def __init__(self,
              in_channels,
              edge_features,
              out_channels,
              dataset_class,
              bias=False,
              towers=1,
              divide_input=False):
     dataset = dataset_class('./all_iter',
                             split='train',
                             less_wired=True,
                             device='cpu')
     dlist = [
         torch_geometric.utils.degree(dataset[i].edge_index[0].to(
             get_hyperparameters()["device"])) for i in range(len(dataset))
     ]
     avg_d = dict(
         lin=sum([torch.mean(D) for D in dlist]) / len(dlist),
         exp=sum(
             [torch.mean(torch.exp(torch.div(1, D)) - 1)
              for D in dlist]) / len(dlist),
         log=sum([torch.mean(torch.log(D + 1))
                  for D in dlist]) / len(dlist))
     super(PNAWrapper,
           self).__init__(in_channels,
                          edge_features,
                          out_channels,
                          get_hyperparameters()["pna_aggregators"].split(),
                          get_hyperparameters()["pna_scalers"].split(),
                          avg_d,
                          towers=towers,
                          divide_input=divide_input)
예제 #2
0
def finish(x, y, batch_ids, steps, STEPS_SIZE, GRAPH_SIZES):
    """
    Returns whether it's a final iteration or not in real task

    Returns true/false value per graph (as a mask)

    N.B. Not what the network thinks
    """
    DEVICE = get_hyperparameters()["device"]
    if steps == 0:
        return torch.ones(len(GRAPH_SIZES), device=DEVICE)
    if not steps < STEPS_SIZE - 1:
        return torch.zeros(len(GRAPH_SIZES), device=DEVICE)
    x_curr = torch.index_select(
        x, 1, torch.tensor([steps], dtype=torch.long,
                           device=DEVICE)).squeeze(1).to(DEVICE)
    y_curr = torch.index_select(
        y, 1, torch.tensor([steps], dtype=torch.long,
                           device=DEVICE)).squeeze(1).to(DEVICE)
    noteq = (~(x_curr == y_curr))
    hyperparameters = get_hyperparameters()
    batches_inside = batch_ids.max() + 1
    noteq_batched = noteq.view(batches_inside, -1,
                               hyperparameters["dim_target"])
    true_termination = noteq_batched.any(dim=1).any(dim=-1).float()
    return true_termination
예제 #3
0
 def get_messages_from_features(self, x_i, x_j, walk_edge_index, attrs, batch):
     attrs = attrs.reshape(-1, get_hyperparameters()["dim_edges"])
     enc_attrs = self.encode_edges(attrs)
     messages = (
             self.processor.message(x_i, x_j, utils.flip_edge_index(walk_edge_index), enc_attrs, batch.num_nodes)
             if type(self.processor) == GAT else
             self.processor.message(x_i, x_j, enc_attrs))
     messages = messages.reshape(batch.num_graphs, -1, get_hyperparameters()["dim_latent"])
     return messages
예제 #4
0
def get_pairs(i, stop_move_backward_col, path, do_not_process):
    mask = (i < stop_move_backward_col) & (~do_not_process)
    index = torch.tensor(i,
                         device=get_hyperparameters()["device"],
                         dtype=torch.long)
    pairs = (path[mask][:]).index_select(dim=2, index=index).squeeze(-1)
    return pairs
예제 #5
0
 def update_states(self, distances, predecessors_p,
                   continue_p, current_latent):
     super().update_states(continue_p, current_latent)
     DIM_LATENT = get_hyperparameters()["dim_latent"]
     self.last_distances = torch.where(self.mask, utils.bit2integer(distances), self.last_distances)
     self.last_predecessors_p = torch.where(self.mask, predecessors_p, self.last_predecessors_p)
     self.last_output = self.last_distances
예제 #6
0
def obtain_paths(predecessors,
                 GRAPH_SIZES,
                 STEPS_SIZE,
                 SOURCE_NODES,
                 SINK_NODES,
                 return_path_matrix=False):
    hyperparameters = get_hyperparameters()
    DEVICE = hyperparameters["device"]
    path_matrix = torch.full((len(GRAPH_SIZES), STEPS_SIZE),
                             -100,
                             device=DEVICE,
                             dtype=torch.long)
    stop_move_backward_col = torch.zeros(len(GRAPH_SIZES),
                                         device=DEVICE,
                                         dtype=torch.long)
    final = SOURCE_NODES.clone()

    for i, n in enumerate(SINK_NODES):
        path_matrix[i][0] = n.item()

    for i in range(1, STEPS_SIZE):
        rowcols = (range(len(GRAPH_SIZES)), stop_move_backward_col)
        upd = (path_matrix[rowcols] != SOURCE_NODES)
        upd2 = path_matrix[rowcols] != predecessors[path_matrix[rowcols]]
        upd3 = predecessors[path_matrix[rowcols]] != -1
        upd &= upd2 & upd3
        path_matrix[upd, i] = predecessors[path_matrix[upd, i - 1]]
        stop_move_backward_col[upd] = i
        final[upd] = path_matrix[upd, i]
    path = torch.stack([
        torch.stack([path_matrix[i][:-1], path_matrix[i][1:]], dim=0)
        for i in range(len(GRAPH_SIZES))
    ],
                       dim=0)
    return path_matrix if return_path_matrix else path, stop_move_backward_col, final
예제 #7
0
def load_algorithms(algorithms, processor, use_ints):
    hyperparameters = get_hyperparameters()
    DEVICE = hyperparameters["device"]
    DIM_LATENT = hyperparameters["dim_latent"]
    DIM_NODES_BFS = hyperparameters["dim_nodes_BFS"]
    DIM_NODES_AugmentingPath = hyperparameters["dim_nodes_AugmentingPath"]
    DIM_EDGES = hyperparameters["dim_edges"]
    DIM_EDGES_BFS = hyperparameters["dim_edges_BFS"]
    DIM_BITS = hyperparameters["dim_bits"] if use_ints else None
    for algorithm in algorithms:
        if algorithm == "AugmentingPath":
            algo_net = models.AugmentingPathNetwork(
                DIM_LATENT,
                DIM_NODES_AugmentingPath,
                DIM_EDGES,
                processor,
                flow_datasets.SingleIterationDataset,
                './all_iter',
                bias=hyperparameters["bias"],
                use_ints=use_ints,
                bits_size=DIM_BITS).to(DEVICE)
        if algorithm == "BFS":
            algo_net = models.BFSNetwork(
                DIM_LATENT, DIM_NODES_BFS, DIM_EDGES_BFS, processor,
                flow_datasets.BFSSingleIterationDataset, './bfs').to(DEVICE)
        processor.add_algorithm(algo_net, algorithm)
    def update_weights(self, optimizer):
        loss = 0
        for name, algorithm in self.algorithms.items():
            print("Algorithm", name)
            losses_dict =\
                    algorithm.get_losses_dict()
            pprint(losses_dict)
            loss += algorithm.get_training_loss()
            if get_hyperparameters(
            )["calculate_termination_statistics"]:  #DEPRECATED
                print(
                    "Term precision:", algorithm.true_positive /
                    (algorithm.true_positive + algorithm.false_positive)
                    if algorithm.true_positive +
                    algorithm.false_positive else 'N/A')
                print(
                    "Term recall:", algorithm.true_positive /
                    (algorithm.true_positive + algorithm.false_negative)
                    if algorithm.true_positive +
                    algorithm.false_negative else 'N/A')

        start = time.time()
        optimizer.zero_grad()
        print("LOSSITEM", loss.item())
        loss.backward()
        optimizer.step()
예제 #9
0
 def prepare_initial_masks(self, batch):
     DEVICE = get_hyperparameters()["device"]
     mask = torch.ones_like(batch.batch, dtype=torch.bool, device=DEVICE)
     mask_cp = torch.ones(batch.num_graphs, dtype=torch.bool, device=DEVICE)
     edge_mask = torch.ones_like(batch.edge_index[0],
                                 dtype=torch.bool,
                                 device=DEVICE)
     return mask, mask_cp, edge_mask
예제 #10
0
 def update_states(self, continue_p, current_latent):
     DIM_LATENT = get_hyperparameters()["dim_latent"]
     self.last_continue_p = torch.where(self.mask_cp, continue_p,
                                        self.last_continue_p)
     self.last_latent = torch.where(
         self.mask.unsqueeze(1).repeat_interleave(DIM_LATENT, dim=1),
         current_latent, self.last_latent)
     return self.last_continue_p
예제 #11
0
 def __init__(self, latent_features, processor_type='MPNN'):
     super(AlgorithmProcessor, self).__init__()
     if processor_type == 'MPNN':
         self.processor = MPNN(latent_features,
                               latent_features,
                               latent_features,
                               bias=get_hyperparameters()['bias'])
     self.algorithms = nn.ModuleDict()
예제 #12
0
 def update_states(self, reachable_p, continue_p, current_latent):
     super().update_states(continue_p, current_latent)
     DIM_LATENT = get_hyperparameters()["dim_latent"]
     self.last_reachable_p = torch.where(self.mask, reachable_p,
                                         self.last_reachable_p)
     self.last_reachable = (self.last_reachable_p >= 0.5).float()
     self.last_output = torch.where(self.mask, self.last_reachable,
                                    self.last_output)
예제 #13
0
 def get_step_io(self, x, y):
     DEVICE = get_hyperparameters()["device"]
     x_curr = torch.index_select(
         x, 1, torch.tensor([self.steps], dtype=torch.long,
                            device=DEVICE)).squeeze(1)
     y_curr = torch.index_select(
         y, 1, torch.tensor([self.steps], dtype=torch.long,
                            device=DEVICE)).squeeze(1)
     return x_curr, y_curr
예제 #14
0
def get_walks(training, batch, output, GRAPH_SIZES, SOURCE_NODES, SINK_NODES):
    if training:
        return random_walk(batch.edge_index[0],
                           batch.edge_index[1],
                           SINK_NODES,
                           walk_length=get_hyperparameters()["walk_length"],
                           coalesced=True).long(), None
    path_matrix, _, mask = get_walks_from_output(output, GRAPH_SIZES,
                                                 SOURCE_NODES, SINK_NODES)
    return path_matrix, mask
 def __init__(self, latent_features, dataset, processor_type='MPNN'):
     assert processor_type in ['MPNN', 'PNA', 'GAT']
     super(AlgorithmProcessor, self).__init__()
     if processor_type == 'MPNN':
         self.processor = MPNN(latent_features,
                               latent_features,
                               latent_features,
                               bias=get_hyperparameters()["bias"])
     elif processor_type == 'PNA':
         self.processor = PNAWrapper(latent_features,
                                     latent_features,
                                     latent_features,
                                     SingleIterationDataset,
                                     bias=get_hyperparameters()["bias"])
     elif processor_type == 'GAT':
         self.processor = GAT(latent_features,
                              latent_features,
                              latent_features,
                              bias=get_hyperparameters()["bias"])
     self.algorithms = nn.ModuleDict()
예제 #16
0
    def set_initial_last_states(self, batch, STEPS_SIZE, SOURCE_NODES):
        hyperparameters = get_hyperparameters()
        DEVICE = hyperparameters["device"]
        DIM_LATENT = hyperparameters["dim_latent"]

        SIZE = batch.num_nodes
        INF = STEPS_SIZE
        super().set_initial_last_states(batch, STEPS_SIZE, SOURCE_NODES)
        self.last_reachable_p = torch.full([SIZE], -1e3, device=DEVICE)
        self.last_reachable_p[0] = 1e3
        self.last_reachable = torch.zeros(SIZE)
        self.last_reachable[0] = 1.
예제 #17
0
    def set_initial_last_states(self, batch, STEPS_SIZE, SOURCE_NODES):
        hyperparameters = get_hyperparameters()
        DEVICE = hyperparameters["device"]
        DIM_LATENT = hyperparameters["dim_latent"]

        SIZE = batch.num_nodes
        self.last_latent = torch.zeros(SIZE, DIM_LATENT, device=DEVICE)
        self.last_continue_p = torch.ones(batch.num_graphs, device=DEVICE)
        x, y = self.get_input_output_features(batch, SOURCE_NODES)
        x.requires_grad = False
        y.requires_grad = False
        x_curr, _ = self.get_step_io(x, y)
        self.last_output = x_curr[:, 0].clone()
예제 #18
0
    def get_features_from_walk(self, walks, batch, mask_end_of_path=None):
        attr_matrix = torch_geometric.utils.to_dense_adj(batch.edge_index, edge_attr=batch.edge_attr).squeeze()
        walk_nodes_latent = self.last_latent[walks]
        walk_nodes_latent_i = walk_nodes_latent[:, :-1]
        walk_nodes_latent_j = walk_nodes_latent[:, 1:]
        walk_nodes_latent_i = walk_nodes_latent_i.reshape(-1, get_hyperparameters()["dim_latent"])
        walk_nodes_latent_j = walk_nodes_latent_j.reshape(-1, get_hyperparameters()["dim_latent"])
        wf = walks[:, :-1]
        ws = walks[:, 1:]
        walk_edge_index = torch.stack((wf, ws), dim=0)
        walk_edge_index = walk_edge_index.view(2,-1)

        walk_attrs = attr_matrix[(ws, wf)]
        inv_walk_attrs = attr_matrix[(wf, ws)]
        if self.training:
            walk_attrs[:, :, 1] = walk_attrs[:, :, 1] + torch.randint_like(walk_attrs[:, :, 1], low=0, high=10)
        wa = walk_attrs[:, :, 1].clone().detach()
        if mask_end_of_path is not None:
            wa[mask_end_of_path[:, 1:]] = 1e9
        actual_argmins = wa.argmin(dim=1)
        no_step_mask = (wa.min(dim=1).values == 1e9) # no step was made by the algorithm so min is not defined
        return walk_nodes_latent_i, walk_nodes_latent_j, walk_edge_index, walk_attrs, inv_walk_attrs, actual_argmins, no_step_mask
예제 #19
0
    def set_initial_last_states(self, batch, STEPS_SIZE, SOURCE_NODES): 
        hyperparameters = get_hyperparameters()
        DEVICE = hyperparameters["device"]
        DIM_LATENT = hyperparameters["dim_latent"]
        DIM_NODES = hyperparameters["dim_nodes_AugmentingPath"]
        DIM_EDGES = hyperparameters["dim_edges"]

        SIZE = batch.num_nodes
        super().set_initial_last_states(batch, STEPS_SIZE, SOURCE_NODES)
        self.last_predecessors_p = torch.full((SIZE, SIZE), -1e9, device=DEVICE)
        self.last_predecessors_p[(SOURCE_NODES, SOURCE_NODES)] = 1e9
        self.last_distances = self.last_output.clone()
        self.last_distances[SOURCE_NODES] = 0.
예제 #20
0
def reweight_batch(batch, batch_bfs, use_ints):
    DEVICE = get_hyperparameters()["device"]
    if use_ints:
        weights = torch.randint_like(batch.edge_attr[:, 1],
                                     1,
                                     16,
                                     device=DEVICE,
                                     dtype=torch.float)
    else:
        weights = 0.8 * torch.rand_like(
            batch.edge_attr[:, 1], device=DEVICE, dtype=torch.float) + 0.2
    batch.edge_attr = torch.stack((weights, batch.edge_attr[:, 1]), dim=1)
    batch_bfs.edge_attr = torch.stack((weights, batch_bfs.edge_attr[:, 1]),
                                      dim=1)
예제 #21
0
    def update_broken_invariants(self, batch, predecessors, adj_matrix, flow_matrix):

        start = time.time()
        DEVICE = get_hyperparameters()["device"]
        GRAPH_SIZES, SOURCE_NODES, SINK_NODES = utils.get_sizes_and_source_sink(batch)
        STEPS_SIZE = GRAPH_SIZES.max()
        _, y = self.get_input_output_features(batch, SOURCE_NODES)
        broke_flow = torch.zeros(batch.num_graphs, dtype=torch.bool, device=DEVICE)
        broke_reachability_source = torch.zeros(batch.num_graphs, dtype=torch.bool, device=DEVICE)
        broke_invariant = torch.zeros(batch.num_graphs, dtype=torch.bool, device=DEVICE)
        curr_node = SINK_NODES.clone().detach()
        cnt = 0
        predecessors_real = y[:, -1, -1]

        idx = predecessors[curr_node] != curr_node
    
        while (predecessors_real[SINK_NODES] != -1).any() and cnt <= STEPS_SIZE and idx.any() and not utils.interrupted():
            # Ignore if we reached the starting node loop
            # (predecessor[starting node] = starting node)
            move_to_predecessors = torch.stack((predecessors[curr_node], curr_node), dim=0)[:, idx]
            rowcols = (move_to_predecessors[0], move_to_predecessors[1])
            if not adj_matrix[rowcols].all():
                # each predecessor lead to a node accessible by an edge!!!
                print()
                print(adj_matrix)
                print(curr_node)
                print(predecessors[curr_node])
                print("FATAL INVARIANT ERORR")
                exit(0)

            assert adj_matrix[rowcols].all()

            if (flow_matrix[rowcols] <= 0).any():
                broke_flow[idx] |= flow_matrix[rowcols] <= 0
            curr_node[idx] = predecessors[curr_node[idx]]
            idx = (predecessors[curr_node] != curr_node) & (predecessors_real[SINK_NODES] != -1)
            cnt += 1
            if cnt > STEPS_SIZE+1:
                break

        original_reachable_mask = (predecessors_real[SINK_NODES] != -1)
        broke_reachability_source |= (curr_node != SOURCE_NODES)
        broke_invariant = broke_flow | broke_reachability_source
        broke_all = broke_flow & broke_reachability_source
        
        self.broken_invariants.extend((original_reachable_mask & broke_invariant).clone().detach())
        self.broken_reachabilities.extend((original_reachable_mask & broke_reachability_source).clone().detach())
        self.broken_flows.extend((original_reachable_mask & broke_flow).clone().detach())
        self.broken_all.extend((original_reachable_mask & broke_all).clone().detach())
예제 #22
0
 def add_algorithms(self, algo_list):
     hyperparameters = get_hyperparameters()
     device = hyperparameters['device']
     node_features = hyperparameters['dim_nodes']
     edge_features = hyperparameters['dim_edges']
     latent_features = hyperparameters['dim_latent']
     for algo in algo_list:
         if algo == 'BFS':
             self.algorithms[algo] = models.bfs_network.BFSNetwork(node_features, edge_features,
                                                                   latent_features, self).to(device)
         elif algo in ['TRANS', 'TIPS', 'BUBBLES']:
             self.algorithms[algo] = models.traversal_network.TraversalNetwork(node_features, edge_features,
                                                                               latent_features, self).to(device)
         else:
             # For other algorithms
             pass
예제 #23
0
    def get_step_loss(self,
                      batch_ids,
                      mask,
                      mask_cp,
                      y_curr,
                      continue_p,
                      true_termination,
                      reachable_p,
                      compute_losses_and_broken=True):
        reachable_p_masked = reachable_p[mask]
        reachable_real_masked = y_curr[mask].squeeze()
        steps = sum(mask_cp.float())

        loss_reachable, loss_term, processed_nodes, step_acc = 0, 0, 0, 1

        if compute_losses_and_broken:
            processed_nodes = len(reachable_real_masked)
            self.predictions["reachabilities"].extend(reachable_p_masked)
            self.actual["reachabilities"].extend(reachable_real_masked)
            self.predictions["terminations"].extend(continue_p[mask_cp])
            self.actual["terminations"].extend(true_termination[mask_cp])
            loss_reachable = F.binary_cross_entropy_with_logits(
                reachable_p_masked,
                reachable_real_masked,
                reduction='sum',
                pos_weight=torch.tensor(1.00))
            loss_term = F.binary_cross_entropy_with_logits(
                continue_p[mask_cp],
                true_termination[mask_cp],
                reduction='sum',
                pos_weight=torch.tensor(1.00))
            if get_hyperparameters()["calculate_termination_statistics"]:
                self.update_termination_statistics(continue_p[mask_cp],
                                                   true_termination[mask_cp])

            if not self.training:
                reachable_split = utils.split_per_graph(
                    batch_ids, reachable_p > 0)
                reachable_real_split = utils.split_per_graph(
                    batch_ids, y_curr.squeeze())
                correct, tot = BFSNetwork.calculate_step_acc(
                    reachable_split[mask_cp], reachable_real_split[mask_cp])
                self.mean_step.extend(correct / tot.float())
                step_acc = correct / tot.float()

        return loss_reachable, loss_term, steps, processed_nodes, step_acc
예제 #24
0
    def get_step_loss(self,
                      batch_ids, mask, mask_cp,
                      y_curr,
                      continue_p, true_termination,
                      distances, predecessors_p,
                      compute_losses_and_broken=True):
        distances_masked, distances_real_masked, predecessors_p_masked, predecessors_real_masked = \
                AugmentingPathNetwork.mask_infinities(mask, y_curr, distances, predecessors_p)
        steps = sum(mask_cp.float())

        train = self.training

        loss_dist, loss_pred, loss_term, processed_nodes, step_acc = 0, 0, 0, 0, 1
        if distances_real_masked.nelement() != 0 and compute_losses_and_broken:
            processed_nodes = len(distances_real_masked)
            if self.bits_size is None:
                loss_dist = F.mse_loss(distances_masked, distances_real_masked, reduction='sum')
            else:
                loss_dist = F.binary_cross_entropy_with_logits(distances_masked, utils.integer2bit(distances_real_masked), reduction='sum')
            loss_pred = F.cross_entropy(predecessors_p_masked, predecessors_real_masked, ignore_index=-1, reduction='sum')

        if compute_losses_and_broken:
            assert mask_cp.any(), mask_cp
            loss_term = F.binary_cross_entropy_with_logits(continue_p[mask_cp], true_termination[mask_cp], reduction='sum', pos_weight=torch.tensor(1.00))
            if get_hyperparameters()["calculate_termination_statistics"]:
                self.update_termination_statistics(continue_p[mask_cp], true_termination[mask_cp])

            assert loss_term.item() != float('inf')

        if not train and mask_cp.any() and compute_losses_and_broken:
            assert mask_cp.any()
            _, predecessors_real = AugmentingPathNetwork.get_real_output_values(y_curr)
            predecessors_p_split = utils.split_per_graph(batch_ids, predecessors_p)
            predecessors_real_split = utils.split_per_graph(batch_ids, predecessors_real)
            correct, tot = AugmentingPathNetwork.calculate_step_acc(torch.max(predecessors_p_split[mask_cp], dim=2).indices, predecessors_real_split[mask_cp])
            self.mean_step.extend(correct/tot.float())
            step_acc = correct/tot.float()

        return loss_dist, loss_pred, loss_term, steps, processed_nodes, step_acc
예제 #25
0
    def augment_flow(self, batch, walks, mask_end_of_path, mins):
        BATCH_SIZE = batch.num_graphs
        x_i, x_j, walk_edge_index, attrs, inv_walk_attrs, actual_argmins, no_step_mask =\
                self.get_features_from_walk(walks, batch, mask_end_of_path=mask_end_of_path)
        if no_step_mask.all():
            return attrs[:, :, 1] # nothing changes

        if self.training: attrs = self._attrs #Augmentation and bottleneck finding had to have the same edge attributes
        messages = self.get_messages_from_features(x_i, x_j, walk_edge_index, attrs, batch)
        messages_inv = self.get_messages_from_features(x_j, x_i, walk_edge_index, inv_walk_attrs, batch)
        minemb = self.bit_encoder(utils.integer2bit(mins.float()))
        minemb = minemb.unsqueeze(1).repeat_interleave(messages.shape[1], dim=1)
        new_weight_distribution = self.subtract_network(torch.cat((minemb, messages), dim=-1))

        mask = torch.arange(2**self.bits_size, device=get_hyperparameters()["device"])
        mask = mask.repeat(BATCH_SIZE, attrs.shape[1], 1)
        expanded_caps = attrs[:, :, 1].unsqueeze(2).expand_as(mask)
        mask = mask > expanded_caps
        new_weight_distribution[mask] = -1e9
        new_weight_distribution = new_weight_distribution[~no_step_mask]
        real_new_caps = attrs[:, :, 1][~no_step_mask]
        real_new_caps -= mins[~no_step_mask].unsqueeze(-1)
        if mask_end_of_path is not None:
            real_new_caps[mask_end_of_path[:, 1:][~no_step_mask]] = -100 # ignore value of cross entropy
        just_correct = real_new_caps >= 0 # the neural mins provided may be incorrect so ignore these values
        neural_new_caps = new_weight_distribution.argmax(dim=-1)
        if just_correct.any():
            self.losses["augment"] = F.cross_entropy(new_weight_distribution[just_correct].view(-1, 2**self.bits_size), real_new_caps[just_correct].view(-1).long())
            subtract_correct = (neural_new_caps == real_new_caps)[just_correct].sum()
            subtract_all = real_new_caps[just_correct].nelement()
            if not self.training:
                self.subtract_correct += subtract_correct
                self.subtract_all += subtract_all
        else:
            return attrs[:, :, 1] # nothing changes
        return neural_new_caps
예제 #26
0
 def zero_hidden(self, num_nodes):
     self.hidden = torch.zeros(num_nodes, self.out_channels).to(
         get_hyperparameters()["device"])
예제 #27
0
def termination_condition(args, i, threshold, batch, reachable_sinks=None):
    if not args["--use-BFS-for-termination"]:
        return i < threshold
    return i < get_hyperparameters()["max_threshold"] and reachable_sinks.any()
예제 #28
0
def run(args, threshold, processor, probp=1, probq=4, savefile=True):
    hyperparameters = get_hyperparameters()
    DEVICE = hyperparameters["device"]
    DIM_LATENT = hyperparameters["dim_latent"]
    if savefile:
        with open(args["SAVE_FILE"], "w"):
            pass
        f = open(args["SAVE_FILE"], "a+")
    dataset = GraphOnlyDataset('./graph_only',
                               split='test' + args["--upscale"],
                               less_wired=True,
                               probp=probp,
                               probq=probq,
                               device='cpu')
    dataset_BFS = GraphOnlyDatasetBFS('./graph_only_BFS',
                                      split='test' + args["--upscale"],
                                      probp=probp,
                                      probq=probq,
                                      less_wired=True,
                                      device='cpu')
    result_maxflows = []
    with torch.no_grad():
        for rep in range(10):
            print(rep)
            current_result = []
            loader = DataLoader(dataset,
                                batch_size=hyperparameters["batch_size"],
                                shuffle=False,
                                drop_last=False,
                                num_workers=0)
            loader_BFS = iter(
                DataLoader(dataset_BFS,
                           batch_size=hyperparameters["batch_size"],
                           shuffle=False,
                           drop_last=False,
                           num_workers=0))
            for batch in tqdm(loader, dynamic_ncols=True):
                batch_bfs = next(loader_BFS)
                batch_bfs.to(DEVICE)
                batch.to(DEVICE)
                start_iter = time.time()
                start = time.time()
                GRAPH_SIZES, SOURCE_NODES, SINK_NODES = utils.get_sizes_and_source_sink(
                    batch)
                # we make at most |V|-1 steps
                STEPS_SIZE = GRAPH_SIZES.max()
                inv_edge_index = utils.create_inv_edge_index(
                    len(GRAPH_SIZES), GRAPH_SIZES.max(), batch.edge_index)

                do_not_process = torch.zeros_like(GRAPH_SIZES,
                                                  dtype=torch.bool,
                                                  device=DEVICE)
                start = time.time()
                path_matrix, stop_move_backward_col, bottleneck = find_augmenting_path(
                    args, batch, batch_bfs, do_not_process, processor,
                    inv_edge_index, threshold)
                do_not_process = bottleneck == 0

                cnt = 0
                while (bottleneck != 0).any():
                    wrong_minus = augment_flow(
                        batch,
                        inv_edge_index,
                        path_matrix,
                        stop_move_backward_col,
                        bottleneck,
                        do_not_process,
                        args["--use-neural-augmentation"],
                        augmenting_path_network=processor.
                        algorithms["AugmentingPath"])
                    batch_bfs.edge_attr[:, 1] = batch.edge_attr[:, 1]
                    do_not_process |= wrong_minus

                    path_matrix, stop_move_backward_col, bottleneck = find_augmenting_path(
                        args, batch, batch_bfs, do_not_process, processor,
                        inv_edge_index, threshold)
                    do_not_process |= (bottleneck == 0)
                    bottleneck[do_not_process] = 0
                    assert ((bottleneck <= 1) &
                            (bottleneck >= 0)).all(), (bottleneck, path_matrix)

                    cnt += 1

                start = time.time()
                maxflows = [0 for sn in SINK_NODES]
                cnt = 0
                for isn, sn in enumerate(SINK_NODES):
                    cnt = 0
                    for i in range(len(batch.batch)):
                        if inv_edge_index[sn][i] != -100:
                            cnt += 1
                            assert 0 <= batch.edge_attr[inv_edge_index[sn]
                                                        [i]][1] <= 1
                            maxflows[isn] += batch.edge_attr[inv_edge_index[sn]
                                                             [i]][1]

                if savefile:
                    print(*[int(mf.item()) for mf in maxflows],
                          sep=' ',
                          end=' ',
                          file=f)
                else:
                    current_result.extend([int(mf.item()) for mf in maxflows])
            if savefile:
                f.write('\n')
            else:
                result_maxflows.append(current_result)
    if savefile:
        f.close()
    else:
        return result_maxflows
예제 #29
0
def augment_flow(batch,
                 inv_edge_index,
                 path,
                 stop_move_backward_col,
                 bottleneck,
                 do_not_process,
                 use_neural_augmentation,
                 augmenting_path_network=None):
    def get_edge_indexes(step):
        pairs = get_pairs(step, stop_move_backward_col, path, do_not_process)
        edge_idx = inv_edge_index[(pairs[:, 1], pairs[:, 0])]
        edge_idx_rev = inv_edge_index[(pairs[:, 0], pairs[:, 1])]
        return edge_idx, edge_idx_rev

    flow = batch.edge_attr[:, 1]
    old_flow = flow.clone()
    path_matrix = torch.cat((path[:, 0, :], path[:, 1, -1].unsqueeze(-1)),
                            dim=-1)
    mask_end_of_path = path_matrix == -100
    path_matrix[mask_end_of_path] = 0
    if use_neural_augmentation:
        new_flows = augmenting_path_network.augment_flow(
            batch, path_matrix, mask_end_of_path, bottleneck)
    largest_len = max(stop_move_backward_col)
    needs_rerun = torch.zeros(batch.num_graphs,
                              device=get_hyperparameters()["device"],
                              dtype=torch.bool)
    for i in range(largest_len):
        mask = (i < stop_move_backward_col) & (~do_not_process)
        if not mask.any():
            break
        edge_idx, edge_idx_rev = get_edge_indexes(i)
        assert (edge_idx != -100).all()
        assert (edge_idx_rev != -100).all()
        assert (abs(flow[edge_idx]) <=
                1).all(), flow[edge_idx][~(abs(flow[edge_idx]) <= 1)]
        assert (abs(bottleneck[mask]) <= 1).all()
        assert (bottleneck[do_not_process] == 0).all()
        assert (bottleneck[~do_not_process] == 1).all()
        assert (bottleneck[mask] == 1).all()
        if use_neural_augmentation:
            new_flow = new_flows[mask[stop_move_backward_col > 0]][:,
                                                                   i].float()
            should_be = flow[edge_idx] - bottleneck[mask]
            not_what_should_be = (new_flow != should_be)
            needs_rerun[mask] |= not_what_should_be
            consts = flow[edge_idx] + flow[edge_idx_rev]
            new_flow_rev = consts - new_flow
            flow[edge_idx] = torch.where(not_what_should_be, flow[edge_idx],
                                         new_flow)
            flow[edge_idx_rev] = torch.where(not_what_should_be,
                                             flow[edge_idx_rev], new_flow_rev)
        else:
            flow[edge_idx] -= bottleneck[mask]
            flow[edge_idx_rev] += bottleneck[mask]
    if needs_rerun.any():
        for i in range(largest_len):
            mask = (i < stop_move_backward_col) & (~do_not_process)
            if not mask.any():
                break
            edge_idx, edge_idx_rev = get_edge_indexes(i)
            flow[edge_idx] = old_flow[edge_idx]
            flow[edge_idx_rev] = old_flow[edge_idx_rev]
    return needs_rerun
예제 #30
0
def find_augmenting_path(args,
                         batch,
                         batch_bfs,
                         do_not_process,
                         processor,
                         inv_edge_index,
                         threshold,
                         debug=False):
    DEVICE = get_hyperparameters()["device"]
    GRAPH_SIZES, SOURCE_NODES, SINK_NODES = utils.get_sizes_and_source_sink(
        batch)
    STEPS_SIZE = GRAPH_SIZES.max()
    redo_mask = torch.ones_like(GRAPH_SIZES, device=DEVICE,
                                dtype=torch.bool) & (~do_not_process)
    predecessors_last = torch.full_like(batch.batch, -1, device=DEVICE)
    reachable_last = torch.ones_like(predecessors_last, dtype=torch.bool)
    weights = batch.edge_attr[:, 0]
    flow = batch.edge_attr[:, 1]
    final = SINK_NODES.clone()
    bottleneck = torch.zeros(batch.num_graphs, device=DEVICE)
    path_matrix = torch.full((batch.num_graphs, STEPS_SIZE),
                             -100,
                             device=DEVICE,
                             dtype=torch.long)
    stop_move_backward_col = torch.zeros(batch.num_graphs,
                                         device=DEVICE,
                                         dtype=torch.long)
    wrong_bottleneck_mask = None
    for algorithm in processor.algorithms.values():
        algorithm.zero_validation_stats()

    if args["--use-BFS-for-termination"]:
        with torch.no_grad():
            reachable = processor.algorithms["BFS"].process(
                batch_bfs,
                EPSILON=0,
                enforced_mask=redo_mask,
                compute_losses_and_broken=False)
    else:
        reachable = torch.ones(batch.num_nodes,
                               device=DEVICE,
                               dtype=torch.bool)
    i = 0
    while termination_condition(args, i, threshold, batch,
                                reachable[SINK_NODES]):
        i += 1

        start = time.time()
        with torch.no_grad():
            predecessors = processor.algorithms["AugmentingPath"].process(
                batch,
                EPSILON=0,
                enforced_mask=redo_mask,
                compute_losses_and_broken=False)

        predecessors = torch.where(redo_mask[batch.batch], predecessors,
                                   predecessors_last)
        predecessors_last = predecessors
        predecessors = predecessors_last.clone()
        path_matrix, stop_move_backward_col, final = obtain_paths(
            predecessors, GRAPH_SIZES, STEPS_SIZE, SOURCE_NODES, SINK_NODES)
        if args["--use-neural-bottleneck"]:
            walks, mask_end_of_path = utils.get_walks(False, batch,
                                                      predecessors,
                                                      GRAPH_SIZES,
                                                      SOURCE_NODES, SINK_NODES)
            bottleneck = processor.algorithms["AugmentingPath"].find_mins(
                batch, walks, mask_end_of_path, GRAPH_SIZES, SOURCE_NODES,
                SINK_NODES)
            zero_unreachable_and_frozen_bottleneck(bottleneck, do_not_process,
                                                   reachable[SINK_NODES])
            real_bottleneck = get_bottleneck(path_matrix,
                                             stop_move_backward_col, flow,
                                             inv_edge_index, do_not_process,
                                             reachable[SINK_NODES])
            wrong_bottleneck_mask = (bottleneck == 1) & (real_bottleneck == 0)
            do_not_process |= wrong_bottleneck_mask
        else:
            bottleneck = get_bottleneck(path_matrix, stop_move_backward_col,
                                        flow, inv_edge_index, do_not_process,
                                        reachable[SINK_NODES])

        reweight_batch(batch, batch_bfs, args["--use-ints"])
        if args["--use-BFS-for-termination"]:
            with torch.no_grad():
                reachable = processor.algorithms["BFS"].process(
                    batch_bfs,
                    EPSILON=0,
                    enforced_mask=redo_mask,
                    compute_losses_and_broken=False)

        reachable = torch.where(redo_mask[batch.batch], reachable,
                                reachable_last)
        reachable_last = reachable
        reachable = reachable_last.clone()
        reachable_sinks = reachable[SINK_NODES]
        redo_mask = get_redo_mask(args["--use-BFS-for-termination"],
                                  args["--use-neural-bottleneck"], final,
                                  SOURCE_NODES, bottleneck, do_not_process,
                                  reachable_sinks)
        do_not_process[redo_mask] |= (~reachable_sinks[redo_mask])

        if debug:
            if (bottleneck == 0).any():
                print("Broke flow cap invariant", file=sys.stderr)
            if (final != SOURCE_NODES).any():
                print("Broke reachability invariant", file=sys.stderr)

        if not redo_mask.any():
            bottleneck[do_not_process] = 0
            return path_matrix, stop_move_backward_col, bottleneck

    bottleneck[redo_mask] = 0
    bottleneck[
        do_not_process] = 0  # Hack to set redo mask to false for next iterations
    return path_matrix, stop_move_backward_col, bottleneck