Example #1
0
from cgcnn.data_muti import CIFData
if __name__ == '__main__':
    dataset = CIFData(
        r"C:\Users\10989\PycharmProjects\8_CGCNN\cgcnn\data\sample-regression-muti-output"
    )
    (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id = dataset.__getitem__(0)
    print(target)

    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=3,
        train_ratio=0.5,
        num_workers=2,
        val_ratio=0.2,
        test_ratio=0.3,
        pin_memory=False,
        train_size=5,
        val_size=3,
        test_size=2,
        return_test=True)
    _, sample_target, _ = collate_pool(sample_data_list)
    normalizer = Normalizer(sample_target)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len,
                                nbr_fea_len,
Example #2
0
def main():
    global args, best_mae_error
    # load data
    dataset = CIFData(*args.data_options)
    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        train_ratio=args.train_ratio,
        num_workers=args.workers,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        pin_memory=args.cuda,
        train_size=args.train_size,
        val_size=args.val_size,
        test_size=args.test_size,
        return_test=True)
    # obtain target value normalizer
    if args.task == 'classification':
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({'mean': 0., 'std': 1.})
    else:
        if len(dataset) < 500:
            warnings.warn('Dataset has less than 500 data points. '
                          'Lower accuracy is expected. ')
            sample_data_list = [dataset[i] for i in range(len(dataset))]
        else:
            sample_data_list = [
                dataset[i] for i in sample(range(len(dataset)), 500)
            ]

        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)
    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=args.atom_fea_len,
        n_conv=args.n_conv,
        h_fea_len=args.h_fea_len,
        n_h=args.n_h,
        classification=True if args.task == 'classification' else False)
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise NameError('Only SGD or Adam is allowed as --optim')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_milestones,
                            gamma=0.1)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, normalizer)

        # evaluate on validation set
        mae_error = validate(val_loader, model, criterion, normalizer)

        if mae_error != mae_error:
            print('Exit due to NaN')
            sys.exit(1)

        scheduler.step()

        # remember the best mae_eror and save checkpoint
        if args.task == 'regression':
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_mae_error': best_mae_error,
                'optimizer': optimizer.state_dict(),
                'normalizer': normalizer.state_dict(),
                'args': vars(args)
            }, is_best)

    # test best model
    print('---------Evaluate Model on Test Set---------------')
    best_checkpoint = torch.load('model_best.pth.tar')
    model.load_state_dict(best_checkpoint['state_dict'])
    validate(test_loader, model, criterion, normalizer, test=True)
Example #3
0
def main():
    #taken from sys.argv
    model_name = sys.argv[1]
    save_name = sys.argv[2]

    #var. for dataset loader
    root_dir = sys.argv[3]
    max_num_nbr = 8
    radius = 4
    dmin = 0
    step = 0.2
    random_seed = 123
    batch_size = 64
    Ntot = 36000  #Total number of data

    train_idx = list(range(100))  #do not change
    val_idx = list(range(100))  #do not change
    test_idx = list(range(Ntot))
    num_workers = 0
    pin_memory = True
    return_test = True

    #var for model
    atom_fea_len = 40
    h_fea_len = 80
    n_conv = 3
    n_h = 2
    lr = 0.001
    lr_decay_rate = 0.98
    weight_decay = 0.0
    best_mae_error = 1e10
    start_epoch = 0
    epochs = 200

    #setup
    dataset = CIFData(root_dir, max_num_nbr, radius, dmin, step, random_seed)
    collate_fn = collate_pool

    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset, collate_fn, batch_size, train_idx, val_idx, test_idx,
        num_workers, pin_memory, return_test)

    sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 100)]
    _, sample_target, _ = collate_pool(sample_data_list)
    normalizer = Normalizer(sample_target)

    #build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len,
                                n_conv, h_fea_len, n_h)
    model.cuda()

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay)
    scheduler = ExponentialLR(optimizer, gamma=lr_decay_rate)

    # optionally resume from a checkpoint
    print("=> loading checkpoint '{}'".format(model_name))
    checkpoint = torch.load(model_name)
    start_epoch = checkpoint['epoch']
    best_mae_error = checkpoint['best_mae_error']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    normalizer.load_state_dict(checkpoint['normalizer'])
    print("=> loaded checkpoint '{}' (epoch {})".format(
        model_name, checkpoint['epoch']))

    print('---------Evaluate Model on Test Set---------------')
    validate(test_loader,
             model,
             criterion,
             normalizer,
             test=True,
             save_name=save_name)
Example #4
0
def main():
    global args, best_mae_error

    # load data
    dataset = CIFData(*args.data_options)
    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        train_ratio=args.train_ratio,
        num_workers=args.workers,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        pin_memory=args.cuda,
        train_size=args.train_size,
        val_size=args.val_size,
        test_size=args.test_size,
        return_val=True,
        return_test=True,
    )

    # obtain target value normalizer
    if args.task == "classification":
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({"mean": 0.0, "std": 1.0})
    else:
        if len(dataset) < 500:
            warnings.warn(
                "Dataset has less than 500 data points. " "Lower accuracy is expected. "
            )
            sample_data_list = [dataset[i] for i in range(len(dataset))]
        else:
            sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 500)]
        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=args.atom_fea_len,
        n_conv=args.n_conv,
        h_fea_len=args.h_fea_len,
        n_h=args.n_h,
        classification=True if args.task == "classification" else False,
        dropout_rate=args.dropout_rate,
    )

    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if args.task == "classification":
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    if args.optim == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "Adam":
        optimizer = optim.Adam(
            model.parameters(), args.lr, weight_decay=args.weight_decay
        )
    else:
        raise NameError("Only SGD or Adam is allowed as --optim")

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            best_mae_error = checkpoint["best_mae_error"]
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            normalizer.load_state_dict(checkpoint["normalizer"])
            print(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint["epoch"]
                )
            )
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, normalizer)

        # evaluate on validation set
        mae_error = validate(val_loader, model, criterion, normalizer)

        if mae_error != mae_error:
            print("Exit due to NaN")
            sys.exit(1)

        scheduler.step()

        # remember the best mae_eror and save checkpoint
        if args.task == "regression":
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_mae_error": best_mae_error,
                "optimizer": optimizer.state_dict(),
                "normalizer": normalizer.state_dict(),
                "args": vars(args),
            },
            is_best,
        )

    # test best model
    best_checkpoint = torch.load("model_best.pth.tar")
    model.load_state_dict(best_checkpoint["state_dict"])
    validate(
        train_loader, model, criterion, normalizer, test=True, fname="train_results"
    )
    validate(val_loader, model, criterion, normalizer, test=True, fname="val_results")
    validate(test_loader, model, criterion, normalizer, test=True, fname="test_results")
Example #5
0
def main():
    global args, model_args, best_mae_error

    # load data
    dataset = CIFData(args.cifpath,
                      max_num_nbr=model_args.max_num_nbr,
                      radius=model_args.radius,
                      nn_method=model_args.nn_method,
                      disable_save_torch=args.disable_save_torch)
    collate_fn = collate_pool

    if args.train_val_test:
        train_loader, val_loader, test_loader = get_train_val_test_loader(
            dataset=dataset,
            collate_fn=collate_fn,
            batch_size=model_args.batch_size,
            train_ratio=model_args.train_ratio,
            num_workers=args.workers,
            val_ratio=model_args.val_ratio,
            test_ratio=model_args.test_ratio,
            pin_memory=args.cuda,
            train_size=model_args.train_size,
            val_size=model_args.val_size,
            test_size=model_args.test_size,
            return_test=True)
    else:
        test_loader = DataLoader(dataset,
                                 batch_size=model_args.batch_size,
                                 shuffle=True,
                                 num_workers=args.workers,
                                 collate_fn=collate_fn,
                                 pin_memory=args.cuda)

    # make and clean torch files if needed
    torch_data_path = os.path.join(args.cifpath, 'cifdata')
    if args.clean_torch and os.path.exists(torch_data_path):
        shutil.rmtree(torch_data_path)
    if os.path.exists(torch_data_path):
        if not args.clean_torch:
            warnings.warn('Found torch .json files at ' + torch_data_path +
                          '. Will read in .jsons as-available')
    else:
        os.mkdir(torch_data_path)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=model_args.atom_fea_len,
        n_conv=model_args.n_conv,
        h_fea_len=model_args.h_fea_len,
        n_h=model_args.n_h,
        classification=True if model_args.task == 'classification' else False,
        enable_tanh=model_args.enable_tanh)
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if model_args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()

    normalizer = Normalizer(torch.zeros(3))

    # optionally resume from a checkpoint
    if os.path.isfile(args.modelpath):
        print("=> loading model '{}'".format(args.modelpath))
        checkpoint = torch.load(args.modelpath,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        normalizer.load_state_dict(checkpoint['normalizer'])
        print("=> loaded model '{}' (epoch {}, validation {})".format(
            args.modelpath, checkpoint['epoch'], checkpoint['best_mae_error']))
    else:
        print("=> no model found at '{}'".format(args.modelpath))

    if args.train_val_test:
        print('---------Evaluate Model on Train Set---------------')
        validate(train_loader,
                 model,
                 criterion,
                 normalizer,
                 test=True,
                 csv_name='train_results.csv')
        print('---------Evaluate Model on Val Set---------------')
        validate(val_loader,
                 model,
                 criterion,
                 normalizer,
                 test=True,
                 csv_name='val_results.csv')
        print('---------Evaluate Model on Test Set---------------')
        validate(test_loader,
                 model,
                 criterion,
                 normalizer,
                 test=True,
                 csv_name='test_results.csv')
    else:
        print('---------Evaluate Model on Dataset---------------')
        validate(test_loader,
                 model,
                 criterion,
                 normalizer,
                 test=True,
                 csv_name='predictions.csv')
Example #6
0
def main():
    #taken from sys.argv
    chk_name = sys.argv[1]
    best_name = sys.argv[2]
    save_name = sys.argv[3]

    #var. for dataset loader
    root_dir = '/your/model/path/'
    max_num_nbr = 8
    radius = 4
    dmin = 0
    step = 0.2
    random_seed = 1234
    batch_size = 64
    N_tot = len(open(root_dir + '/id_prop.csv').readlines())
    N_tr = int(N_tot * 0.8)
    N_val = int(N_tot * 0.1)
    N_test = N_tot - N_tr - N_val

    train_idx = list(range(N_tr))
    val_idx = list(range(N_tr, N_tr + N_val))
    test_idx = list(range(N_tr + N_val, N_tr + N_val + N_test))

    num_workers = 0
    pin_memory = True
    return_test = True

    #var for model
    #	atom_fea_len,h_fea_len,n_conv,n_h,lr_decay_rate = Hyp_loader(root_dir,hyp_idx)
    atom_fea_len = 90
    h_fea_len = 2 * atom_fea_len
    n_conv = 5
    n_h = 2
    lr_decay_rate = 0.97

    lr = 0.001
    weight_decay = 0.0
    resume = False
    resume_path = 'ddd'

    #var for training
    best_mae_error = 1e10
    start_epoch = 0
    epochs = 200

    #setup
    dataset = CIFData(root_dir, max_num_nbr, radius, dmin, step, random_seed)
    collate_fn = collate_pool

    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset, collate_fn, batch_size, train_idx, val_idx, test_idx,
        num_workers, pin_memory, return_test)

    sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 500)]
    _, sample_target, _ = collate_pool(sample_data_list)
    normalizer = Normalizer(sample_target)

    #build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len,
                                n_conv, h_fea_len, n_h)
    model.cuda()

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay)
    scheduler = ExponentialLR(optimizer, gamma=lr_decay_rate)

    # optionally resume from a checkpoint
    if resume:
        print("=> loading checkpoint '{}'".format(args.resume))
        start_epoch = checkpoint['epoch']
        best_mae_error = checkpoint['best_mae_error']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        normalizer.load_state_dict(checkpoint['normalizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    t0 = time.time()
    for epoch in range(start_epoch, epochs):
        train(train_loader, model, criterion, optimizer, epoch, normalizer)
        mae_error = validate(val_loader, model, criterion, normalizer)

        scheduler.step()
        is_best = mae_error < best_mae_error
        best_mae_error = min(mae_error, best_mae_error)
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_mae_error': best_mae_error,
                'optimizer': optimizer.state_dict(),
                'normalizer': normalizer.state_dict()
            }, is_best, chk_name, best_name)
    t1 = time.time()
    print('--------Training time in sec-------------')
    print(t1 - t0)
    print('---------Best Model on Validation Set---------------')
    best_checkpoint = torch.load(best_name)
    print(best_checkpoint['best_mae_error'].cpu().numpy())
    print('---------Evaluate Model on Test Set---------------')
    model.load_state_dict(best_checkpoint['state_dict'])
    validate(test_loader,
             model,
             criterion,
             normalizer,
             test=True,
             save_name=save_name)
Example #7
0
def main():
    #taken from sys.argv
    resume = True
    resume_path = sys.argv[1]

    #var. for dataset loader
    root_dir = '/your/data/path/'
    max_num_nbr = 8
    radius = 4
    dmin = 0
    step = 0.2
    random_seed = 1234
    batch_size = 64
    N_tot = len(open(root_dir + '/id_prop.csv').readlines())
    N_tr = int(N_tot * 0.8)
    N_val = int(N_tot * 0.1)
    N_test = N_tot - N_tr - N_val
    #	N_test = N_tot

    train_idx = list(range(N_tr))
    val_idx = list(range(N_tr, N_tr + N_val))
    test_idx = list(range(N_tot))

    num_workers = 0
    pin_memory = False
    return_test = True

    #var for model
    atom_fea_len = 90
    h_fea_len = 2 * atom_fea_len
    n_conv = 5
    n_h = 2
    lr_decay_rate = 0.99

    lr = 0.001
    weight_decay = 0.0

    model_args = {
        'radius': radius,
        'dmin': dmin,
        'step': step,
        'batch_size': batch_size,
        'random_seed': random_seed,
        'N_tr': N_tr,
        'N_val': N_val,
        'N_test': N_test,
        'atom_fea_len': atom_fea_len,
        'h_fea_len': h_fea_len,
        'n_conv': n_conv,
        'n_h': n_h,
        'lr': lr,
        'lr_decay_rate': lr_decay_rate,
        'weight_decay': weight_decay
    }

    #var for training
    best_mae_error = 1e10
    start_epoch = 0
    epochs = 1000

    #setup
    dataset = CIFData(root_dir, max_num_nbr, radius, dmin, step, random_seed)
    collate_fn = collate_pool

    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset, collate_fn, batch_size, train_idx, val_idx, test_idx,
        num_workers, pin_memory, return_test)

    sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 1)]
    _, sample_target, _ = collate_pool(sample_data_list)
    normalizer = Normalizer(sample_target)

    #build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len,
                                n_conv, h_fea_len, n_h)
    model.cuda()

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay)
    scheduler = ExponentialLR(optimizer, gamma=lr_decay_rate)

    # optionally resume from a checkpoint
    if resume:
        print("=> loading checkpoint '{}'".format(resume_path))
        checkpoint = torch.load(resume_path)
        start_epoch = checkpoint['epoch']
        best_mae_error = checkpoint['best_mae_error']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        normalizer.load_state_dict(checkpoint['normalizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            resume_path, checkpoint['epoch']))

    print('---------Evaluate Model on Test Set---------------')
    save_name = 'dropout_test.csv'
    validate(test_loader,
             model,
             criterion,
             normalizer,
             test=True,
             save_name=save_name)
Example #8
0
def main():
    global args, best_mae_error

    # load data
    dataset = CIFData(*args.data_options,
                      disable_save_torch=args.disable_save_torch)
    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        train_ratio=args.train_ratio,
        num_workers=args.workers,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        pin_memory=args.cuda,
        train_size=args.train_size,
        val_size=args.val_size,
        test_size=args.test_size,
        return_test=True)

    # Make sure >1 class is present
    if args.task == 'classification':
        total_train = 0
        total_val = 0
        total_test = 0
        for i, (_, target, _) in enumerate(train_loader):
            for target_i in target.squeeze():
                total_train += target_i
        if bool(total_train == 0):
            raise ValueError('All 0s in train')
        elif bool(total_train == 1):
            raise ValueError('All 1s in train')
        for i, (_, target, _) in enumerate(val_loader):
            if len(target) == 1:
                raise ValueError('Only single entry in val')
            for target_i in target.squeeze():
                total_val += target_i
        if bool(total_val == 0):
            raise ValueError('All 0s in val')
        elif bool(total_val == 1):
            raise ValueError('All 1s in val')
        for i, (_, target, _) in enumerate(test_loader):
            if len(target) == 1:
                raise ValueError('Only single entry in test')
            for target_i in target.squeeze():
                total_test += target_i
        if bool(total_test == 0):
            raise ValueError('All 0s in test')
        elif bool(total_test == 1):
            raise ValueError('All 1s in test')

    # make output folder if needed
    if not os.path.exists('output'):
        os.mkdir('output')

    # make and clean torch files if needed
    torch_data_path = os.path.join(args.data_options[0], 'cifdata')
    if args.clean_torch and os.path.exists(torch_data_path):
        shutil.rmtree(torch_data_path)
    if os.path.exists(torch_data_path):
        if not args.clean_torch:
            warnings.warn('Found cifdata folder at ' +
                          torch_data_path+'. Will read in .jsons as-available')
    else:
        os.mkdir(torch_data_path)

    # obtain target value normalizer
    if args.task == 'classification':
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({'mean': 0., 'std': 1.})
    else:
        if len(dataset) < 500:
            warnings.warn('Dataset has less than 500 data points. '
                          'Lower accuracy is expected. ')
            sample_data_list = [dataset[i] for i in range(len(dataset))]
        else:
            sample_data_list = [dataset[i] for i in
                                sample(range(len(dataset)), 500)]
        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
                                atom_fea_len=args.atom_fea_len,
                                n_conv=args.n_conv,
                                h_fea_len=args.h_fea_len,
                                n_h=args.n_h,
                                classification=True if args.task ==
                                'classification' else False)
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(), args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(), args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise NameError('Only SGD or Adam is allowed as --optim')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones,
                            gamma=0.1)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, normalizer)

        # evaluate on validation set
        mae_error = validate(val_loader, model, criterion, normalizer)

        if mae_error != mae_error:
            print('Exit due to NaN')
            sys.exit(1)

        scheduler.step()

        # remember the best mae_eror and save checkpoint
        if args.task == 'regression':
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_mae_error': best_mae_error,
            'optimizer': optimizer.state_dict(),
            'normalizer': normalizer.state_dict(),
            'args': vars(args)
        }, is_best)

    # test best model
    best_checkpoint = torch.load(os.path.join('output', 'model_best.pth.tar'))
    model.load_state_dict(best_checkpoint['state_dict'])
    print('---------Evaluate Best Model on Train Set---------------')
    validate(train_loader, model, criterion, normalizer, test=True,
             csv_name='train_results.csv')
    print('---------Evaluate Best Model on Val Set---------------')
    validate(val_loader, model, criterion, normalizer, test=True,
             csv_name='val_results.csv')
    print('---------Evaluate Best Model on Test Set---------------')
    validate(test_loader, model, criterion, normalizer, test=True,
             csv_name='test_results.csv')
Example #9
0
def main():
    global args, best_mae_error

    # load dataset: (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id
    dataset = CIFData(args.root + args.target)
    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        train_ratio=args.train_ratio,
        num_workers=args.workers,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        pin_memory=args.cuda,
        return_test=True)

    # obtain target value normalizer
    if args.task == 'classification':
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({'mean': 0., 'std': 1.})
    else:
        sample_data_list = [dataset[i] for i in \
                            sample(range(len(dataset)), 500)]
        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=args.atom_fea_len,
        n_conv=args.n_conv,
        h_fea_len=args.h_fea_len,
        n_h=args.n_h,
        classification=True if args.task == 'classification' else False)
    # pring number of trainable model parameters
    trainable_params = sum(p.numel() for p in model.parameters()
                           if p.requires_grad)
    print('=> number of trainable model parameters: {:d}'.format(
        trainable_params))

    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise NameError('Only SGD or Adam is allowed as --optim')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # TensorBoard writer
    summary_root = './runs/'
    if not os.path.exists(summary_root):
        os.mkdir(summary_root)
    summary_file = summary_root + args.target
    if os.path.exists(summary_file):
        shutil.rmtree(summary_file)
    writer = SummaryWriter(summary_file)

    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_milestones,
                            gamma=0.1)

    for epoch in range(args.start_epoch, args.start_epoch + args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, normalizer,
              writer)

        # evaluate on validation set
        mae_error = validate(val_loader, model, criterion, epoch, normalizer,
                             writer)

        scheduler.step()

        # remember the best mae_eror and save checkpoint
        if args.task == 'regression':
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_mae_error': best_mae_error,
                'optimizer': optimizer.state_dict(),
                'normalizer': normalizer.state_dict(),
                'args': vars(args)
            }, args.target, is_best)