コード例 #1
0

if __name__ == '__main__':

    # Create a new model object, and initialize its parameters
    fit = SimpleAE(latent_dim=6)

    # Prepare to iterate through all the training data.
    # See the note at the top, under Utilities.
    iter_training_data = enumerate_cycle(mnist_batched)

    optimizer = optim.Adadelta(fit.parameters(), lr=1)

    print("Press Ctrl+C to end training and save parameters")

    while not interrupted():
        (epoch, batch_num), (imgs, lbls) = next(iter_training_data)
        optimizer.zero_grad()
        _, rx = fit(imgs)
        e = nn.functional.mse_loss(imgs, rx)
        e.backward()
        optimizer.step()

        if batch_num % 25 == 0:
            IPython.display.clear_output(wait=True)
            print(
                f'epoch={epoch} batch={batch_num}/{len(mnist_batched)} loss={e.item()}'
            )

    # Optionally, save all the parameters
    torch.save(fit.state_dict(), 'intro_ae_6d.pt')
コード例 #2
0
    best_model = models.AlgorithmProcessor(DIM_LATENT, SingleIterationDataset,
                                           args["--processor-type"]).to(DEVICE)
    best_model.algorithms = nn.ModuleDict(processor.algorithms.items())
    best_model.load_state_dict(copy.deepcopy(processor.state_dict()))

    torch.set_printoptions(precision=20)

    with torch.autograd.profiler.profile(enabled=False, use_cuda=True) as prof:
        # for algorithm in processor.algorithms:
        #     algorithm.loader = DataLoader(algorithm.train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=8)
        #     algorithm.val_loader = DataLoader(algorithm.val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=8)
        optimizer = optim.Adam(params,
                               lr=hyperparameters["lr"],
                               weight_decay=hyperparameters["weight_decay"])
        for epoch in range(3000):  # FIXME
            if interrupted():
                break
            # 0.0032
            processor.train()
            iterate_over(processor, optimizer)

            patience += 1
            print('Epoch {:4d}: \n'.format(epoch), end=' ')
            processor.eval()
            iterate_over(processor)
            # print("Mean/Last step acc", processor.algorithms[0].get_validation_accuracies())
            # total_loss = sum(processor.algorithms[0].get_validation_losses()) #TODO PRETTIER!

            if augmenting_path_network is None:
                total_loss = sum(
                    processor.algorithms["BFS"].get_validation_losses())
コード例 #3
0
 def loop_condition(self, batch_ids, x, y, STEPS_SIZE, GRAPH_SIZES):
     return (((not self.training and self.mask_cp.any()) or
              (self.training
               and utils.finish(x, y, batch_ids, self.steps, STEPS_SIZE,
                                GRAPH_SIZES).bool().any()))
             and self.steps + 1 < STEPS_SIZE and not utils.interrupted())
コード例 #4
0

if __name__ == '__main__':
    # Create a new model object, and initialize its parameters
    fit = ND_AE(latent_dim=8)

    # Prepare to iterate through all the training data.
    # See the note at the top, under Utilities.
    iter_training_data = enumerate_cycle(mnist_batched)

    optimizer = optim.Adadelta(fit.parameters(), lr=1)

    print("Press Ctrl+C to end training and save parameters")

    epoch = 0
    while not interrupted() and epoch < 10:
        (epoch, batch_num), (imgs, lbls) = next(iter_training_data)
        optimizer.zero_grad()
        _, rx = fit(imgs)
        e = nn.functional.mse_loss(imgs, rx)
        e.backward()
        optimizer.step()

        if batch_num % 25 == 0:
            IPython.display.clear_output(wait=True)
            print(
                f'epoch={epoch} batch={batch_num}/{len(mnist_batched)} loss={e.item()}'
            )

    # Optionally, save all the parameters
    torch.save(fit.state_dict(), 'nd_ae_8d.pt')
コード例 #5
0
    def update_broken_invariants(self, batch, predecessors, adj_matrix, flow_matrix):

        start = time.time()
        DEVICE = get_hyperparameters()["device"]
        GRAPH_SIZES, SOURCE_NODES, SINK_NODES = utils.get_sizes_and_source_sink(batch)
        STEPS_SIZE = GRAPH_SIZES.max()
        _, y = self.get_input_output_features(batch, SOURCE_NODES)
        broke_flow = torch.zeros(batch.num_graphs, dtype=torch.bool, device=DEVICE)
        broke_reachability_source = torch.zeros(batch.num_graphs, dtype=torch.bool, device=DEVICE)
        broke_invariant = torch.zeros(batch.num_graphs, dtype=torch.bool, device=DEVICE)
        curr_node = SINK_NODES.clone().detach()
        cnt = 0
        predecessors_real = y[:, -1, -1]

        idx = predecessors[curr_node] != curr_node
    
        while (predecessors_real[SINK_NODES] != -1).any() and cnt <= STEPS_SIZE and idx.any() and not utils.interrupted():
            # Ignore if we reached the starting node loop
            # (predecessor[starting node] = starting node)
            move_to_predecessors = torch.stack((predecessors[curr_node], curr_node), dim=0)[:, idx]
            rowcols = (move_to_predecessors[0], move_to_predecessors[1])
            if not adj_matrix[rowcols].all():
                # each predecessor lead to a node accessible by an edge!!!
                print()
                print(adj_matrix)
                print(curr_node)
                print(predecessors[curr_node])
                print("FATAL INVARIANT ERORR")
                exit(0)

            assert adj_matrix[rowcols].all()

            if (flow_matrix[rowcols] <= 0).any():
                broke_flow[idx] |= flow_matrix[rowcols] <= 0
            curr_node[idx] = predecessors[curr_node[idx]]
            idx = (predecessors[curr_node] != curr_node) & (predecessors_real[SINK_NODES] != -1)
            cnt += 1
            if cnt > STEPS_SIZE+1:
                break

        original_reachable_mask = (predecessors_real[SINK_NODES] != -1)
        broke_reachability_source |= (curr_node != SOURCE_NODES)
        broke_invariant = broke_flow | broke_reachability_source
        broke_all = broke_flow & broke_reachability_source
        
        self.broken_invariants.extend((original_reachable_mask & broke_invariant).clone().detach())
        self.broken_reachabilities.extend((original_reachable_mask & broke_reachability_source).clone().detach())
        self.broken_flows.extend((original_reachable_mask & broke_flow).clone().detach())
        self.broken_all.extend((original_reachable_mask & broke_all).clone().detach())