def project(self): round_preds = (self.final_graph.edge_preds > 0.5).float() self.constr_satisf_rate, flow_in, flow_out =compute_constr_satisfaction_rate(graph_obj = self.final_graph, edges_out = round_preds, undirected_edges = False, return_flow_vals = True) #self.constr_satisf_rate = 1 - ((flow_in > 1).sum() + (flow_out > 1).sum()).float() / (self.num_nodes*2) # Concat all violated_constraint info nodes_mask = (flow_in > 1) | (flow_out >1) edges_mask = nodes_mask[self.final_graph.edge_index[0]] | nodes_mask[self.final_graph.edge_index[1]] if edges_mask.sum() > 0: graph_to_project = Graph() graph_to_project.edge_preds = self.final_graph.edge_preds[edges_mask] graph_to_project.edge_index = self.final_graph.edge_index.T[edges_mask].T graph_to_project.node_names = self.final_graph.node_names.cuda()[nodes_mask] graph_to_project.node_preds = torch.zeros_like(graph_to_project.node_names) if self.solver_backend == 'gurobi': #mcf_solver = GurobiMinCostFlowSolver(graph_to_project.numpy()) raise Exception('Uncomment gurobi code to run gorubi solver') else: mcf_solver = PuLPMinCostFlowSolver(graph_to_project.numpy()) mcf_solver.solve() # Assign the right values to the original graph's predictions self.final_graph.edge_preds = self.final_graph.edge_preds.cpu().numpy() edges_mask = edges_mask.cpu().numpy() self.final_graph.edge_preds[~edges_mask] = round_preds[~edges_mask].cpu().numpy() self.final_graph.edge_preds[edges_mask] = graph_to_project.edge_preds
def project(self): round_preds = (self.final_graph.edge_preds > 0.5).float() self.constr_satisf_rate, flow_in, flow_out = compute_constr_satisfaction_rate( graph_obj=self.final_graph, edges_out=round_preds, undirected_edges=False, return_flow_vals=True) # Determine the set of constraints that are violated nodes_names = torch.arange(self.num_nodes).to(flow_in.device) in_type = torch.zeros(self.num_nodes).to(flow_in.device) out_type = torch.ones(self.num_nodes).to(flow_in.device) flow_in_info = torch.stack((nodes_names.float(), in_type.float())).t() flow_out_info = torch.stack( (nodes_names.float(), out_type.float())).t() all_violated_constr = torch.cat((flow_in_info, flow_out_info)) mask = torch.cat((flow_in > 1, flow_out > 1)) # Sort violated constraints by the value of thei maximum pred value among incoming / outgoing edges all_violated_constr = all_violated_constr[mask] vals, sorted_ix = torch.sort(all_violated_constr[:, 1], descending=True) all_violated_constr = all_violated_constr[sorted_ix] # Iterate over violated constraints. for viol_constr in all_violated_constr: node_name, viol_type = viol_constr # Determine the set of incoming / outgoing edges mask = torch.zeros(self.num_nodes).bool() mask[node_name.int()] = True if viol_type == 0: # Flow in violation mask = mask[self.final_graph.edge_index[1]] else: # Flow Out violation mask = mask[self.final_graph.edge_index[0]] flow_edges_ix = torch.where(mask)[0] # If the constraint is still violated, set to 1 the edge with highest score, and set the rest to 0 if round_preds[flow_edges_ix].sum() > 1: max_pred_ix = max( flow_edges_ix, key=lambda ix: self.final_graph.edge_preds[ ix] * round_preds[ix] ) # Multiply for round_preds so that if the edge has been set to 0 # it can not be set back to 1 round_preds[mask] = 0 round_preds[max_pred_ix] = 1 # Assert that there are no constraint violations assert scatter_add(round_preds, self.final_graph.edge_index[1], dim_size=self.num_nodes).max() <= 1 assert scatter_add(round_preds, self.final_graph.edge_index[0], dim_size=self.num_nodes).max() <= 1 # return round_preds, constr_satisf_rate self.final_graph.edge_preds = round_preds