Ejemplo n.º 1
0
def test(test_dataloader, model=None):

    if model is None:
        obj = torch.load(args.save_path + '.model',
                         map_location=lambda storage, loc: storage)
        train_args = obj['args']
        model = GNN(word_vocab_size=WORD_VOCAB_SIZE,
                    char_vocab_size=CHAR_VOCAB_SIZE,
                    d_output=d_output,
                    args=train_args)
        model.load_state_dict(obj['model'])
        model.cuda()
        print('Model loaded.')

    test_acc, test_prec, test_recall, test_f1 = evaluate(model,
                                                         test_dataloader,
                                                         output=True,
                                                         args=args)
    print('######## prec   : ', acc_to_str(test_prec))
    print('######## recall : ', acc_to_str(test_recall))
    print('######## f1     : ', acc_to_str(test_f1))
    prec, recall, f1 = np.mean(list(test_prec.values())), np.mean(
        list(test_recall.values())), np.mean(list(test_f1.values()))
    print(prec, recall, f1)
    result_obj['test_prec'] = prec
    result_obj['test_recall'] = recall
    result_obj['test_f1'] = f1
    result_obj['test_info'] = '\n'.join(
        [acc_to_str(test_prec),
         acc_to_str(test_recall),
         acc_to_str(test_f1)])
Ejemplo n.º 2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with DGL')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help='random seed to use (default: 42)')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help='GNN to use, which can be from '
        '[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset',
                        action='store_true',
                        help='use 10% of the training set for training')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory. If not specified, '
                        'tensorboard will not be used.')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device("cuda:" + str(args.device))
    else:
        device = torch.device("cpu")

    ### automatic dataloading and splitting
    dataset = SampleDglPCQM4MDataset(root='dataset/')

    # split_idx['train'], split_idx['valid'], split_idx['test']
    # separately gives a 1D int64 tensor
    split_idx = dataset.get_idx_split()
    split_idx["train"] = split_idx["train"].type(torch.LongTensor)
    split_idx["test"] = split_idx["test"].type(torch.LongTensor)
    split_idx["valid"] = split_idx["valid"].type(torch.LongTensor)

    ### automatic evaluator.
    evaluator = PCQM4MEvaluator()

    if args.train_subset:
        subset_ratio = 0.1
        subset_idx = torch.randperm(len(
            split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
        train_loader = DataLoader(dataset[split_idx["train"][subset_idx]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  collate_fn=collate_dgl)
    else:
        train_loader = DataLoader(dataset[split_idx["train"]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  collate_fn=collate_dgl)

    valid_loader = DataLoader(dataset[split_idx["valid"]],
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=collate_dgl)

    if args.save_test_dir is not '':
        test_loader = DataLoader(dataset[split_idx["test"]],
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_dgl)

    if args.checkpoint_dir is not '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual-diffpool':
        model = DiffPoolGNN(gnn_type='gin', virtual_node=True,
                            **shared_params).to(device)
    elif args.gnn == 'gin-virtual-bayes-diffpool':
        model = BayesDiffPoolGNN(gnn_type='gin',
                                 virtual_node=True,
                                 **shared_params).to(device)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if args.log_dir is not '':
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)
    """ load from latest checkpoint """
    # start epoch (default = 1, unless resuming training)
    firstEpoch = 1
    # check if checkpoint exist -> load model
    checkpointFile = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
    if os.path.exists(checkpointFile):
        # load checkpoint file
        checkpointData = torch.load(checkpointFile)
        firstEpoch = checkpointData["epoch"]
        model.load_state_dict(checkpointData["model_state_dict"])
        optimizer.load_state_dict(checkpointData["optimizer_state_dict"])
        scheduler.load_state_dict(checkpointData["scheduler_state_dict"])
        best_valid_mae = checkpointData["best_val_mae"]
        num_params = checkpointData["num_params"]
        print(
            "Loaded existing weights from {}. Continuing from epoch: {} with best valid MAE: {}"
            .format(checkpointFile, firstEpoch, best_valid_mae))

    for epoch in range(firstEpoch, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train_mae = train(model, device, train_loader, optimizer, args.gnn)

        print('Evaluating...')
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae})

        if args.log_dir is not '':
            writer.add_scalar('valid/mae', valid_mae, epoch)
            writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir is not '':
                print('Saving checkpoint...')
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_mae': best_valid_mae,
                    'num_params': num_params
                }
                torch.save(checkpoint,
                           os.path.join(args.checkpoint_dir, 'checkpoint.pt'))

            if args.save_test_dir is not '':
                print('Predicting on test data...')
                y_pred = test(model, device, test_loader)
                print('Saving test submission file...')
                evaluator.save_test_submission({'y_pred': y_pred},
                                               args.save_test_dir)

        scheduler.step()

        print(f'Best validation MAE so far: {best_valid_mae}')

    if args.log_dir is not '':
        writer.close()
Ejemplo n.º 3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with DGL')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help='random seed to use (default: 42)')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help='GNN to use, which can be from '
        '[gin, gin-virtual, gcn, gcn-virtual] (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device("cuda:" + str(args.device))
    else:
        device = torch.device("cpu")

    ### automatic data loading and splitting
    ### Read in the raw SMILES strings
    smiles_dataset = PCQM4MDataset(root='dataset/', only_smiles=True)
    split_idx = smiles_dataset.get_idx_split()

    test_smiles_dataset = [smiles_dataset[i] for i in split_idx['test']]
    onthefly_dataset = OnTheFlyPCQMDataset(test_smiles_dataset)
    test_loader = DataLoader(onthefly_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             collate_fn=collate_dgl)

    ### automatic evaluator.
    evaluator = PCQM4MEvaluator()

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True,
                    **shared_params).to(device)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
    if not os.path.exists(checkpoint_path):
        raise RuntimeError(f'Checkpoint file not found at {checkpoint_path}')

    ## reading in checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])

    print('Predicting on test data...')
    y_pred = test(model, device, test_loader)
    print('Saving test submission file...')
    evaluator.save_test_submission({'y_pred': y_pred}, args.save_test_dir)
Ejemplo n.º 4
0
def train(seed):

    print('random seed:', seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.enabled = False

    dataset = read_data('../data', '../graph')
    label2id = dataset.label2id
    print(label2id)

    vocab_size = dataset.vocab_size
    output_dim = len(label2id)

    def acc_to_str(acc):
        s = ['%s:%.3f' % (label, acc[label]) for label in acc]
        return '{' + ', '.join(s) + '}'

    cross_res = {label: [] for label in label2id if label != 'O'}
    output_file = open('%s.mistakes' % args.output, 'w')

    for cross_valid in range(5):

        model = GNN(vocab_size=vocab_size, output_dim=output_dim, args=args)
        model.cuda()
        # print vocab_size

        dataset.split_train_valid_test([0.8, 0.1, 0.1], 5, cross_valid)
        print('train:', len(dataset.train), 'valid:', len(dataset.valid),
              'test:', len(dataset.test))

        def evaluate(model, datalist, output_file=None):
            if output_file != None:
                output_file.write(
                    '#############################################\n')
            correct = {label: 0 for label in label2id if label != 'O'}
            total = len(datalist)
            model.eval()
            print_cnt = 0
            for data in datalist:
                word, feat = Variable(data.input_word).cuda(), Variable(
                    data.input_feat).cuda()
                a_ud, a_lr = Variable(data.a_ud,
                                      requires_grad=False).cuda(), Variable(
                                          data.a_lr,
                                          requires_grad=False).cuda()
                mask = Variable(data.mask, requires_grad=False).cuda()
                if args.globalnode:
                    logprob, form = model(word, feat, mask, a_ud, a_lr)
                    logprob = logprob.data.view(-1, output_dim)
                else:
                    logprob = model(word, feat, mask, a_ud,
                                    a_lr).data.view(-1, output_dim)
                mask = mask.data.view(-1)
                y_pred = torch.LongTensor(output_dim)
                for i in range(output_dim):
                    prob = logprob[:, i].exp() * mask
                    y_pred[i] = prob.topk(k=1)[1][0]
                # y_pred = logprob.topk(k=1,dim=0)[1].view(-1)
                for label in label2id:
                    if label == 'O':
                        continue
                    labelid = label2id[label]
                    if data.output.view(-1)[y_pred[labelid]] == labelid:
                        correct[label] += 1
                    else:
                        if output_file != None:
                            num_sent, sent_len, word_len = data.input_word.size(
                            )
                            id = y_pred[label2id[label]]
                            word = data.words[data.sents[int(
                                id / sent_len)][id % sent_len]]
                            output_file.write(
                                '%d %d %s %s\n' %
                                (data.set_id, data.fax_id, label, word))
            return {label: float(correct[label]) / total for label in correct}

        batch = 1

        weight = torch.zeros(len(label2id))
        for label, id in label2id.items():
            weight[id] = 1 if label == 'O' else 10
        loss_function = nn.NLLLoss(weight.cuda(), reduce=False)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr / float(batch),
                                     weight_decay=args.wd)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)

        best_acc = -1
        wait = 0

        for epoch in range(args.epochs):
            sum_loss = 0
            model.train()
            # random.shuffle(dataset.train)
            for idx, data in enumerate(dataset.train):
                word, feat = Variable(data.input_word).cuda(), Variable(
                    data.input_feat).cuda()
                a_ud, a_lr = Variable(data.a_ud,
                                      requires_grad=False).cuda(), Variable(
                                          data.a_lr,
                                          requires_grad=False).cuda()
                mask = Variable(data.mask, requires_grad=False).cuda()
                true_output = Variable(data.output).cuda()
                if args.globalnode:
                    logprob, form = model(word, feat, mask, a_ud, a_lr)
                else:
                    logprob = model(word, feat, mask, a_ud, a_lr)
                loss = torch.mean(
                    mask.view(-1) * loss_function(logprob.view(-1, output_dim),
                                                  true_output.view(-1)))
                if args.globalnode:
                    true_form = Variable(torch.LongTensor([data.set_id - 1
                                                           ])).cuda()
                    loss = loss + 0.1 * F.nll_loss(form, true_form)
                sum_loss += loss.data.sum()
                loss.backward()
                if (idx + 1) % batch == 0 or idx + 1 == len(dataset.train):
                    optimizer.step()
                    optimizer.zero_grad()
            train_acc = evaluate(model, dataset.train)
            valid_acc = evaluate(model, dataset.valid)
            test_acc = evaluate(model, dataset.test)
            print('Epoch %d:  Train Loss: %.3f  Train: %s  Valid: %s  Test: %s' \
                % (epoch, sum_loss, acc_to_str(train_acc), acc_to_str(valid_acc), acc_to_str(test_acc)))
            # scheduler.step()

            acc = np.log(list(valid_acc.values())).sum()
            if epoch < 6:
                continue
            if acc >= best_acc:
                torch.save(model.state_dict(), args.output + '.model')
            wait = 0 if acc > best_acc else wait + 1
            best_acc = max(acc, best_acc)
            if wait >= args.patience:
                break

        model.load_state_dict(torch.load(args.output + '.model'))
        test_acc = evaluate(model, dataset.test, output_file=output_file)
        print('########', acc_to_str(test_acc))
        for label in test_acc:
            cross_res[label].append(test_acc[label])

    print("Cross Validation Result:")
    for label in cross_res:
        cross_res[label] = np.mean(cross_res[label])
    print(acc_to_str(cross_res))
    return cross_res
Ejemplo n.º 5
0
def train(dataset):

    print('random seed:', args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.enabled = False

    cross_res = {label: [] for label in label2id if label != 'O'}

    for cross_valid in range(1):

        # print('cross_valid', cross_valid)

        model = GNN(word_vocab_size=WORD_VOCAB_SIZE,
                    char_vocab_size=CHAR_VOCAB_SIZE,
                    d_output=d_output,
                    args=args)
        model.cuda()
        # print vocab_size

        # print('split dataset')
        # dataset.split_train_valid_test_bycase([0.5, 0.1, 0.4], 5, cross_valid)
        print('train:', len(dataset.train), 'valid:', len(dataset.valid),
              'test:', len(dataset.test))
        sys.stdout.flush()

        train_dataloader = DataLoader(dataset.train,
                                      batch_size=args.batch,
                                      shuffle=True)
        valid_dataloader = DataLoader(dataset.valid, batch_size=args.batch)
        test_dataloader = DataLoader(dataset.test, batch_size=args.batch)

        weight = torch.zeros(len(label2id))
        for label, idx in label2id.items():
            weight[idx] = 1 if label == 'O' else 2
        loss_function = nn.CrossEntropyLoss(weight.cuda(), reduce=False)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr,
                                     weight_decay=args.wd)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)

        best_acc = -1
        wait = 0
        batch_cnt = 0

        for epoch in range(args.epochs):
            total_loss = 0
            pending_loss = None
            model.train()
            # random.shuffle(dataset.train)
            load_time, forward_time, backward_time = 0, 0, 0
            model.clear_time()

            train_log = open(args.save_path + '_train.log', 'w')
            for tensors, batch in tqdm(train_dataloader,
                                       file=train_log,
                                       mininterval=60):
                # print(batch[0].case_id, batch[0].doc_id, batch[0].page_id)
                start = time.time()
                data, data_word, pos, length, mask, label, adjs = to_var(
                    tensors, cuda=args.cuda)
                batch_size, docu_len, sent_len, word_len = data.size()
                load_time += (time.time() - start)

                start = time.time()
                logit = model(data, data_word, pos, length, mask, adjs)
                forward_time += (time.time() - start)

                start = time.time()
                if args.crf:
                    logit = logit.view(batch_size * docu_len, sent_len, -1)
                    mask = mask.view(batch_size * docu_len, -1)
                    length = length.view(batch_size * docu_len)
                    label = label.view(batch_size * docu_len, -1)
                    loss = -model.crf_layer.loglikelihood(
                        logit, mask, length, label)
                    loss = torch.masked_select(loss, torch.gt(length,
                                                              0)).mean()
                else:
                    loss = loss_function(logit.view(-1, d_output),
                                         label.view(-1))
                    loss = torch.masked_select(loss, mask.view(-1)).mean()
                total_loss += loss.data.sum()
                # print(total_loss, batch[0].case_id, batch[0].doc_id, batch[0].page_id)
                if math.isnan(total_loss):
                    print('Loss is NaN!')
                    exit()

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                backward_time += (time.time() - start)

                batch_cnt += 1
                if batch_cnt % 20000 != 0:
                    continue
                # print('load %f   forward %f   backward %f'%(load_time, forward_time, backward_time))
                # model.print_time()
                valid_acc, valid_prec, valid_recall, valid_f1 = evaluate(
                    model, valid_dataloader, args=args)

                print('Epoch %d:  Train Loss: %.3f  Valid Acc: %.5f' %
                      (epoch, total_loss, valid_acc))
                # print(acc_to_str(valid_f1))
                # scheduler.step()

                acc = np.mean(list(valid_f1.values()))  # valid_acc
                print(acc)
                if acc >= best_acc:
                    obj = {'args': args, 'model': model.state_dict()}
                    torch.save(obj, args.save_path + '.model')
                    result_obj['valid_prec'] = np.mean(
                        list(valid_prec.values()))
                    result_obj['valid_recall'] = np.mean(
                        list(valid_recall.values()))
                    result_obj['valid_f1'] = np.mean(list(valid_f1.values()))
                wait = 0 if acc > best_acc else wait + 1
                best_acc = max(acc, best_acc)

                model.train()
                sys.stdout.flush()
                if wait >= args.patience:
                    break

            train_log.close()
            os.remove(args.save_path + '_train.log')

            if wait >= args.patience:
                break

        obj = torch.load(args.save_path + '.model')
        model.load_state_dict(obj['model'])

        test(test_dataloader, model)

    # print("Cross Validation Result:")
    # for label in cross_res:
    #     cross_res[label] = np.mean(cross_res[label])
    # print(acc_to_str(cross_res))
    return cross_res
Ejemplo n.º 6
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='GNN baselines on pcqm4m with Pytorch Geometrics')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument(
        '--gnn',
        type=str,
        default='gin-virtual',
        help=
        'GNN gin, gin-virtual, or gcn, or gcn-virtual (default: gin-virtual)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default='sum',
        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--drop_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--num_layers',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5)')
    parser.add_argument(
        '--emb_dim',
        type=int,
        default=600,
        help='dimensionality of hidden units in GNNs (default: 600)')
    parser.add_argument('--train_subset', action='store_true')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--log_dir',
                        type=str,
                        default="",
                        help='tensorboard log directory')
    parser.add_argument('--use_triplet_loss', action='store_true')
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default='',
                        help='directory to save checkpoint')
    parser.add_argument('--save_test_dir',
                        type=str,
                        default='',
                        help='directory to save test submission file')
    args = parser.parse_args()

    print(args)

    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    random.seed(42)

    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    ### automatic dataloading and splitting
    dataset = PygPCQM4MDataset(root='dataset/')

    split_idx = dataset.get_idx_split()

    ### automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()

    #     if args.use_triplet_loss:
    #         if args.train_subset:
    #             subset_ratio = 0.1
    #             subset_idx = torch.randperm(len(split_idx["train"]))[:int(subset_ratio*len(split_idx["train"]))]
    #             anchor_loader = DataLoader(dataset[split_idx["train"][subset_idx]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    #             positive_loader = DataLoader(dataset[split_idx["train"][subset_idx]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    #             negative_loader = DataLoader(dataset[split_idx["train"][subset_idx]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    #         else:
    #             anchor_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    #             positive_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    #             negative_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    #     elif args.train_subset:
    if args.train_subset:
        subset_ratio = 0.1
        subset_idx = torch.randperm(len(
            split_idx["train"]))[:int(subset_ratio * len(split_idx["train"]))]
        train_loader = DataLoader(dataset[split_idx["train"][subset_idx]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
    else:
        train_loader = DataLoader(dataset[split_idx["train"]],
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    valid_loader = DataLoader(dataset[split_idx["valid"]],
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)

    if args.save_test_dir is not '':
        test_loader = DataLoader(dataset[split_idx["test"]],
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    if args.checkpoint_dir is not '':
        os.makedirs(args.checkpoint_dir, exist_ok=True)

    shared_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    if args.gnn == 'gin':
        model = GNN(gnn_type='gin', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual':
        model = GNN(gnn_type='gin', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gcn':
        model = GNN(gnn_type='gcn', virtual_node=False,
                    **shared_params).to(device)
    elif args.gnn == 'gcn-virtual':
        model = GNN(gnn_type='gcn', virtual_node=True,
                    **shared_params).to(device)
    elif args.gnn == 'gin-virtual-bnn':
        model = BayesianGNN(gnn_type='gin',
                            virtual_node=True,
                            last_layer_only=False,
                            **shared_params).to(device)
    elif args.gnn == 'gin-virtual-bnn-lastLayer':
        model = BayesianGNN(gnn_type='gin',
                            virtual_node=True,
                            last_layer_only=True,
                            **shared_params).to(device)
    else:
        raise ValueError('Invalid GNN type')

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}')

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    if args.log_dir is not '':
        writer = SummaryWriter(log_dir=args.log_dir)

    best_valid_mae = 1000

    if args.train_subset:
        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
        args.epochs = 1000
    else:
        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)

    # start epoch (default = 1, unless resuming training)
    firstEpoch = 1
    # check if checkpoint exist -> load model
    checkpointFile = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
    if os.path.exists(checkpointFile):
        # load checkpoint file
        checkpointData = torch.load(checkpointFile)
        firstEpoch = checkpointData["epoch"]
        model.load_state_dict(checkpointData["model_state_dict"])
        optimizer.load_state_dict(checkpointData["optimizer_state_dict"])
        scheduler.load_state_dict(checkpointData["scheduler_state_dict"])
        best_valid_mae = checkpointData["best_val_mae"]
        num_params = checkpointData["num_params"]
        print(
            "Loaded existing weights from {}. Continuing from epoch: {} with best valid MAE: {}"
            .format(checkpointFile, firstEpoch, best_valid_mae))

    if args.use_triplet_loss:
        model.gnn_node.register_forward_hook(get_activation('gnn_node'))

    for epoch in range(firstEpoch, args.epochs + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        if args.use_triplet_loss:
            train_mae = triplet_loss_train(model, device, train_loader,
                                           dataset, optimizer, args.gnn, args)
        else:
            train_mae = train(model, device, train_loader, optimizer, args.gnn)

        print('Evaluating...')
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae})

        if args.log_dir is not '':
            writer.add_scalar('valid/mae', valid_mae, epoch)
            writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.checkpoint_dir is not '':
                print('Saving checkpoint...')
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_mae': best_valid_mae,
                    'num_params': num_params
                }
                torch.save(checkpoint,
                           os.path.join(args.checkpoint_dir, 'checkpoint.pt'))

            if args.save_test_dir is not '':
                print('Predicting on test data...')
                y_pred = test(model, device, test_loader)
                print('Saving test submission file...')
                evaluator.save_test_submission({'y_pred': y_pred},
                                               args.save_test_dir)

        scheduler.step()

        print(f'Best validation MAE so far: {best_valid_mae}')

    if args.log_dir is not '':
        writer.close()