def project(self):
        round_preds = (self.final_graph.edge_preds > 0.5).float()

        self.constr_satisf_rate, flow_in, flow_out =compute_constr_satisfaction_rate(graph_obj = self.final_graph,
                                                                                     edges_out = round_preds,
                                                                                     undirected_edges = False,
                                                                                     return_flow_vals = True)
        #self.constr_satisf_rate = 1 - ((flow_in > 1).sum() + (flow_out > 1).sum()).float() / (self.num_nodes*2)

        # Concat all violated_constraint info
        nodes_mask = (flow_in > 1) | (flow_out >1)
        edges_mask = nodes_mask[self.final_graph.edge_index[0]] | nodes_mask[self.final_graph.edge_index[1]]
        if edges_mask.sum() > 0:
            graph_to_project = Graph()
            graph_to_project.edge_preds = self.final_graph.edge_preds[edges_mask]
            graph_to_project.edge_index = self.final_graph.edge_index.T[edges_mask].T
            graph_to_project.node_names = self.final_graph.node_names.cuda()[nodes_mask]
            graph_to_project.node_preds = torch.zeros_like(graph_to_project.node_names)

            if self.solver_backend == 'gurobi':
                #mcf_solver = GurobiMinCostFlowSolver(graph_to_project.numpy())
                raise Exception('Uncomment gurobi code to run gorubi solver')
            else:
                mcf_solver = PuLPMinCostFlowSolver(graph_to_project.numpy())

            mcf_solver.solve()

            # Assign the right values to the original graph's predictions
            self.final_graph.edge_preds = self.final_graph.edge_preds.cpu().numpy()
            edges_mask = edges_mask.cpu().numpy()
            self.final_graph.edge_preds[~edges_mask] = round_preds[~edges_mask].cpu().numpy()
            self.final_graph.edge_preds[edges_mask] = graph_to_project.edge_preds
Beispiel #2
0
    def project(self):
        round_preds = (self.final_graph.edge_preds > 0.5).float()

        self.constr_satisf_rate, flow_in, flow_out = compute_constr_satisfaction_rate(
            graph_obj=self.final_graph,
            edges_out=round_preds,
            undirected_edges=False,
            return_flow_vals=True)
        # Determine the set of constraints that are violated
        nodes_names = torch.arange(self.num_nodes).to(flow_in.device)
        in_type = torch.zeros(self.num_nodes).to(flow_in.device)
        out_type = torch.ones(self.num_nodes).to(flow_in.device)

        flow_in_info = torch.stack((nodes_names.float(), in_type.float())).t()
        flow_out_info = torch.stack(
            (nodes_names.float(), out_type.float())).t()
        all_violated_constr = torch.cat((flow_in_info, flow_out_info))
        mask = torch.cat((flow_in > 1, flow_out > 1))

        # Sort violated constraints by the value of thei maximum pred value among incoming / outgoing edges
        all_violated_constr = all_violated_constr[mask]
        vals, sorted_ix = torch.sort(all_violated_constr[:, 1],
                                     descending=True)
        all_violated_constr = all_violated_constr[sorted_ix]

        # Iterate over violated constraints.
        for viol_constr in all_violated_constr:
            node_name, viol_type = viol_constr

            # Determine the set of incoming / outgoing edges
            mask = torch.zeros(self.num_nodes).bool()
            mask[node_name.int()] = True
            if viol_type == 0:  # Flow in violation
                mask = mask[self.final_graph.edge_index[1]]

            else:  # Flow Out violation
                mask = mask[self.final_graph.edge_index[0]]
            flow_edges_ix = torch.where(mask)[0]

            # If the constraint is still violated, set to 1 the edge with highest score, and set the rest to 0
            if round_preds[flow_edges_ix].sum() > 1:
                max_pred_ix = max(
                    flow_edges_ix,
                    key=lambda ix: self.final_graph.edge_preds[
                        ix] * round_preds[ix]
                )  # Multiply for round_preds so that if the edge has been set to 0
                # it can not be set back to 1
                round_preds[mask] = 0
                round_preds[max_pred_ix] = 1

        # Assert that there are no constraint violations
        assert scatter_add(round_preds,
                           self.final_graph.edge_index[1],
                           dim_size=self.num_nodes).max() <= 1
        assert scatter_add(round_preds,
                           self.final_graph.edge_index[0],
                           dim_size=self.num_nodes).max() <= 1

        # return round_preds, constr_satisf_rate
        self.final_graph.edge_preds = round_preds