#include <cassert>
#include <queue>
#include <iostream>

#include <gp-bnb/incremental_bfs.hpp>

incremental_bfs::incremental_bfs(const graph& g) 
    : g_(g), s_(ibfs_subtree(subtree::s, g, flow_edges_)), t_(ibfs_subtree(subtree::t, g, flow_edges_)), flow_edges_(std::vector<bool> (g.num_edges())) {
}

void incremental_bfs::reset(std::vector<node_id>& sources, std::vector<node_id>& sinks) {
    assert(!sources.empty());
    assert(!sinks.empty());

	node_assignments_.assign(g_.num_nodes(), none);
    for (const auto& node : sources) {
        node_assignments_[node-1] = s_root;
    }
    for (const auto& node : sinks) {
        node_assignments_[node-1] = t_root;
    }

	flow_edges_.assign(g_.num_edges(), false);
    flow_ = 0;

    s_.reset(sources);
    t_.reset(sinks);

}

void incremental_bfs::run() {
    // increments the flow value to the number of pairwise neighbors of sources and sinks
    for (const auto& node : s_.get_front()) {
		const auto& adjacency = g_.get_adjacency(node);
		const auto& edge_ids = g_.get_edge_ids(node);
        for (node_id i = 0; i < adjacency.size(); ++i) {
			if (node_assignments_[adjacency[i]-1] == t_root) {
                flow_edges_[edge_ids[i]-1] = true;
                flow_++;
            }
        } 
    }
    
    // grows the subtrees alternating until one subtree cannot grow anymore
    for (int i = 0; ; ++i) {
        if (i%2 == 0) {
            if (!grow(s_)) break;
        } else {
            if (!grow(t_)) break;
        }
    }
}

bool incremental_bfs::grow(ibfs_subtree& st) {
    // ensures that only st can adopt on current_max+1
    st.set_max_adoption_label(st.get_current_max()+1);

    // creates a queue of candidate nodes for extending
    std::queue<node_id> q;
    for (auto& n : st.get_front()) {
        q.push(n);
    }

    if (q.empty()) return false;

    while (!q.empty()) {
        node_id node = q.front();
        q.pop();
        
        // takes into account the case that node in this grow was already orphaned and not re-inserted in the front
        if (!st.is_front_element(node)) continue;
        
		const auto& adjacency = g_.get_adjacency(node);
		const auto& edge_ids = g_.get_edge_ids(node);
        // searches for neighbors of candidate nodes
        for (node_id i = 0; i < adjacency.size(); ++i) {
			node_id neighbor = adjacency[i];
            if (!flow_edges_[edge_ids[i]-1]) {
                // performs augmentation ...
                if (-st.get_id() == node_assignments_[neighbor-1]) {
					flow_edges_[edge_ids[i]-1] = true;
                    if (st.get_id() == subtree::s) {
                        if (s_.get_current_max() > 0) {
                            for (node_id n : s_.reduce_path(node)) {
                                q.push(n);
                            }
                        }
                        if (t_.get_current_max() > 0) {
                            t_.reduce_path(neighbor);
                        }
                    } else {
                        if (t_.get_current_max() > 0) {
                            for (node_id n : t_.reduce_path(node)) {
                                q.push(n);
                            }
                        }
                        if (s_.get_current_max() > 0) {
                            s_.reduce_path(neighbor);
                        }
                        
                        
                    } 
                    
                    // orphan nodes can be inserted again
                    for (const auto& n : s_.get_last_resid_orphans()) {
                        node_assignments_[n-1] = none;
                    }
                    for (const auto& n : t_.get_last_resid_orphans()) {
                        node_assignments_[n-1] = none;
                    }

                    flow_++;
                    if (st.get_current_max() > 0) break;
                    
                // ... or extends st
                } else if (node_assignments_[neighbor-1] == none) {
                    node_assignments_[neighbor-1] = (st.get_id() == subtree::s) ? s : t;
                    st.add_node(st.get_current_max()+1, neighbor, node);
                }
            }
        }
                
    }

    st.increment_current_max();

    return true;
}