예제 #1
0
                          sep=' ',
                          end=' ',
                          file=f)
                else:
                    current_result.extend([int(mf.item()) for mf in maxflows])
            if savefile:
                f.write('\n')
            else:
                result_maxflows.append(current_result)
    if savefile:
        f.close()
    else:
        return result_maxflows


if __name__ == "__main__":
    hyperparameters = get_hyperparameters()
    DEVICE = hyperparameters["device"]
    DIM_LATENT = hyperparameters["dim_latent"]
    args = docopt(__doc__)
    print("ARGS", args["--algorithms"])
    processor = models.AlgorithmProcessor(DIM_LATENT, SingleIterationDataset,
                                          args["--processor-type"]).to(DEVICE)
    utils.load_algorithms(args["--algorithms"], processor, args["--use-ints"])
    processor.load_state_dict(torch.load(args["MODEL_TO_LOAD"]))
    processor.eval()

    with torch.no_grad():
        run(args, int(args["--threshold"]), processor, int(args["--probp"]),
            int(args["--probq"]))
예제 #2
0
def main(algo_list, test, train_path, test_path):

    hyperparameters = get_hyperparameters()
    num_epochs = hyperparameters['num_epochs']
    device = hyperparameters['device']
    dim_latent = hyperparameters['dim_latent']
    batch_size = hyperparameters['batch_size']
    patience_limit = hyperparameters['patience_limit']

    mode = 'test' if test else 'train'
    time_now = datetime.now().strftime('%Y-%b-%d-%H-%M')

    processor = models.AlgorithmProcessor(dim_latent).to(device)
    processor.add_algorithms(algo_list)
    params = list(processor.parameters())
    model_path = f'trained_models/processor_{time_now}.pt'

    if not os.path.isdir(os.path.join(train_path, 'processed')):
        os.mkdir(os.path.join(train_path, 'processed'))
    if not os.path.isdir(os.path.join(test_path, 'processed')):
        os.mkdir(os.path.join(test_path, 'processed'))
    if not os.path.isdir('trained_models'):
        os.mkdir('trained_models')
    if not os.path.isdir('figures'):
        os.mkdir('figures')

    ds = MultiAlgoDataset(
        train_path) if len(algo_list) > 1 else SingleAlgoDataset(train_path)
    ds_test = MultiAlgoDataset(
        test_path) if len(algo_list) > 1 else SingleAlgoDataset(test_path)

    num_graphs = len(ds)
    valid_fraction = 0.3
    valid_size = int(round(num_graphs * valid_fraction))
    train_size = num_graphs - valid_size
    ds_train, ds_valid = random_split(ds, [train_size, valid_size])

    dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
    dl_valid = DataLoader(ds_valid, batch_size=batch_size, shuffle=False)
    dl_test = DataLoader(ds_test, batch_size=batch_size, shuffle=False)

    optimizer = optim.Adam(params, lr=1e-5)

    patience = 0
    best_model = models.AlgorithmProcessor(dim_latent)
    best_model.algorithms = nn.ModuleDict(processor.algorithms.items())
    best_model.load_state_dict(copy.deepcopy(processor.state_dict()))

    # TRAINING
    if mode == 'train':
        loss_per_epoch_train = []
        loss_per_epoch_valid = []
        accuracy_per_epoch_train = {'TRANS': [], 'TIPS': [], 'BUBBLES': []}
        accuracy_per_epoch_valid = {'TRANS': [], 'TIPS': [], 'BUBBLES': []}

        for epoch in range(num_epochs):
            print(f'Epoch: {epoch}')
            processor.train()
            patience += 1
            loss_per_graph = []
            accuracy_per_graph = {'TRANS': [], 'TIPS': [], 'BUBBLES': []}
            for data in dl_train:
                # processor.process_graph(data, optimizer, loss_per_graph, accuracy_per_graph, train=True,
                #                         device=device)
                processor.process_graph_all(data,
                                            optimizer,
                                            loss_per_graph,
                                            accuracy_per_graph,
                                            train=True,
                                            device=device)

            loss_per_epoch_train.append(
                sum(loss_per_graph) / len(loss_per_graph))
            append_accuracy_list(accuracy_per_epoch_train, accuracy_per_graph,
                                 algo_list)

            # VALIDATION
            with torch.no_grad():
                processor.eval()
                loss_per_graph = []
                accuracy_per_graph = {'TRANS': [], 'TIPS': [], 'BUBBLES': []}
                for data in dl_valid:
                    processor.process_graph_all(data,
                                                optimizer,
                                                loss_per_graph,
                                                accuracy_per_graph,
                                                train=False)
                    # print(loss_per_graph)
                current_loss = sum(loss_per_graph) / len(loss_per_graph)
                if len(loss_per_epoch_valid) > 0 and current_loss < min(
                        loss_per_epoch_valid):
                    patience = 0
                    best_model.load_state_dict(
                        copy.deepcopy(processor.state_dict()))
                    torch.save(best_model.state_dict(), model_path)
                elif patience > patience_limit:
                    break
                loss_per_epoch_valid.append(current_loss)
                append_accuracy_list(accuracy_per_epoch_valid,
                                     accuracy_per_graph, algo_list)

        draw_loss_plot(loss_per_epoch_train, loss_per_epoch_valid, time_now)
        draw_accuracy_plots(accuracy_per_epoch_train, accuracy_per_epoch_valid,
                            algo_list, time_now)

        torch.save(processor.state_dict(), model_path)

    processor.load_state_dict(torch.load(model_path))

    # TESTING
    with torch.no_grad():
        processor.eval()

        loss_per_graph = []
        accuracy = {'TRANS': [], 'TIPS': [], 'BUBBLES': []}
        last_step = {'TRANS': [], 'TIPS': [], 'BUBBLES': []}

        for data in dl_test:
            processor.process_graph_all(data,
                                        optimizer,
                                        loss_per_graph,
                                        accuracy,
                                        train=False,
                                        last_step=last_step)

        print_mean_accuracy(accuracy, algo_list)