def __init__(self, in_channels, edge_features, out_channels, dataset_class, bias=False, towers=1, divide_input=False): dataset = dataset_class('./all_iter', split='train', less_wired=True, device='cpu') dlist = [ torch_geometric.utils.degree(dataset[i].edge_index[0].to( get_hyperparameters()["device"])) for i in range(len(dataset)) ] avg_d = dict( lin=sum([torch.mean(D) for D in dlist]) / len(dlist), exp=sum( [torch.mean(torch.exp(torch.div(1, D)) - 1) for D in dlist]) / len(dlist), log=sum([torch.mean(torch.log(D + 1)) for D in dlist]) / len(dlist)) super(PNAWrapper, self).__init__(in_channels, edge_features, out_channels, get_hyperparameters()["pna_aggregators"].split(), get_hyperparameters()["pna_scalers"].split(), avg_d, towers=towers, divide_input=divide_input)
def finish(x, y, batch_ids, steps, STEPS_SIZE, GRAPH_SIZES): """ Returns whether it's a final iteration or not in real task Returns true/false value per graph (as a mask) N.B. Not what the network thinks """ DEVICE = get_hyperparameters()["device"] if steps == 0: return torch.ones(len(GRAPH_SIZES), device=DEVICE) if not steps < STEPS_SIZE - 1: return torch.zeros(len(GRAPH_SIZES), device=DEVICE) x_curr = torch.index_select( x, 1, torch.tensor([steps], dtype=torch.long, device=DEVICE)).squeeze(1).to(DEVICE) y_curr = torch.index_select( y, 1, torch.tensor([steps], dtype=torch.long, device=DEVICE)).squeeze(1).to(DEVICE) noteq = (~(x_curr == y_curr)) hyperparameters = get_hyperparameters() batches_inside = batch_ids.max() + 1 noteq_batched = noteq.view(batches_inside, -1, hyperparameters["dim_target"]) true_termination = noteq_batched.any(dim=1).any(dim=-1).float() return true_termination
def get_messages_from_features(self, x_i, x_j, walk_edge_index, attrs, batch): attrs = attrs.reshape(-1, get_hyperparameters()["dim_edges"]) enc_attrs = self.encode_edges(attrs) messages = ( self.processor.message(x_i, x_j, utils.flip_edge_index(walk_edge_index), enc_attrs, batch.num_nodes) if type(self.processor) == GAT else self.processor.message(x_i, x_j, enc_attrs)) messages = messages.reshape(batch.num_graphs, -1, get_hyperparameters()["dim_latent"]) return messages
def get_pairs(i, stop_move_backward_col, path, do_not_process): mask = (i < stop_move_backward_col) & (~do_not_process) index = torch.tensor(i, device=get_hyperparameters()["device"], dtype=torch.long) pairs = (path[mask][:]).index_select(dim=2, index=index).squeeze(-1) return pairs
def update_states(self, distances, predecessors_p, continue_p, current_latent): super().update_states(continue_p, current_latent) DIM_LATENT = get_hyperparameters()["dim_latent"] self.last_distances = torch.where(self.mask, utils.bit2integer(distances), self.last_distances) self.last_predecessors_p = torch.where(self.mask, predecessors_p, self.last_predecessors_p) self.last_output = self.last_distances
def obtain_paths(predecessors, GRAPH_SIZES, STEPS_SIZE, SOURCE_NODES, SINK_NODES, return_path_matrix=False): hyperparameters = get_hyperparameters() DEVICE = hyperparameters["device"] path_matrix = torch.full((len(GRAPH_SIZES), STEPS_SIZE), -100, device=DEVICE, dtype=torch.long) stop_move_backward_col = torch.zeros(len(GRAPH_SIZES), device=DEVICE, dtype=torch.long) final = SOURCE_NODES.clone() for i, n in enumerate(SINK_NODES): path_matrix[i][0] = n.item() for i in range(1, STEPS_SIZE): rowcols = (range(len(GRAPH_SIZES)), stop_move_backward_col) upd = (path_matrix[rowcols] != SOURCE_NODES) upd2 = path_matrix[rowcols] != predecessors[path_matrix[rowcols]] upd3 = predecessors[path_matrix[rowcols]] != -1 upd &= upd2 & upd3 path_matrix[upd, i] = predecessors[path_matrix[upd, i - 1]] stop_move_backward_col[upd] = i final[upd] = path_matrix[upd, i] path = torch.stack([ torch.stack([path_matrix[i][:-1], path_matrix[i][1:]], dim=0) for i in range(len(GRAPH_SIZES)) ], dim=0) return path_matrix if return_path_matrix else path, stop_move_backward_col, final
def load_algorithms(algorithms, processor, use_ints): hyperparameters = get_hyperparameters() DEVICE = hyperparameters["device"] DIM_LATENT = hyperparameters["dim_latent"] DIM_NODES_BFS = hyperparameters["dim_nodes_BFS"] DIM_NODES_AugmentingPath = hyperparameters["dim_nodes_AugmentingPath"] DIM_EDGES = hyperparameters["dim_edges"] DIM_EDGES_BFS = hyperparameters["dim_edges_BFS"] DIM_BITS = hyperparameters["dim_bits"] if use_ints else None for algorithm in algorithms: if algorithm == "AugmentingPath": algo_net = models.AugmentingPathNetwork( DIM_LATENT, DIM_NODES_AugmentingPath, DIM_EDGES, processor, flow_datasets.SingleIterationDataset, './all_iter', bias=hyperparameters["bias"], use_ints=use_ints, bits_size=DIM_BITS).to(DEVICE) if algorithm == "BFS": algo_net = models.BFSNetwork( DIM_LATENT, DIM_NODES_BFS, DIM_EDGES_BFS, processor, flow_datasets.BFSSingleIterationDataset, './bfs').to(DEVICE) processor.add_algorithm(algo_net, algorithm)
def update_weights(self, optimizer): loss = 0 for name, algorithm in self.algorithms.items(): print("Algorithm", name) losses_dict =\ algorithm.get_losses_dict() pprint(losses_dict) loss += algorithm.get_training_loss() if get_hyperparameters( )["calculate_termination_statistics"]: #DEPRECATED print( "Term precision:", algorithm.true_positive / (algorithm.true_positive + algorithm.false_positive) if algorithm.true_positive + algorithm.false_positive else 'N/A') print( "Term recall:", algorithm.true_positive / (algorithm.true_positive + algorithm.false_negative) if algorithm.true_positive + algorithm.false_negative else 'N/A') start = time.time() optimizer.zero_grad() print("LOSSITEM", loss.item()) loss.backward() optimizer.step()
def prepare_initial_masks(self, batch): DEVICE = get_hyperparameters()["device"] mask = torch.ones_like(batch.batch, dtype=torch.bool, device=DEVICE) mask_cp = torch.ones(batch.num_graphs, dtype=torch.bool, device=DEVICE) edge_mask = torch.ones_like(batch.edge_index[0], dtype=torch.bool, device=DEVICE) return mask, mask_cp, edge_mask
def update_states(self, continue_p, current_latent): DIM_LATENT = get_hyperparameters()["dim_latent"] self.last_continue_p = torch.where(self.mask_cp, continue_p, self.last_continue_p) self.last_latent = torch.where( self.mask.unsqueeze(1).repeat_interleave(DIM_LATENT, dim=1), current_latent, self.last_latent) return self.last_continue_p
def __init__(self, latent_features, processor_type='MPNN'): super(AlgorithmProcessor, self).__init__() if processor_type == 'MPNN': self.processor = MPNN(latent_features, latent_features, latent_features, bias=get_hyperparameters()['bias']) self.algorithms = nn.ModuleDict()
def update_states(self, reachable_p, continue_p, current_latent): super().update_states(continue_p, current_latent) DIM_LATENT = get_hyperparameters()["dim_latent"] self.last_reachable_p = torch.where(self.mask, reachable_p, self.last_reachable_p) self.last_reachable = (self.last_reachable_p >= 0.5).float() self.last_output = torch.where(self.mask, self.last_reachable, self.last_output)
def get_step_io(self, x, y): DEVICE = get_hyperparameters()["device"] x_curr = torch.index_select( x, 1, torch.tensor([self.steps], dtype=torch.long, device=DEVICE)).squeeze(1) y_curr = torch.index_select( y, 1, torch.tensor([self.steps], dtype=torch.long, device=DEVICE)).squeeze(1) return x_curr, y_curr
def get_walks(training, batch, output, GRAPH_SIZES, SOURCE_NODES, SINK_NODES): if training: return random_walk(batch.edge_index[0], batch.edge_index[1], SINK_NODES, walk_length=get_hyperparameters()["walk_length"], coalesced=True).long(), None path_matrix, _, mask = get_walks_from_output(output, GRAPH_SIZES, SOURCE_NODES, SINK_NODES) return path_matrix, mask
def __init__(self, latent_features, dataset, processor_type='MPNN'): assert processor_type in ['MPNN', 'PNA', 'GAT'] super(AlgorithmProcessor, self).__init__() if processor_type == 'MPNN': self.processor = MPNN(latent_features, latent_features, latent_features, bias=get_hyperparameters()["bias"]) elif processor_type == 'PNA': self.processor = PNAWrapper(latent_features, latent_features, latent_features, SingleIterationDataset, bias=get_hyperparameters()["bias"]) elif processor_type == 'GAT': self.processor = GAT(latent_features, latent_features, latent_features, bias=get_hyperparameters()["bias"]) self.algorithms = nn.ModuleDict()
def set_initial_last_states(self, batch, STEPS_SIZE, SOURCE_NODES): hyperparameters = get_hyperparameters() DEVICE = hyperparameters["device"] DIM_LATENT = hyperparameters["dim_latent"] SIZE = batch.num_nodes INF = STEPS_SIZE super().set_initial_last_states(batch, STEPS_SIZE, SOURCE_NODES) self.last_reachable_p = torch.full([SIZE], -1e3, device=DEVICE) self.last_reachable_p[0] = 1e3 self.last_reachable = torch.zeros(SIZE) self.last_reachable[0] = 1.
def set_initial_last_states(self, batch, STEPS_SIZE, SOURCE_NODES): hyperparameters = get_hyperparameters() DEVICE = hyperparameters["device"] DIM_LATENT = hyperparameters["dim_latent"] SIZE = batch.num_nodes self.last_latent = torch.zeros(SIZE, DIM_LATENT, device=DEVICE) self.last_continue_p = torch.ones(batch.num_graphs, device=DEVICE) x, y = self.get_input_output_features(batch, SOURCE_NODES) x.requires_grad = False y.requires_grad = False x_curr, _ = self.get_step_io(x, y) self.last_output = x_curr[:, 0].clone()
def get_features_from_walk(self, walks, batch, mask_end_of_path=None): attr_matrix = torch_geometric.utils.to_dense_adj(batch.edge_index, edge_attr=batch.edge_attr).squeeze() walk_nodes_latent = self.last_latent[walks] walk_nodes_latent_i = walk_nodes_latent[:, :-1] walk_nodes_latent_j = walk_nodes_latent[:, 1:] walk_nodes_latent_i = walk_nodes_latent_i.reshape(-1, get_hyperparameters()["dim_latent"]) walk_nodes_latent_j = walk_nodes_latent_j.reshape(-1, get_hyperparameters()["dim_latent"]) wf = walks[:, :-1] ws = walks[:, 1:] walk_edge_index = torch.stack((wf, ws), dim=0) walk_edge_index = walk_edge_index.view(2,-1) walk_attrs = attr_matrix[(ws, wf)] inv_walk_attrs = attr_matrix[(wf, ws)] if self.training: walk_attrs[:, :, 1] = walk_attrs[:, :, 1] + torch.randint_like(walk_attrs[:, :, 1], low=0, high=10) wa = walk_attrs[:, :, 1].clone().detach() if mask_end_of_path is not None: wa[mask_end_of_path[:, 1:]] = 1e9 actual_argmins = wa.argmin(dim=1) no_step_mask = (wa.min(dim=1).values == 1e9) # no step was made by the algorithm so min is not defined return walk_nodes_latent_i, walk_nodes_latent_j, walk_edge_index, walk_attrs, inv_walk_attrs, actual_argmins, no_step_mask
def set_initial_last_states(self, batch, STEPS_SIZE, SOURCE_NODES): hyperparameters = get_hyperparameters() DEVICE = hyperparameters["device"] DIM_LATENT = hyperparameters["dim_latent"] DIM_NODES = hyperparameters["dim_nodes_AugmentingPath"] DIM_EDGES = hyperparameters["dim_edges"] SIZE = batch.num_nodes super().set_initial_last_states(batch, STEPS_SIZE, SOURCE_NODES) self.last_predecessors_p = torch.full((SIZE, SIZE), -1e9, device=DEVICE) self.last_predecessors_p[(SOURCE_NODES, SOURCE_NODES)] = 1e9 self.last_distances = self.last_output.clone() self.last_distances[SOURCE_NODES] = 0.
def reweight_batch(batch, batch_bfs, use_ints): DEVICE = get_hyperparameters()["device"] if use_ints: weights = torch.randint_like(batch.edge_attr[:, 1], 1, 16, device=DEVICE, dtype=torch.float) else: weights = 0.8 * torch.rand_like( batch.edge_attr[:, 1], device=DEVICE, dtype=torch.float) + 0.2 batch.edge_attr = torch.stack((weights, batch.edge_attr[:, 1]), dim=1) batch_bfs.edge_attr = torch.stack((weights, batch_bfs.edge_attr[:, 1]), dim=1)
def update_broken_invariants(self, batch, predecessors, adj_matrix, flow_matrix): start = time.time() DEVICE = get_hyperparameters()["device"] GRAPH_SIZES, SOURCE_NODES, SINK_NODES = utils.get_sizes_and_source_sink(batch) STEPS_SIZE = GRAPH_SIZES.max() _, y = self.get_input_output_features(batch, SOURCE_NODES) broke_flow = torch.zeros(batch.num_graphs, dtype=torch.bool, device=DEVICE) broke_reachability_source = torch.zeros(batch.num_graphs, dtype=torch.bool, device=DEVICE) broke_invariant = torch.zeros(batch.num_graphs, dtype=torch.bool, device=DEVICE) curr_node = SINK_NODES.clone().detach() cnt = 0 predecessors_real = y[:, -1, -1] idx = predecessors[curr_node] != curr_node while (predecessors_real[SINK_NODES] != -1).any() and cnt <= STEPS_SIZE and idx.any() and not utils.interrupted(): # Ignore if we reached the starting node loop # (predecessor[starting node] = starting node) move_to_predecessors = torch.stack((predecessors[curr_node], curr_node), dim=0)[:, idx] rowcols = (move_to_predecessors[0], move_to_predecessors[1]) if not adj_matrix[rowcols].all(): # each predecessor lead to a node accessible by an edge!!! print() print(adj_matrix) print(curr_node) print(predecessors[curr_node]) print("FATAL INVARIANT ERORR") exit(0) assert adj_matrix[rowcols].all() if (flow_matrix[rowcols] <= 0).any(): broke_flow[idx] |= flow_matrix[rowcols] <= 0 curr_node[idx] = predecessors[curr_node[idx]] idx = (predecessors[curr_node] != curr_node) & (predecessors_real[SINK_NODES] != -1) cnt += 1 if cnt > STEPS_SIZE+1: break original_reachable_mask = (predecessors_real[SINK_NODES] != -1) broke_reachability_source |= (curr_node != SOURCE_NODES) broke_invariant = broke_flow | broke_reachability_source broke_all = broke_flow & broke_reachability_source self.broken_invariants.extend((original_reachable_mask & broke_invariant).clone().detach()) self.broken_reachabilities.extend((original_reachable_mask & broke_reachability_source).clone().detach()) self.broken_flows.extend((original_reachable_mask & broke_flow).clone().detach()) self.broken_all.extend((original_reachable_mask & broke_all).clone().detach())
def add_algorithms(self, algo_list): hyperparameters = get_hyperparameters() device = hyperparameters['device'] node_features = hyperparameters['dim_nodes'] edge_features = hyperparameters['dim_edges'] latent_features = hyperparameters['dim_latent'] for algo in algo_list: if algo == 'BFS': self.algorithms[algo] = models.bfs_network.BFSNetwork(node_features, edge_features, latent_features, self).to(device) elif algo in ['TRANS', 'TIPS', 'BUBBLES']: self.algorithms[algo] = models.traversal_network.TraversalNetwork(node_features, edge_features, latent_features, self).to(device) else: # For other algorithms pass
def get_step_loss(self, batch_ids, mask, mask_cp, y_curr, continue_p, true_termination, reachable_p, compute_losses_and_broken=True): reachable_p_masked = reachable_p[mask] reachable_real_masked = y_curr[mask].squeeze() steps = sum(mask_cp.float()) loss_reachable, loss_term, processed_nodes, step_acc = 0, 0, 0, 1 if compute_losses_and_broken: processed_nodes = len(reachable_real_masked) self.predictions["reachabilities"].extend(reachable_p_masked) self.actual["reachabilities"].extend(reachable_real_masked) self.predictions["terminations"].extend(continue_p[mask_cp]) self.actual["terminations"].extend(true_termination[mask_cp]) loss_reachable = F.binary_cross_entropy_with_logits( reachable_p_masked, reachable_real_masked, reduction='sum', pos_weight=torch.tensor(1.00)) loss_term = F.binary_cross_entropy_with_logits( continue_p[mask_cp], true_termination[mask_cp], reduction='sum', pos_weight=torch.tensor(1.00)) if get_hyperparameters()["calculate_termination_statistics"]: self.update_termination_statistics(continue_p[mask_cp], true_termination[mask_cp]) if not self.training: reachable_split = utils.split_per_graph( batch_ids, reachable_p > 0) reachable_real_split = utils.split_per_graph( batch_ids, y_curr.squeeze()) correct, tot = BFSNetwork.calculate_step_acc( reachable_split[mask_cp], reachable_real_split[mask_cp]) self.mean_step.extend(correct / tot.float()) step_acc = correct / tot.float() return loss_reachable, loss_term, steps, processed_nodes, step_acc
def get_step_loss(self, batch_ids, mask, mask_cp, y_curr, continue_p, true_termination, distances, predecessors_p, compute_losses_and_broken=True): distances_masked, distances_real_masked, predecessors_p_masked, predecessors_real_masked = \ AugmentingPathNetwork.mask_infinities(mask, y_curr, distances, predecessors_p) steps = sum(mask_cp.float()) train = self.training loss_dist, loss_pred, loss_term, processed_nodes, step_acc = 0, 0, 0, 0, 1 if distances_real_masked.nelement() != 0 and compute_losses_and_broken: processed_nodes = len(distances_real_masked) if self.bits_size is None: loss_dist = F.mse_loss(distances_masked, distances_real_masked, reduction='sum') else: loss_dist = F.binary_cross_entropy_with_logits(distances_masked, utils.integer2bit(distances_real_masked), reduction='sum') loss_pred = F.cross_entropy(predecessors_p_masked, predecessors_real_masked, ignore_index=-1, reduction='sum') if compute_losses_and_broken: assert mask_cp.any(), mask_cp loss_term = F.binary_cross_entropy_with_logits(continue_p[mask_cp], true_termination[mask_cp], reduction='sum', pos_weight=torch.tensor(1.00)) if get_hyperparameters()["calculate_termination_statistics"]: self.update_termination_statistics(continue_p[mask_cp], true_termination[mask_cp]) assert loss_term.item() != float('inf') if not train and mask_cp.any() and compute_losses_and_broken: assert mask_cp.any() _, predecessors_real = AugmentingPathNetwork.get_real_output_values(y_curr) predecessors_p_split = utils.split_per_graph(batch_ids, predecessors_p) predecessors_real_split = utils.split_per_graph(batch_ids, predecessors_real) correct, tot = AugmentingPathNetwork.calculate_step_acc(torch.max(predecessors_p_split[mask_cp], dim=2).indices, predecessors_real_split[mask_cp]) self.mean_step.extend(correct/tot.float()) step_acc = correct/tot.float() return loss_dist, loss_pred, loss_term, steps, processed_nodes, step_acc
def augment_flow(self, batch, walks, mask_end_of_path, mins): BATCH_SIZE = batch.num_graphs x_i, x_j, walk_edge_index, attrs, inv_walk_attrs, actual_argmins, no_step_mask =\ self.get_features_from_walk(walks, batch, mask_end_of_path=mask_end_of_path) if no_step_mask.all(): return attrs[:, :, 1] # nothing changes if self.training: attrs = self._attrs #Augmentation and bottleneck finding had to have the same edge attributes messages = self.get_messages_from_features(x_i, x_j, walk_edge_index, attrs, batch) messages_inv = self.get_messages_from_features(x_j, x_i, walk_edge_index, inv_walk_attrs, batch) minemb = self.bit_encoder(utils.integer2bit(mins.float())) minemb = minemb.unsqueeze(1).repeat_interleave(messages.shape[1], dim=1) new_weight_distribution = self.subtract_network(torch.cat((minemb, messages), dim=-1)) mask = torch.arange(2**self.bits_size, device=get_hyperparameters()["device"]) mask = mask.repeat(BATCH_SIZE, attrs.shape[1], 1) expanded_caps = attrs[:, :, 1].unsqueeze(2).expand_as(mask) mask = mask > expanded_caps new_weight_distribution[mask] = -1e9 new_weight_distribution = new_weight_distribution[~no_step_mask] real_new_caps = attrs[:, :, 1][~no_step_mask] real_new_caps -= mins[~no_step_mask].unsqueeze(-1) if mask_end_of_path is not None: real_new_caps[mask_end_of_path[:, 1:][~no_step_mask]] = -100 # ignore value of cross entropy just_correct = real_new_caps >= 0 # the neural mins provided may be incorrect so ignore these values neural_new_caps = new_weight_distribution.argmax(dim=-1) if just_correct.any(): self.losses["augment"] = F.cross_entropy(new_weight_distribution[just_correct].view(-1, 2**self.bits_size), real_new_caps[just_correct].view(-1).long()) subtract_correct = (neural_new_caps == real_new_caps)[just_correct].sum() subtract_all = real_new_caps[just_correct].nelement() if not self.training: self.subtract_correct += subtract_correct self.subtract_all += subtract_all else: return attrs[:, :, 1] # nothing changes return neural_new_caps
def zero_hidden(self, num_nodes): self.hidden = torch.zeros(num_nodes, self.out_channels).to( get_hyperparameters()["device"])
def termination_condition(args, i, threshold, batch, reachable_sinks=None): if not args["--use-BFS-for-termination"]: return i < threshold return i < get_hyperparameters()["max_threshold"] and reachable_sinks.any()
def run(args, threshold, processor, probp=1, probq=4, savefile=True): hyperparameters = get_hyperparameters() DEVICE = hyperparameters["device"] DIM_LATENT = hyperparameters["dim_latent"] if savefile: with open(args["SAVE_FILE"], "w"): pass f = open(args["SAVE_FILE"], "a+") dataset = GraphOnlyDataset('./graph_only', split='test' + args["--upscale"], less_wired=True, probp=probp, probq=probq, device='cpu') dataset_BFS = GraphOnlyDatasetBFS('./graph_only_BFS', split='test' + args["--upscale"], probp=probp, probq=probq, less_wired=True, device='cpu') result_maxflows = [] with torch.no_grad(): for rep in range(10): print(rep) current_result = [] loader = DataLoader(dataset, batch_size=hyperparameters["batch_size"], shuffle=False, drop_last=False, num_workers=0) loader_BFS = iter( DataLoader(dataset_BFS, batch_size=hyperparameters["batch_size"], shuffle=False, drop_last=False, num_workers=0)) for batch in tqdm(loader, dynamic_ncols=True): batch_bfs = next(loader_BFS) batch_bfs.to(DEVICE) batch.to(DEVICE) start_iter = time.time() start = time.time() GRAPH_SIZES, SOURCE_NODES, SINK_NODES = utils.get_sizes_and_source_sink( batch) # we make at most |V|-1 steps STEPS_SIZE = GRAPH_SIZES.max() inv_edge_index = utils.create_inv_edge_index( len(GRAPH_SIZES), GRAPH_SIZES.max(), batch.edge_index) do_not_process = torch.zeros_like(GRAPH_SIZES, dtype=torch.bool, device=DEVICE) start = time.time() path_matrix, stop_move_backward_col, bottleneck = find_augmenting_path( args, batch, batch_bfs, do_not_process, processor, inv_edge_index, threshold) do_not_process = bottleneck == 0 cnt = 0 while (bottleneck != 0).any(): wrong_minus = augment_flow( batch, inv_edge_index, path_matrix, stop_move_backward_col, bottleneck, do_not_process, args["--use-neural-augmentation"], augmenting_path_network=processor. algorithms["AugmentingPath"]) batch_bfs.edge_attr[:, 1] = batch.edge_attr[:, 1] do_not_process |= wrong_minus path_matrix, stop_move_backward_col, bottleneck = find_augmenting_path( args, batch, batch_bfs, do_not_process, processor, inv_edge_index, threshold) do_not_process |= (bottleneck == 0) bottleneck[do_not_process] = 0 assert ((bottleneck <= 1) & (bottleneck >= 0)).all(), (bottleneck, path_matrix) cnt += 1 start = time.time() maxflows = [0 for sn in SINK_NODES] cnt = 0 for isn, sn in enumerate(SINK_NODES): cnt = 0 for i in range(len(batch.batch)): if inv_edge_index[sn][i] != -100: cnt += 1 assert 0 <= batch.edge_attr[inv_edge_index[sn] [i]][1] <= 1 maxflows[isn] += batch.edge_attr[inv_edge_index[sn] [i]][1] if savefile: print(*[int(mf.item()) for mf in maxflows], sep=' ', end=' ', file=f) else: current_result.extend([int(mf.item()) for mf in maxflows]) if savefile: f.write('\n') else: result_maxflows.append(current_result) if savefile: f.close() else: return result_maxflows
def augment_flow(batch, inv_edge_index, path, stop_move_backward_col, bottleneck, do_not_process, use_neural_augmentation, augmenting_path_network=None): def get_edge_indexes(step): pairs = get_pairs(step, stop_move_backward_col, path, do_not_process) edge_idx = inv_edge_index[(pairs[:, 1], pairs[:, 0])] edge_idx_rev = inv_edge_index[(pairs[:, 0], pairs[:, 1])] return edge_idx, edge_idx_rev flow = batch.edge_attr[:, 1] old_flow = flow.clone() path_matrix = torch.cat((path[:, 0, :], path[:, 1, -1].unsqueeze(-1)), dim=-1) mask_end_of_path = path_matrix == -100 path_matrix[mask_end_of_path] = 0 if use_neural_augmentation: new_flows = augmenting_path_network.augment_flow( batch, path_matrix, mask_end_of_path, bottleneck) largest_len = max(stop_move_backward_col) needs_rerun = torch.zeros(batch.num_graphs, device=get_hyperparameters()["device"], dtype=torch.bool) for i in range(largest_len): mask = (i < stop_move_backward_col) & (~do_not_process) if not mask.any(): break edge_idx, edge_idx_rev = get_edge_indexes(i) assert (edge_idx != -100).all() assert (edge_idx_rev != -100).all() assert (abs(flow[edge_idx]) <= 1).all(), flow[edge_idx][~(abs(flow[edge_idx]) <= 1)] assert (abs(bottleneck[mask]) <= 1).all() assert (bottleneck[do_not_process] == 0).all() assert (bottleneck[~do_not_process] == 1).all() assert (bottleneck[mask] == 1).all() if use_neural_augmentation: new_flow = new_flows[mask[stop_move_backward_col > 0]][:, i].float() should_be = flow[edge_idx] - bottleneck[mask] not_what_should_be = (new_flow != should_be) needs_rerun[mask] |= not_what_should_be consts = flow[edge_idx] + flow[edge_idx_rev] new_flow_rev = consts - new_flow flow[edge_idx] = torch.where(not_what_should_be, flow[edge_idx], new_flow) flow[edge_idx_rev] = torch.where(not_what_should_be, flow[edge_idx_rev], new_flow_rev) else: flow[edge_idx] -= bottleneck[mask] flow[edge_idx_rev] += bottleneck[mask] if needs_rerun.any(): for i in range(largest_len): mask = (i < stop_move_backward_col) & (~do_not_process) if not mask.any(): break edge_idx, edge_idx_rev = get_edge_indexes(i) flow[edge_idx] = old_flow[edge_idx] flow[edge_idx_rev] = old_flow[edge_idx_rev] return needs_rerun
def find_augmenting_path(args, batch, batch_bfs, do_not_process, processor, inv_edge_index, threshold, debug=False): DEVICE = get_hyperparameters()["device"] GRAPH_SIZES, SOURCE_NODES, SINK_NODES = utils.get_sizes_and_source_sink( batch) STEPS_SIZE = GRAPH_SIZES.max() redo_mask = torch.ones_like(GRAPH_SIZES, device=DEVICE, dtype=torch.bool) & (~do_not_process) predecessors_last = torch.full_like(batch.batch, -1, device=DEVICE) reachable_last = torch.ones_like(predecessors_last, dtype=torch.bool) weights = batch.edge_attr[:, 0] flow = batch.edge_attr[:, 1] final = SINK_NODES.clone() bottleneck = torch.zeros(batch.num_graphs, device=DEVICE) path_matrix = torch.full((batch.num_graphs, STEPS_SIZE), -100, device=DEVICE, dtype=torch.long) stop_move_backward_col = torch.zeros(batch.num_graphs, device=DEVICE, dtype=torch.long) wrong_bottleneck_mask = None for algorithm in processor.algorithms.values(): algorithm.zero_validation_stats() if args["--use-BFS-for-termination"]: with torch.no_grad(): reachable = processor.algorithms["BFS"].process( batch_bfs, EPSILON=0, enforced_mask=redo_mask, compute_losses_and_broken=False) else: reachable = torch.ones(batch.num_nodes, device=DEVICE, dtype=torch.bool) i = 0 while termination_condition(args, i, threshold, batch, reachable[SINK_NODES]): i += 1 start = time.time() with torch.no_grad(): predecessors = processor.algorithms["AugmentingPath"].process( batch, EPSILON=0, enforced_mask=redo_mask, compute_losses_and_broken=False) predecessors = torch.where(redo_mask[batch.batch], predecessors, predecessors_last) predecessors_last = predecessors predecessors = predecessors_last.clone() path_matrix, stop_move_backward_col, final = obtain_paths( predecessors, GRAPH_SIZES, STEPS_SIZE, SOURCE_NODES, SINK_NODES) if args["--use-neural-bottleneck"]: walks, mask_end_of_path = utils.get_walks(False, batch, predecessors, GRAPH_SIZES, SOURCE_NODES, SINK_NODES) bottleneck = processor.algorithms["AugmentingPath"].find_mins( batch, walks, mask_end_of_path, GRAPH_SIZES, SOURCE_NODES, SINK_NODES) zero_unreachable_and_frozen_bottleneck(bottleneck, do_not_process, reachable[SINK_NODES]) real_bottleneck = get_bottleneck(path_matrix, stop_move_backward_col, flow, inv_edge_index, do_not_process, reachable[SINK_NODES]) wrong_bottleneck_mask = (bottleneck == 1) & (real_bottleneck == 0) do_not_process |= wrong_bottleneck_mask else: bottleneck = get_bottleneck(path_matrix, stop_move_backward_col, flow, inv_edge_index, do_not_process, reachable[SINK_NODES]) reweight_batch(batch, batch_bfs, args["--use-ints"]) if args["--use-BFS-for-termination"]: with torch.no_grad(): reachable = processor.algorithms["BFS"].process( batch_bfs, EPSILON=0, enforced_mask=redo_mask, compute_losses_and_broken=False) reachable = torch.where(redo_mask[batch.batch], reachable, reachable_last) reachable_last = reachable reachable = reachable_last.clone() reachable_sinks = reachable[SINK_NODES] redo_mask = get_redo_mask(args["--use-BFS-for-termination"], args["--use-neural-bottleneck"], final, SOURCE_NODES, bottleneck, do_not_process, reachable_sinks) do_not_process[redo_mask] |= (~reachable_sinks[redo_mask]) if debug: if (bottleneck == 0).any(): print("Broke flow cap invariant", file=sys.stderr) if (final != SOURCE_NODES).any(): print("Broke reachability invariant", file=sys.stderr) if not redo_mask.any(): bottleneck[do_not_process] = 0 return path_matrix, stop_move_backward_col, bottleneck bottleneck[redo_mask] = 0 bottleneck[ do_not_process] = 0 # Hack to set redo mask to false for next iterations return path_matrix, stop_move_backward_col, bottleneck