I have implemented Prim's Algorithm from Introduction to Algorithms. I have observed that the code is similar to Dijkstra's Algorithm, so I have used my Dijkstra's Algorithm implementation.
Please review this code and suggest improvements.
To compile on Linux: g++ -std=c++14 prims.cpp
#include <iostream>
#include <map>
#include <limits>
#include <list>
#include <queue>
class Graph
{
struct Vertex
{
std::size_t id;
int distance = std::numeric_limits<int>::max();
Vertex * parent = nullptr;
Vertex(std::size_t id) : id(id) {}
};
using pair_ = std::pair<std::size_t, int>;
std::vector<Vertex> vertices = {};
//adjacency list , store src, dest, and weight
std::vector< std::vector< pair_> > adj_list;
//to store unprocessed vertex min-priority queue
std::priority_queue< pair_, std::vector<pair_>,
std::greater<pair_> > unprocessed;
public:
Graph(std::size_t size);
void add_edge(std::size_t src, std::size_t dest, int weight);
void prim(std::size_t vertex);
std::size_t minimum_cost() ;
};
Graph::Graph(std::size_t size)
{
vertices.reserve(size);
adj_list.resize(size);
for (int i = 0; i < size; i++)
{
vertices.emplace_back(i);
}
}
void Graph::add_edge(std::size_t src , std::size_t dest, int weight)
{
if(weight >= 0)
{
if (src == dest)
{
throw std::logic_error("Source and destination vertices are same");
}
if (src < 0 || vertices.size() <= src)
{
throw std::out_of_range("Enter correct source vertex");
}
if (dest < 0 || vertices.size() <= dest)
{
throw std::out_of_range("Enter correct destination vertex");
}
int flag = 0, i = src;
for (auto& it : adj_list[i])
{
if (it.first == dest)
{
flag = 1;
break;
}
}
if (flag == 0)
{
adj_list[src].push_back( {dest, weight} );
}
else
{
throw std::logic_error("Existing edge");
}
}
else
{
std::cerr << "Negative weight\n";
}
}
void Graph::prim(std::size_t vertex)
{
vertices[vertex].distance = 0;
vertices[vertex].parent = &vertices[vertex];
unprocessed.push( std::make_pair(vertices[vertex].distance, vertex) );
while (!unprocessed.empty())
{
int curr_vertex_dist = unprocessed.top().first;
std::size_t curr_vertex = unprocessed.top().second;
unprocessed.pop();
for (auto& ver: adj_list[curr_vertex])
{
auto& next_dist = vertices[ver.first].distance;
const auto curr_dist = ver.second;
if (curr_dist < next_dist)
{
next_dist = curr_dist;
//make src vertex parent of dest vertex
vertices[ver.first].parent = &vertices[curr_vertex];
unprocessed.push( std::make_pair(next_dist, ver.first));
}
}
}
}
std::size_t Graph::minimum_cost()
{
std::size_t cost = 0;
for (auto vertex: vertices)
{
cost = cost + vertex.distance;
}
return cost;
}
int main()
{
Graph grp(9);
grp.add_edge(0, 1, 4);
grp.add_edge(0, 2, 8);
grp.add_edge(1, 2, 11);
grp.add_edge(1, 3, 8);
grp.add_edge(3, 4, 2);
grp.add_edge(4, 2, 7);
grp.add_edge(2, 5, 1);
grp.add_edge(5, 4, 6);
grp.add_edge(3, 6, 7);
grp.add_edge(3, 8, 4);
grp.add_edge(5, 8, 2);
grp.add_edge(6, 7, 9);
grp.add_edge(6, 8, 14);
grp.add_edge(7, 8, 10);
grp.prim(0);
std::cout << "The total cost is : " << grp.minimum_cost() << "\n";
}