Exemplo n.º 1
0
def test_early_stopping_high_metric():
    for metric in ['r2', 'roc_auc_score']:
        model1 = nn.Linear(2, 3)
        stopper = EarlyStopping(mode=None,
                                patience=1,
                                filename='test.pkl',
                                metric=metric)

        # Save model in the first step
        stopper.step(1., model1)
        model1.weight.data = model1.weight.data + 1
        model2 = nn.Linear(2, 3)
        stopper.load_checkpoint(model2)
        assert not torch.allclose(model1.weight, model2.weight)

        # Save model checkpoint with performance improvement
        model1.weight.data = model1.weight.data + 1
        stopper.step(2., model1)
        stopper.load_checkpoint(model2)
        assert torch.allclose(model1.weight, model2.weight)

        # Stop when no improvement observed
        model1.weight.data = model1.weight.data + 1
        assert stopper.step(0.5, model1)
        stopper.load_checkpoint(model2)
        assert not torch.allclose(model1.weight, model2.weight)

        remove_file('test.pkl')
Exemplo n.º 2
0
def main(args, exp_config, train_set, val_set, test_set):
    if args['featurizer_type'] != 'pre_train':
        exp_config['in_node_feats'] = args['node_featurizer'].feat_size()
        if args['edge_featurizer'] is not None:
            exp_config['in_edge_feats'] = args['edge_featurizer'].feat_size()
    exp_config.update({
        'n_tasks': args['n_tasks'],
        'model': args['model']
    })

    train_loader = DataLoader(dataset=train_set, batch_size=exp_config['batch_size'], shuffle=True,
                              collate_fn=collate_molgraphs, num_workers=args['num_workers'])
    val_loader = DataLoader(dataset=val_set, batch_size=exp_config['batch_size'],
                            collate_fn=collate_molgraphs, num_workers=args['num_workers'])
    test_loader = DataLoader(dataset=test_set, batch_size=exp_config['batch_size'],
                             collate_fn=collate_molgraphs, num_workers=args['num_workers'])

    if args['pretrain']:
        args['num_epochs'] = 0
        if args['featurizer_type'] == 'pre_train':
            model = load_pretrained('{}_{}'.format(
                args['model'], args['dataset'])).to(args['device'])
        else:
            model = load_pretrained('{}_{}_{}'.format(
                args['model'], args['featurizer_type'], args['dataset'])).to(args['device'])
    else:
        model = load_model(exp_config).to(args['device'])
        loss_criterion = nn.SmoothL1Loss(reduction='none')
        optimizer = Adam(model.parameters(), lr=exp_config['lr'],
                         weight_decay=exp_config['weight_decay'])
        stopper = EarlyStopping(patience=exp_config['patience'],
                                filename=args['result_path'] + '/model.pth',
                                metric=args['metric'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
            epoch + 1, args['num_epochs'], args['metric'],
            val_score, args['metric'], stopper.best_score))

        if early_stop:
            break

    if not args['pretrain']:
        stopper.load_checkpoint(model)
    val_score = run_an_eval_epoch(args, model, val_loader)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print('val {} {:.4f}'.format(args['metric'], val_score))
    print('test {} {:.4f}'.format(args['metric'], test_score))

    with open(args['result_path'] + '/eval.txt', 'w') as f:
        if not args['pretrain']:
            f.write('Best val {}: {}\n'.format(args['metric'], stopper.best_score))
        f.write('Val {}: {}\n'.format(args['metric'], val_score))
        f.write('Test {}: {}\n'.format(args['metric'], test_score))
Exemplo n.º 3
0
def main(args, exp_config, train_set, val_set, test_set):
    # Record settings
    exp_config.update({
        'model': args['model'],
        'in_feats': args['node_featurizer'].feat_size(),
        'n_tasks': args['n_tasks']
    })

    # Set up directory for saving results
    args = init_trial_path(args)

    train_loader = DataLoader(dataset=train_set,
                              batch_size=exp_config['batch_size'],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=exp_config['batch_size'],
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=exp_config['batch_size'],
                             collate_fn=collate_molgraphs)
    model = load_model(exp_config).to(args['device'])

    loss_criterion = nn.SmoothL1Loss(reduction='none')
    optimizer = Adam(model.parameters(),
                     lr=exp_config['lr'],
                     weight_decay=exp_config['weight_decay'])
    stopper = EarlyStopping(patience=exp_config['patience'],
                            filename=args['trial_path'] + '/model.pth')

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_criterion,
                          optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.
            format(epoch + 1, args['num_epochs'], args['metric'], val_score,
                   args['metric'], stopper.best_score))

        if early_stop:
            break

    stopper.load_checkpoint(model)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print('test {} {:.4f}'.format(args['metric'], test_score))

    with open(args['trial_path'] + '/eval.txt', 'w') as f:
        f.write('Best val {}: {}\n'.format(args['metric'], stopper.best_score))
        f.write('Test {}: {}\n'.format(args['metric'], test_score))

    with open(args['trial_path'] + '/configure.json', 'w') as f:
        json.dump(exp_config, f, indent=2)

    return args['trial_path'], stopper.best_score
Exemplo n.º 4
0
def main(args):
    args['device'] = torch.device(
        "cuda: 0") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(args['random_seed'])

    dataset = PubChemBioAssayAromaticity(
        smiles_to_graph=args['smiles_to_graph'],
        node_featurizer=args.get('node_featurizer', None),
        edge_featurizer=args.get('edge_featurizer', None))
    train_set, val_set, test_set = RandomSplitter.train_val_test_split(
        dataset,
        frac_train=args['frac_train'],
        frac_val=args['frac_val'],
        frac_test=args['frac_test'],
        random_state=args['random_seed'])
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_molgraphs)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = load_pretrained(args['exp'])
    else:
        model = load_model(args)
        loss_fn = nn.MSELoss(reduction='none')
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args['lr'],
                                     weight_decay=args['weight_decay'])
        stopper = EarlyStopping(mode=args['mode'], patience=args['patience'])
    model.to(args['device'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.
            format(epoch + 1, args['num_epochs'], args['metric_name'],
                   val_score, args['metric_name'], stopper.best_score))

        if early_stop:
            break

    if not args['pre_trained']:
        stopper.load_checkpoint(model)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print('test {} {:.4f}'.format(args['metric_name'], test_score))
Exemplo n.º 5
0
def main(args, node_featurizer, edge_featurizer, train_set, val_set, test_set):
    # Record starting time
    t0 = time.time()

    train_mean, train_std = get_label_mean_and_std(train_set)
    train_mean, train_std = train_mean.to(args['device']), train_std.to(
        args['device'])
    args['train_mean'], args['train_std'] = train_mean, train_std

    train_loader = DataLoader(dataset=train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=collate)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=args['batch_size'],
                            shuffle=False,
                            collate_fn=collate)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=args['batch_size'],
                             shuffle=False,
                             collate_fn=collate)

    model = load_model(args, node_featurizer,
                       edge_featurizer).to(args['device'])
    criterion = nn.SmoothL1Loss(reduction='none')
    optimizer = Adam(model.parameters(),
                     lr=args['lr'],
                     weight_decay=args['weight_decay'])
    stopper = EarlyStopping(patience=args['patience'],
                            filename=args['result_path'] + '/model.pth')

    for epoch in range(args['num_epochs']):
        loss, train_r2, train_mae = run_a_train_epoch(args, model,
                                                      train_loader, criterion,
                                                      optimizer)
        print('Epoch {:d}/{:d} | training | averaged loss {:.4f} | '
              'averaged r2 {:.4f} | averaged mae {:.4f}'.format(
                  epoch + 1, args['num_epochs'], float(np.mean(loss)),
                  float(np.mean(train_r2)), float(np.mean(train_mae))))

        # Validation and early stop
        val_r2, val_mae = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(float(np.mean(val_r2)), model)
        print(
            'Epoch {:d}/{:d} | validation | current r2 {:.4f} | best r2 {:.4f} | mae {:.4f}'
            .format(epoch + 1, args['num_epochs'], float(np.mean(val_r2)),
                    stopper.best_score, float(np.mean(val_mae))))

        if early_stop:
            break

    print('It took {:.4f}s to complete the task'.format(time.time() - t0))
    stopper.load_checkpoint(model)
    log_model_evaluation(args, model, train_loader, val_loader, test_loader)
Exemplo n.º 6
0
def main(args):
    args['device'] = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(args['random_seed'])

    # Interchangeable with other datasets
    dataset, train_set, val_set, test_set = load_dataset_for_classification(
        args)
    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              collate_fn=collate_molgraphs,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_molgraphs)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = load_pretrained(args['exp'])
    else:
        args['n_tasks'] = dataset.n_tasks
        model = load_model(args)
        loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights(
            torch.tensor(train_set.indices)).to(args['device']),
                                           reduction='none')
        optimizer = Adam(model.parameters(), lr=args['lr'])
        stopper = EarlyStopping(patience=args['patience'])
    model.to(args['device'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_criterion,
                          optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.
            format(epoch + 1, args['num_epochs'], args['metric_name'],
                   val_score, args['metric_name'], stopper.best_score))
        if early_stop:
            break

    if not args['pre_trained']:
        stopper.load_checkpoint(model)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print('test {} {:.4f}'.format(args['metric_name'], test_score))
Exemplo n.º 7
0
def main(args):
    args['device'] = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(args['random_seed'])

    train_set, val_set, test_set = load_dataset_for_regression(args)
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=args['batch_size'],
                            shuffle=True,
                            collate_fn=collate_molgraphs)
    if test_set is not None:
        test_loader = DataLoader(dataset=test_set,
                                 batch_size=args['batch_size'],
                                 collate_fn=collate_molgraphs)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = load_pretrained(args['exp'])
    else:
        model = load_model(args)
        loss_fn = nn.MSELoss(reduction='none')
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args['lr'],
                                     weight_decay=args['weight_decay'])
        stopper = EarlyStopping(mode='lower', patience=args['patience'])
    model.to(args['device'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.
            format(epoch + 1, args['num_epochs'], args['metric_name'],
                   val_score, args['metric_name'], stopper.best_score))

        if early_stop:
            break

    if test_set is not None:
        if not args['pre_trained']:
            stopper.load_checkpoint(model)
        test_score = run_an_eval_epoch(args, model, test_loader)
        print('test {} {:.4f}'.format(args['metric_name'], test_score))
Exemplo n.º 8
0
def test_early_stopping_low():
    model1 = nn.Linear(2, 3)
    stopper = EarlyStopping(mode='lower', patience=1, filename='test.pkl')

    # Save model in the first step
    stopper.step(1., model1)
    model1.weight.data = model1.weight.data + 1
    model2 = nn.Linear(2, 3)
    stopper.load_checkpoint(model2)
    assert not torch.allclose(model1.weight, model2.weight)

    # Save model checkpoint with performance improvement
    model1.weight.data = model1.weight.data + 1
    stopper.step(0.5, model1)
    stopper.load_checkpoint(model2)
    assert torch.allclose(model1.weight, model2.weight)

    # Stop when no improvement observed
    model1.weight.data = model1.weight.data + 1
    assert stopper.step(2, model1)
    stopper.load_checkpoint(model2)
    assert not torch.allclose(model1.weight, model2.weight)

    remove_file('test.pkl')
Exemplo n.º 9
0
def main(args, train_set, val_set, test_set):
    # Set up directory for saving results
    args = init_trial_path(args)

    train_loader = DataLoader(dataset=train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_molgraphs)
    model = load_model(args).to(args['device'])

    loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
    optimizer = Adam(model.parameters(),
                     lr=args['lr'],
                     weight_decay=args['weight_decay'])
    stopper = EarlyStopping(patience=args['patience'],
                            filename=args['trial_path'] + '/model.pth')

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_criterion,
                          optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.
            format(epoch + 1, args['num_epochs'], args['metric'], val_score,
                   args['metric'], stopper.best_score))

        if early_stop:
            break

    test_score = run_an_eval_epoch(args, model, test_loader)
    print('test {} {:.4f}'.format(args['metric'], test_score))

    with open(args['trial_path'] + '/eval.txt', 'w') as f:
        f.write('Best val {}: {}\n'.format(args['metric'], stopper.best_score))
        f.write('Test {}: {}\n'.format(args['metric'], test_score))

    return args, stopper.best_score
Exemplo n.º 10
0
def main(args):
    args['device'] = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(args['random_seed'])

    train_g, train_y, val_g, val_y = load_data(0)

    train_set = list(zip(train_g, train_y))
    val_set = list(zip(val_g, val_y))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs)

    model = load_model(args)
    loss_fn = nn.MSELoss(reduction='none')
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args['lr'],
                                 weight_decay=args['weight_decay'])
    stopper = EarlyStopping(mode=args['mode'], patience=args['patience'])
    model.to(args['device'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)

        # Validation and early stop
        val_score, prc_auc = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        validation_score.append(val_score)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, prc_auc_score {:.4f}, best validation {} {:.4f}'
            .format(epoch + 1, args['num_epochs'], args['metric_name'],
                    val_score, prc_auc, args['metric_name'],
                    stopper.best_score))

        if early_stop:
            break
Exemplo n.º 11
0
    def train(self, model, train_loader, val_loader, epochs, lr, patience,
              metric):
        args = dict(
            patience=patience,
            num_epochs=epochs,
            lr=lr,
            device=self.device,
            metric_name=metric,
        )
        self.args = args
        stopper = EarlyStopping(patience=args["patience"])
        loss_fn = nn.MSELoss(reduction="none")
        optimizer = optim.Adam(model.parameters(), lr=args["lr"])
        for epoch in range(args["num_epochs"]):
            self.run_a_train_epoch(args, epoch, model, train_loader, loss_fn,
                                   optimizer)
            # Validation and early stop
            if val_loader is not None:
                val_score = self.run_an_eval_epoch(args, model, val_loader)
                early_stop = stopper.step(val_score, model)
                print(
                    "epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}"
                    .format(
                        epoch + 1,
                        args["num_epochs"],
                        args["metric_name"],
                        val_score,
                        args["metric_name"],
                        stopper.best_score,
                    ))
                if early_stop:
                    self.model = model
                    self.args = args
                    break

        self.model = model
        self.args = args
Exemplo n.º 12
0
model = load_model(args)
loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(),
                             lr=args['lr'],
                             weight_decay=args['weight_decay'])
stopper = EarlyStopping(mode=args['mode'], patience=args['patience'])
model.to(args['device'])

for epoch in range(args['num_epochs']):
    # Train
    run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)

    # Validation and early stop
    val_score, prc_auc = run_an_eval_epoch(args, model, val_loader)
    early_stop = stopper.step(val_score, model)
    validation_score.append(val_score)
    print(
        'epoch {:d}/{:d}, validation {} {:.4f}, prc_auc_score {:.4f}, best validation {} {:.4f}'
        .format(epoch + 1, args['num_epochs'], args['metric_name'], val_score,
                prc_auc, args['metric_name'], stopper.best_score))

    if early_stop:
        break

# %%

e = len(epoch_roc_accuracies) + 1
plt.plot([i for i in range(1, e)],
         epoch_roc_accuracies,
         c='b',
Exemplo n.º 13
0
def main(args):

    torch.cuda.set_device(args['gpu'])
    set_random_seed(args['random_seed'])

    dataset, train_set, val_set, test_set = load_dataset_for_classification(
        args)  # 6264, 783, 784

    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              collate_fn=collate_molgraphs,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_molgraphs)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = load_pretrained(args['exp'])
    else:
        args['n_tasks'] = dataset.n_tasks
        if args['method'] == 'twp':
            model = load_mymodel(args)
            print(model)
        else:
            model = load_model(args)
            for name, parameters in model.named_parameters():
                print(name, ':', parameters.size())

        method = args['method']
        life_model = importlib.import_module(f'LifeModel.{method}_model')
        life_model_ins = life_model.NET(model, args)
        data_loader = DataLoader(train_set,
                                 batch_size=len(train_set),
                                 collate_fn=collate_molgraphs,
                                 shuffle=True)
        life_model_ins.data_loader = data_loader

        loss_criterion = BCEWithLogitsLoss(
            pos_weight=dataset.task_pos_weights.cuda(), reduction='none')

    model.cuda()
    score_mean = []
    score_matrix = np.zeros([args['n_tasks'], args['n_tasks']])

    prev_model = None
    for task_i in range(12):
        print('\n********' + str(task_i))
        stopper = EarlyStopping(patience=args['patience'])
        for epoch in range(args['num_epochs']):
            # Train
            if args['method'] == 'lwf':
                life_model_ins.observe(train_loader, loss_criterion, task_i,
                                       args, prev_model)
            else:
                life_model_ins.observe(train_loader, loss_criterion, task_i,
                                       args)

            # Validation and early stop
            val_score = run_an_eval_epoch(args, model, val_loader, task_i)
            early_stop = stopper.step(val_score, model)

            if early_stop:
                print(epoch)
                break

        if not args['pre_trained']:
            stopper.load_checkpoint(model)

        score_matrix[task_i] = run_eval_epoch(args, model, test_loader)
        prev_model = copy.deepcopy(life_model_ins).cuda()

    print('AP: ', round(np.mean(score_matrix[-1, :]), 4))
    backward = []
    for t in range(args['n_tasks'] - 1):
        b = score_matrix[args['n_tasks'] - 1][t] - score_matrix[t][t]
        backward.append(round(b, 4))
    mean_backward = round(np.mean(backward), 4)
    print('AF: ', mean_backward)
Exemplo n.º 14
0
def main(args):
    
    # fix random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    
    # load CSV dataset
    smlstr = []
    logCMC = []
    with open("../data/dataset.csv") as csvDataFile:
        csvReader = csv.reader(csvDataFile)
        for row in csvReader:
            smlstr.append(row[0])
            logCMC.append(row[1])
    smlstr = np.asarray(smlstr)
    logCMC = np.asarray(logCMC, dtype="float")
    dataset_size = len(smlstr)
    all_ind = np.arange(dataset_size)

    # split into training and testing
    if args.randSplit:
        train_full_ind, test_ind, \
        smlstr_train, smlstr_test, \
        logCMC_train, logCMC_test = train_test_split(all_ind, smlstr, logCMC,
                                                     test_size=args.test_size,
                                                     random_state=args.seed)   
    else:
        if args.dataset in ["nonionic","all"]:
            if args.dataset == "nonionic":
                test_ind = np.array([8,14,26,31,43,54,57,68,72,80,99,110]) # for nonionic surfactants only
            elif args.dataset == "all":
                test_ind = np.array([8,14,26,31,43,54,57,68,72,80,99,110,125,132,140,150,164,171,178,185,192,197])
            train_full_ind = np.asarray([x for x in all_ind if x not in test_ind])
            np.random.shuffle(test_ind)
            np.random.shuffle(train_full_ind)
            smlstr_train  = smlstr[train_full_ind]
            smlstr_test = smlstr[test_ind]
            logCMC_train = logCMC[train_full_ind]
            logCMC_test = logCMC[test_ind]
        else:
            print("Using Random Splits")
            args.randSplit = True
            train_full_ind, test_ind, \
            smlstr_train, smlstr_test, \
            logCMC_train, logCMC_test = train_test_split(all_ind, smlstr, logCMC,
                                                         test_size=args.test_size,
                                                         random_state=args.seed)   

        
    # save train/test data and index corresponding to the original dataset
    pickle.dump(smlstr_train,open("../gnn_logs/smlstr_train.p","wb"))
    pickle.dump(smlstr_test,open("../gnn_logs/smlstr_test.p","wb"))
    pickle.dump(logCMC_train,open("../gnn_logs/logCMC_train.p","wb"))
    pickle.dump(logCMC_test,open("../gnn_logs/logCMC_test.p","wb"))
    pickle.dump(train_full_ind,open("../gnn_logs/original_ind_train_full.p","wb"))
    pickle.dump(test_ind,open("../gnn_logs/original_ind_test.p","wb"))
    rows = zip(train_full_ind,smlstr_train,logCMC_train)
    with open("../gnn_logs/dataset_train.csv",'w',newline='') as f:
        writer = csv.writer(f,delimiter=',')
        for row in rows:
            writer.writerow(row)
    rows = zip(test_ind,smlstr_test,logCMC_test)
    with open("../gnn_logs/dataset_test.csv",'w',newline='') as f:
        writer = csv.writer(f,delimiter=',')
        for row in rows:
            writer.writerow(row)
      
    train_size = len(smlstr_train)
    indices = list(range(train_size))
    
    if args.skip_cv == False:
        # K-fold CV setup
        kf = KFold(n_splits=args.cv, random_state=args.seed, shuffle=True)
        cv_index = 0
        index_list_train = []
        index_list_valid = []
        for train_indices, valid_indices in kf.split(indices):
            index_list_train.append(train_indices)
            index_list_valid.append(valid_indices)
            model = args.gnn_model(args.dim_input, args.unit_per_layer,1,False)
            model_arch = 'GCNReg'
            loss_fn = nn.MSELoss()
            
            # check gpu availability
            if args.gpu >= 0:
                model = model.cuda(args.gpu)
                loss_fn = loss_fn.cuda(args.gpu)
                cudnn.enabled = True
                cudnn.benchmark = True
                cudnn.deterministic = False
            optimizer = torch.optim.Adam(model.parameters(), args.lr)
            # training
    
                    
            if args.single_feat:
                from dgllife.utils import BaseAtomFeaturizer,atomic_number
                train_full_dataset = graph_dataset(smlstr_train,logCMC_train,node_enc=BaseAtomFeaturizer({'h': atomic_number}))
                test_dataset = graph_dataset(smlstr_test,logCMC_test,node_enc=BaseAtomFeaturizer({'h': atomic_number}))
                args.dim_input = 1
            else:
                train_full_dataset = graph_dataset(smlstr_train,logCMC_train)
                test_dataset = graph_dataset(smlstr_test,logCMC_test)
            train_sampler = SubsetRandomSampler(train_indices)
            valid_sampler = SubsetRandomSampler(valid_indices)
            train_loader = torch.utils.data.DataLoader(train_full_dataset, batch_size=args.batch_size,
                                                       sampler=train_sampler,
                                                       collate_fn=collate,
                                                       shuffle=False)
            val_loader = torch.utils.data.DataLoader(train_full_dataset, batch_size=args.batch_size,
                                                     sampler=valid_sampler,
                                                     collate_fn=collate,
                                                     shuffle=False)
            train_dataset = graph_dataset(smlstr_train[train_indices],logCMC_train[train_indices])
            valid_dataset = graph_dataset(smlstr_train[valid_indices],logCMC_train[valid_indices])
    
            fname = r"ep{}bs{}lr{}kf{}hu{}cvid{}".format(args.epochs, args.batch_size,
                                                               args.lr,
                                                               args.cv,
                                                               args.unit_per_layer, cv_index)
            
            best_rmse = 1000        
            if args.train:
                print("Training the model ...")
                stopper = EarlyStopping(mode='lower', patience=args.patience, filename=r'../gnn_logs/{}es.pth.tar'.format(fname)) # early stop model
                for epoch in range(args.start_epoch, args.epochs):
                    train_loss = train(train_loader, model, loss_fn, optimizer, epoch, args, fname)
                    rmse = validate(val_loader, model, epoch, args, fname)
                    is_best = rmse < best_rmse
                    best_rmse = min(rmse, best_rmse)
                    if is_best:
                        save_checkpoint({
                            'epoch': epoch + 1,
                            'model_arch': model_arch,
                            'state_dict': model.state_dict(),
                            'best_rmse': best_rmse,
                            'optimizer': optimizer.state_dict(),
                        }, fname)
                    if args.early_stop:
                        early_stop = stopper.step(train_loss, model)
                        if early_stop:
                            print("**********Early Stopping!")
                            break
    
    
            # test
            print("Testing the model ...")
            checkpoint = torch.load(r"../gnn_logs/{}.pth.tar".format(fname))
            args.start_epoch = 0
            best_rmse = checkpoint['best_rmse']
            model = args.gnn_model(args.dim_input, args.unit_per_layer,1,True)
            if args.gpu >= 0:
                model = model.cuda(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            # if args.gpu < 0:
            #     model = model.cpu()
            # else:
            #     model = model.cuda(args.gpu)
            print("=> loaded checkpoint '{}' (epoch {}, rmse {})"
                  .format(fname, checkpoint['epoch'], best_rmse))
            cudnn.deterministic = True
            stage = 'testtest'
            predict(test_dataset, model, -1, args, fname, stage)
            stage = 'testtrain'
            predict(train_dataset, model, -1, args, fname, stage)
            stage = 'testval'
            predict(valid_dataset, model, -1, args, fname, stage)
            cv_index += 1
        pickle.dump(index_list_train,open("../gnn_logs/ind_train_list.p","wb"))
        pickle.dump(index_list_valid,open("../gnn_logs/ind_val_list.p","wb"))
        cv_index += 1

    else:
        model = args.gnn_model(args.dim_input, args.unit_per_layer,1,False)
        model_arch = 'GCNReg'
        loss_fn = nn.MSELoss()
        
        # check gpu availability
        if args.gpu >= 0:
            model = model.cuda(args.gpu)
            loss_fn = loss_fn.cuda(args.gpu)
            cudnn.enabled = True
            cudnn.benchmark = True
            cudnn.deterministic = False
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
        # training

                
        if args.single_feat:
            from dgllife.utils import BaseAtomFeaturizer,atomic_number
            train_full_dataset = graph_dataset(smlstr_train,logCMC_train,node_enc=BaseAtomFeaturizer({'h': atomic_number}))
            test_dataset = graph_dataset(smlstr_test,logCMC_test,node_enc=BaseAtomFeaturizer({'h': atomic_number}))
            args.dim_input = 1
        else:
            train_full_dataset = graph_dataset(smlstr_train,logCMC_train)
            test_dataset = graph_dataset(smlstr_test,logCMC_test)
        train_loader = torch.utils.data.DataLoader(train_full_dataset, batch_size=args.batch_size,
                                                   collate_fn=collate,
                                                   shuffle=False)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                                  collate_fn=collate,
                                                  shuffle=False)
        train_dataset = graph_dataset(smlstr_train,logCMC_train)
        fname = r"ep{}bs{}lr{}hu{}".format(args.epochs, args.batch_size,
                                                           args.lr,
                                                           args.unit_per_layer)
        
        best_rmse = 1000        
        if args.train:
            print("Training the model ...")
            stopper = EarlyStopping(mode='lower', patience=args.patience, filename=r'../gnn_logs/{}es.pth.tar'.format(fname)) # early stop model
            for epoch in range(args.start_epoch, args.epochs):
                train_loss = train(train_loader, model, loss_fn, optimizer, epoch, args, fname)
                rmse = validate(test_loader, model, epoch, args, fname)
                is_best = rmse < best_rmse
                best_rmse = min(rmse, best_rmse)
                if is_best:
                    save_checkpoint({
                        'epoch': epoch + 1,
                        'model_arch': model_arch,
                        'state_dict': model.state_dict(),
                        'best_rmse': best_rmse,
                        'optimizer': optimizer.state_dict(),
                    }, fname)
                if args.early_stop:
                    early_stop = stopper.step(train_loss, model)
                    if early_stop:
                        print("**********Early Stopping!")
                        break


        # test
        print("Testing the model ...")
        checkpoint = torch.load(r"../gnn_logs/{}.pth.tar".format(fname))
        args.start_epoch = 0
        best_rmse = checkpoint['best_rmse']
        model = args.gnn_model(args.dim_input, args.unit_per_layer,1,True)
        if args.gpu >= 0:
            model = model.cuda(args.gpu)
        model.load_state_dict(checkpoint['state_dict'])
        # if args.gpu < 0:
        #     model = model.cpu()
        # else:
        #     model = model.cuda(args.gpu)
        print("=> loaded checkpoint '{}' (epoch {}, rmse {})"
              .format(fname, checkpoint['epoch'], best_rmse))
        cudnn.deterministic = True
        stage = 'testtest'
        predict(test_dataset, model, -1, args, fname, stage)
        stage = 'testtrain'
        predict(train_dataset, model, -1, args, fname, stage)
        if args.early_stop:
            checkpoint = torch.load(r"../gnn_logs/{}es.pth.tar".format(fname))
            args.start_epoch = 0
            model = args.gnn_model(args.dim_input, args.unit_per_layer,1,True)
            if args.gpu >= 0:
                model = model.cuda(args.gpu)
            model.load_state_dict(checkpoint['model_state_dict'])
            train_dataset = graph_dataset(smlstr_train,logCMC_train)
            test_dataset = graph_dataset(smlstr_test,logCMC_test)
            cudnn.deterministic = True     
            stage = 'testtest'
            predict(test_dataset, model, -1, args, r"{}es".format(fname), stage)
            stage = 'testtrain'
            predict(train_dataset, model, -1, args, r"{}es".format(fname), stage)

        
    return