def view_model_param(MODEL_NAME, net_params):
    model = gnn_model(MODEL_NAME, net_params)
    total_param = 0
    print("MODEL DETAILS:\n")
    #print(model)
    for param in model.parameters():
        # print(param.data.size())
        total_param += np.prod(list(param.data.size()))
    print('MODEL/Total parameters:', MODEL_NAME, total_param)
    return total_param
def train_val_pipeline(MODEL_NAME, DATASET_NAME, params, net_params, dirs):
    avg_test_acc = []
    avg_train_acc = []
    avg_convergence_epochs = []

    t0 = time.time()
    per_epoch_time = []

    dataset = LoadData(DATASET_NAME)

    if MODEL_NAME in ['GCN', 'GAT']:
        if net_params['self_loop']:
            print(
                "[!] Adding graph self-loops for GCN/GAT models (central node trick)."
            )
            dataset._add_self_loops()

    trainset, valset, testset = dataset.train, dataset.val, dataset.test

    root_log_dir, root_ckpt_dir, write_file_name, write_config_file = dirs
    device = net_params['device']

    # Write the network and optimization hyper-parameters in folder config/
    with open(write_config_file + '.txt', 'w') as f:
        f.write(
            """Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n\nTotal Parameters: {}\n\n"""
            .format(DATASET_NAME, MODEL_NAME, params, net_params,
                    net_params['total_param']))

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for split_number in range(10):
            t0_split = time.time()
            log_dir = os.path.join(root_log_dir, "RUN_" + str(split_number))
            writer = SummaryWriter(log_dir=log_dir)

            # setting seeds
            random.seed(params['seed'])
            np.random.seed(params['seed'])
            torch.manual_seed(params['seed'])
            if device.type == 'cuda':
                torch.cuda.manual_seed(params['seed'])

            print("RUN NUMBER: ", split_number)
            trainset, valset, testset = dataset.train[
                split_number], dataset.val[split_number], dataset.test[
                    split_number]
            print("Training Graphs: ", len(trainset))
            print("Validation Graphs: ", len(valset))
            print("Test Graphs: ", len(testset))
            print("Number of Classes: ", net_params['n_classes'])

            model = gnn_model(MODEL_NAME, net_params)
            model = model.to(device)
            optimizer = optim.Adam(model.parameters(),
                                   lr=params['init_lr'],
                                   weight_decay=params['weight_decay'])
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='min',
                factor=params['lr_reduce_factor'],
                patience=params['lr_schedule_patience'],
                verbose=True)

            epoch_train_losses, epoch_val_losses = [], []
            epoch_train_accs, epoch_val_accs = [], []

            # batching exception for Diffpool
            drop_last = True if MODEL_NAME == 'DiffPool' else False

            if MODEL_NAME in ['RingGNN', '3WLGNN']:
                # import train functions specific for WL-GNNs
                from train.train_TUs_graph_classification import train_epoch_dense as train_epoch, evaluate_network_dense as evaluate_network

                train_loader = DataLoader(trainset,
                                          shuffle=True,
                                          collate_fn=dataset.collate_dense_gnn)
                val_loader = DataLoader(valset,
                                        shuffle=False,
                                        collate_fn=dataset.collate_dense_gnn)
                test_loader = DataLoader(testset,
                                         shuffle=False,
                                         collate_fn=dataset.collate_dense_gnn)

            else:
                # import train functions for all other GCNs
                from train.train_TUs_graph_classification import train_epoch_sparse as train_epoch, evaluate_network_sparse as evaluate_network

                train_loader = DataLoader(trainset,
                                          batch_size=params['batch_size'],
                                          shuffle=True,
                                          drop_last=drop_last,
                                          collate_fn=dataset.collate)
                val_loader = DataLoader(valset,
                                        batch_size=params['batch_size'],
                                        shuffle=False,
                                        drop_last=drop_last,
                                        collate_fn=dataset.collate)
                test_loader = DataLoader(testset,
                                         batch_size=params['batch_size'],
                                         shuffle=False,
                                         drop_last=drop_last,
                                         collate_fn=dataset.collate)

            with tqdm(range(params['epochs'])) as t:
                for epoch in t:

                    t.set_description('Epoch %d' % epoch)

                    start = time.time()

                    if MODEL_NAME in [
                            'RingGNN', '3WLGNN'
                    ]:  # since different batch training function for dense GNNs
                        epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
                            model, optimizer, device, train_loader, epoch,
                            params['batch_size'])
                    else:  # for all other models common train function
                        epoch_train_loss, epoch_train_acc, optimizer = train_epoch(
                            model, optimizer, device, train_loader, epoch)

                    epoch_val_loss, epoch_val_acc = evaluate_network(
                        model, device, val_loader, epoch)
                    _, epoch_test_acc = evaluate_network(
                        model, device, test_loader, epoch)

                    epoch_train_losses.append(epoch_train_loss)
                    epoch_val_losses.append(epoch_val_loss)
                    epoch_train_accs.append(epoch_train_acc)
                    epoch_val_accs.append(epoch_val_acc)

                    writer.add_scalar('train/_loss', epoch_train_loss, epoch)
                    writer.add_scalar('val/_loss', epoch_val_loss, epoch)
                    writer.add_scalar('train/_acc', epoch_train_acc, epoch)
                    writer.add_scalar('val/_acc', epoch_val_acc, epoch)
                    writer.add_scalar('test/_acc', epoch_test_acc, epoch)
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'], epoch)

                    _, epoch_test_acc = evaluate_network(
                        model, device, test_loader, epoch)
                    t.set_postfix(time=time.time() - start,
                                  lr=optimizer.param_groups[0]['lr'],
                                  train_loss=epoch_train_loss,
                                  val_loss=epoch_val_loss,
                                  train_acc=epoch_train_acc,
                                  val_acc=epoch_val_acc,
                                  test_acc=epoch_test_acc)

                    per_epoch_time.append(time.time() - start)

                    # Saving checkpoint
                    ckpt_dir = os.path.join(root_ckpt_dir,
                                            "RUN_" + str(split_number))
                    if not os.path.exists(ckpt_dir):
                        os.makedirs(ckpt_dir)
                    torch.save(
                        model.state_dict(),
                        '{}.pkl'.format(ckpt_dir + "/epoch_" + str(epoch)))

                    files = glob.glob(ckpt_dir + '/*.pkl')
                    for file in files:
                        epoch_nb = file.split('_')[-1]
                        epoch_nb = int(epoch_nb.split('.')[0])
                        if epoch_nb < epoch - 1:
                            os.remove(file)

                    scheduler.step(epoch_val_loss)

                    if optimizer.param_groups[0]['lr'] < params['min_lr']:
                        print("\n!! LR EQUAL TO MIN LR SET.")
                        break

                    # Stop training after params['max_time'] hours
                    if time.time() - t0_split > params[
                            'max_time'] * 3600 / 10:  # Dividing max_time by 10, since there are 10 runs in TUs
                        print('-' * 89)
                        print(
                            "Max_time for one train-val-test split experiment elapsed {:.3f} hours, so stopping"
                            .format(params['max_time'] / 10))
                        break

            _, test_acc = evaluate_network(model, device, test_loader, epoch)
            _, train_acc = evaluate_network(model, device, train_loader, epoch)
            avg_test_acc.append(test_acc)
            avg_train_acc.append(train_acc)
            avg_convergence_epochs.append(epoch)

            print("Test Accuracy [LAST EPOCH]: {:.4f}".format(test_acc))
            print("Train Accuracy [LAST EPOCH]: {:.4f}".format(train_acc))
            print("Convergence Time (Epochs): {:.4f}".format(epoch))

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early because of KeyboardInterrupt')

    print("TOTAL TIME TAKEN: {:.4f}hrs".format((time.time() - t0) / 3600))
    print("AVG TIME PER EPOCH: {:.4f}s".format(np.mean(per_epoch_time)))
    print("AVG CONVERGENCE Time (Epochs): {:.4f}".format(
        np.mean(np.array(avg_convergence_epochs))))
    # Final test accuracy value averaged over 10-fold
    print(
        """\n\n\nFINAL RESULTS\n\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}"""
        .format(
            np.mean(np.array(avg_test_acc)) * 100,
            np.std(avg_test_acc) * 100))
    print("\nAll splits Test Accuracies:\n", avg_test_acc)
    print(
        """\n\n\nFINAL RESULTS\n\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}"""
        .format(
            np.mean(np.array(avg_train_acc)) * 100,
            np.std(avg_train_acc) * 100))
    print("\nAll splits Train Accuracies:\n", avg_train_acc)

    writer.close()
    """
        Write the results in out/results folder
    """
    with open(write_file_name + '.txt', 'w') as f:
        f.write("""Dataset: {},\nModel: {}\n\nparams={}\n\nnet_params={}\n\n{}\n\nTotal Parameters: {}\n\n
    FINAL RESULTS\nTEST ACCURACY averaged: {:.4f} with s.d. {:.4f}\nTRAIN ACCURACY averaged: {:.4f} with s.d. {:.4f}\n\n
    Average Convergence Time (Epochs): {:.4f} with s.d. {:.4f}\nTotal Time Taken: {:.4f} hrs\nAverage Time Per Epoch: {:.4f} s\n\n\nAll Splits Test Accuracies: {}"""\
          .format(DATASET_NAME, MODEL_NAME, params, net_params, model, net_params['total_param'],
                  np.mean(np.array(avg_test_acc))*100, np.std(avg_test_acc)*100,
                  np.mean(np.array(avg_train_acc))*100, np.std(avg_train_acc)*100,
                  np.mean(avg_convergence_epochs), np.std(avg_convergence_epochs),
               (time.time()-t0)/3600, np.mean(per_epoch_time), avg_test_acc))
Пример #3
0
        val_losses, val_accuracy = [], []
        test_losses, test_accuracy = [], []
        for idx_split, (d_train, d_val, d_test) in enumerate(split):
            split.set_description(f'Split #{idx_split}')

            set_seed(params['seed'])

            train_loader = DataLoader(d_train, batch_size=args.batch_size, shuffle=True, collate_fn=dataset.collate)
            val_loader = DataLoader(d_val, batch_size=args.batch_size, collate_fn=dataset.collate)
            test_loader = DataLoader(d_test, batch_size=args.batch_size, collate_fn=dataset.collate)

            net_param = deepcopy(_net_param)
            net_param['in_dim'] = in_dim
            net_param['n_classes'] = dataset.all.num_labels

            net = gnn_model(args.net, net_param)
            net.cuda()

            optimizer = optim.Adam(net.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                                factor=params['lr_reduce_factor'],
                                                                patience=params['lr_schedule_patience'])

            with tqdm(range(params['epochs'])) as epochs:
                for e in epochs:
                    epochs.set_description(f'Epoch #{e}')
                    
                    train_loss, train_acc = train(net, train_loader, optimizer)
                    val_loss, val_acc = val(net, val_loader)

                    epochs.set_postfix(lr=optimizer.param_groups[0]['lr'],