Esempio n. 1
0
def main():
    global args, model_args, best_mae_error

    # load data
    dataset = CIFData(args.cifpath)
    collate_fn = collate_pool
    test_loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=collate_fn,
        pin_memory=args.cuda,
    )

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=model_args.atom_fea_len,
        n_conv=model_args.n_conv,
        h_fea_len=model_args.h_fea_len,
        n_h=model_args.n_h,
        classification=True if model_args.task == "classification" else False,
    )
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if model_args.task == "classification":
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()

    normalizer = Normalizer(torch.zeros(3))

    # optionally resume from a checkpoint
    if os.path.isfile(args.modelpath):
        print("=> loading model '{}'".format(args.modelpath))
        checkpoint = torch.load(args.modelpath,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint["state_dict"])
        normalizer.load_state_dict(checkpoint["normalizer"])
        print("=> loaded model '{}' (epoch {}, validation {})".format(
            args.modelpath, checkpoint["epoch"], checkpoint["best_mae_error"]))
    else:
        print("=> no model found at '{}'".format(args.modelpath))

    validate(test_loader,
             model,
             criterion,
             normalizer,
             test=True,
             fname="predictions")
    print(dataset.bad_indices)
Esempio n. 2
0
def main():
    global args, model_args, best_mae_error

    # load data
    dataset = CIFData(args.cifpath)
    collate_fn = collate_pool
    test_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.workers,
                             collate_fn=collate_fn,
                             pin_memory=args.cuda)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=model_args.atom_fea_len,
        n_conv=model_args.n_conv,
        h_fea_len=model_args.h_fea_len,
        n_h=model_args.n_h,
        classification=True if model_args.task == 'classification' else False)
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if model_args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    # if args.optim == 'SGD':
    #     optimizer = optim.SGD(model.parameters(), args.lr,
    #                           momentum=args.momentum,
    #                           weight_decay=args.weight_decay)
    # elif args.optim == 'Adam':
    #     optimizer = optim.Adam(model.parameters(), args.lr,
    #                            weight_decay=args.weight_decay)
    # else:
    #     raise NameError('Only SGD or Adam is allowed as --optim')

    normalizer = Normalizer(torch.zeros(3))

    # optionally resume from a checkpoint
    if os.path.isfile(args.modelpath):
        print("=> loading model '{}'".format(args.modelpath))
        checkpoint = torch.load(args.modelpath,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        normalizer.load_state_dict(checkpoint['normalizer'])
        print("=> loaded model '{}' (epoch {}, validation {})".format(
            args.modelpath, checkpoint['epoch'], checkpoint['best_mae_error']))
    else:
        print("=> no model found at '{}'".format(args.modelpath))

    validate(test_loader, model, criterion, normalizer, test=True)
def predict(args, dataset, collate_fn, test_loader):

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=model_args.atom_fea_len,
        n_conv=model_args.n_conv,
        h_fea_len=model_args.h_fea_len,
        n_h=model_args.n_h,
        classification=True if model_args.task == 'classification' else False)
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if model_args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()

    normalizer = Normalizer(torch.zeros(3))

    # optionally resume from a checkpoint
    if os.path.isfile(args.modelpath):
        print("=> loading model '{}'".format(args.modelpath))
        checkpoint = torch.load(args.modelpath,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        normalizer.load_state_dict(checkpoint['normalizer'])
        print("=> loaded model '{}' (epoch {}, validation {})".format(
            args.modelpath, checkpoint['epoch'], checkpoint['best_mae_error']))
    else:
        print("=> no model found at '{}'".format(args.modelpath))

    validate(test_loader, model, criterion, normalizer, test=True)
Esempio n. 4
0
def main():
    global args, best_mae_error
    # load data
    dataset = CIFData(*args.data_options)
    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        train_ratio=args.train_ratio,
        num_workers=args.workers,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        pin_memory=args.cuda,
        train_size=args.train_size,
        val_size=args.val_size,
        test_size=args.test_size,
        return_test=True)
    # obtain target value normalizer
    if args.task == 'classification':
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({'mean': 0., 'std': 1.})
    else:
        if len(dataset) < 500:
            warnings.warn('Dataset has less than 500 data points. '
                          'Lower accuracy is expected. ')
            sample_data_list = [dataset[i] for i in range(len(dataset))]
        else:
            sample_data_list = [
                dataset[i] for i in sample(range(len(dataset)), 500)
            ]

        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)
    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=args.atom_fea_len,
        n_conv=args.n_conv,
        h_fea_len=args.h_fea_len,
        n_h=args.n_h,
        classification=True if args.task == 'classification' else False)
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise NameError('Only SGD or Adam is allowed as --optim')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_milestones,
                            gamma=0.1)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, normalizer)

        # evaluate on validation set
        mae_error = validate(val_loader, model, criterion, normalizer)

        if mae_error != mae_error:
            print('Exit due to NaN')
            sys.exit(1)

        scheduler.step()

        # remember the best mae_eror and save checkpoint
        if args.task == 'regression':
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_mae_error': best_mae_error,
                'optimizer': optimizer.state_dict(),
                'normalizer': normalizer.state_dict(),
                'args': vars(args)
            }, is_best)

    # test best model
    print('---------Evaluate Model on Test Set---------------')
    best_checkpoint = torch.load('model_best.pth.tar')
    model.load_state_dict(best_checkpoint['state_dict'])
    validate(test_loader, model, criterion, normalizer, test=True)
Esempio n. 5
0
    print(target)

    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=3,
        train_ratio=0.5,
        num_workers=2,
        val_ratio=0.2,
        test_ratio=0.3,
        pin_memory=False,
        train_size=5,
        val_size=3,
        test_size=2,
        return_test=True)
    _, sample_target, _ = collate_pool(sample_data_list)
    normalizer = Normalizer(sample_target)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len,
                                nbr_fea_len,
                                atom_fea_len=4,
                                n_conv=3,
                                h_fea_len=32,
                                n_h=1,
                                classification=False)
Esempio n. 6
0
def main():
    global args, best_mae_error

    # load data
    dataset = CIFData(*args.data_options)
    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        train_ratio=args.train_ratio,
        num_workers=args.workers,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        pin_memory=args.cuda,
        train_size=args.train_size,
        val_size=args.val_size,
        test_size=args.test_size,
        return_val=True,
        return_test=True,
    )

    # obtain target value normalizer
    if args.task == "classification":
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({"mean": 0.0, "std": 1.0})
    else:
        if len(dataset) < 500:
            warnings.warn(
                "Dataset has less than 500 data points. " "Lower accuracy is expected. "
            )
            sample_data_list = [dataset[i] for i in range(len(dataset))]
        else:
            sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 500)]
        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=args.atom_fea_len,
        n_conv=args.n_conv,
        h_fea_len=args.h_fea_len,
        n_h=args.n_h,
        classification=True if args.task == "classification" else False,
        dropout_rate=args.dropout_rate,
    )

    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if args.task == "classification":
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    if args.optim == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "Adam":
        optimizer = optim.Adam(
            model.parameters(), args.lr, weight_decay=args.weight_decay
        )
    else:
        raise NameError("Only SGD or Adam is allowed as --optim")

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            best_mae_error = checkpoint["best_mae_error"]
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            normalizer.load_state_dict(checkpoint["normalizer"])
            print(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint["epoch"]
                )
            )
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, normalizer)

        # evaluate on validation set
        mae_error = validate(val_loader, model, criterion, normalizer)

        if mae_error != mae_error:
            print("Exit due to NaN")
            sys.exit(1)

        scheduler.step()

        # remember the best mae_eror and save checkpoint
        if args.task == "regression":
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_mae_error": best_mae_error,
                "optimizer": optimizer.state_dict(),
                "normalizer": normalizer.state_dict(),
                "args": vars(args),
            },
            is_best,
        )

    # test best model
    best_checkpoint = torch.load("model_best.pth.tar")
    model.load_state_dict(best_checkpoint["state_dict"])
    validate(
        train_loader, model, criterion, normalizer, test=True, fname="train_results"
    )
    validate(val_loader, model, criterion, normalizer, test=True, fname="val_results")
    validate(test_loader, model, criterion, normalizer, test=True, fname="test_results")
Esempio n. 7
0
def cv():
    global args, best_mae_error

    if not os.path.exists("./checkpoints"):
        os.mkdir("./checkpoints")
    # load data
    dataset = CIFData(*args.data_options)
    collate_fn = collate_pool
    i = 0
    train_maes = []
    val_maes = []
    test_maes = []
    for train_loader, val_loader, test_loader in get_cv_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        train_ratio=args.train_ratio,
        num_workers=args.workers,
        test_ratio=args.test_ratio,
        pin_memory=args.cuda,
        train_size=args.train_size,
        val_size=args.val_size,
        test_size=args.test_size,
        cross_validation=args.cross_validation,
    ):
        i += 1
        # obtain target value normalizer
        if args.task == "classification":
            normalizer = Normalizer(torch.zeros(2))
            normalizer.load_state_dict({"mean": 0.0, "std": 1.0})
        else:
            if len(dataset) < 500:
                warnings.warn(
                    "Dataset has less than 500 data points. "
                    "Lower accuracy is expected. "
                )
                sample_data_list = [dataset[i] for i in range(len(dataset))]
            else:
                sample_data_list = [
                    dataset[i] for i in sample(range(len(dataset)), 500)
                ]
            _, sample_target, _ = collate_pool(sample_data_list)
            normalizer = Normalizer(sample_target)

        # build model
        structures, _, _ = dataset[0]
        orig_atom_fea_len = structures[0].shape[-1]
        nbr_fea_len = structures[1].shape[-1]
        model = CrystalGraphConvNet(
            orig_atom_fea_len,
            nbr_fea_len,
            atom_fea_len=args.atom_fea_len,
            n_conv=args.n_conv,
            h_fea_len=args.h_fea_len,
            n_h=args.n_h,
            classification=True if args.task == "classification" else False,
            dropout_rate=args.dropout_rate,
        )

        if args.cuda:
            model.cuda()

        # define loss func and optimizer
        if args.task == "classification":
            criterion = nn.NLLLoss()
        else:
            criterion = nn.MSELoss()
        if args.optim == "SGD":
            optimizer = optim.SGD(
                model.parameters(),
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
            )
        elif args.optim == "Adam":
            optimizer = optim.Adam(
                model.parameters(), args.lr, weight_decay=args.weight_decay
            )
        else:
            raise NameError("Only SGD or Adam is allowed as --optim")

        scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1)

        print(f"Split {i}")
        if args.task == "regression":
            best_mae_error = 1e10
        else:
            best_mae_error = 0.0
        for epoch in range(args.start_epoch, args.epochs):
            # train for one epoch
            train(train_loader, model, criterion, optimizer, epoch, normalizer)

            # evaluate on validation set
            mae_error = validate(val_loader, model, criterion, normalizer)

            if mae_error != mae_error:
                print("Exit due to NaN")
                sys.exit(1)

            scheduler.step()

            # remember the best mae_eror and save checkpoint
            if args.task == "regression":
                is_best = mae_error < best_mae_error
                best_mae_error = min(mae_error, best_mae_error)
            else:
                is_best = mae_error > best_mae_error
                best_mae_error = max(mae_error, best_mae_error)
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict(),
                    "best_mae_error": best_mae_error,
                    "optimizer": optimizer.state_dict(),
                    "normalizer": normalizer.state_dict(),
                    "args": vars(args),
                },
                is_best,
                pad_string=f"./checkpoints/{i}",
            )

        # test best model
        best_checkpoint = torch.load(f"./checkpoints/{i}_model_best.pth.tar")
        model.load_state_dict(best_checkpoint["state_dict"])
        train_mae = validate(
            train_loader,
            model,
            criterion,
            normalizer,
            test=True,
            split=i,
            fname="train",
            to_save=False,
        )
        val_mae = validate(
            val_loader,
            model,
            criterion,
            normalizer,
            test=True,
            fname="val",
            to_save=False,
        )
        test_mae = validate(
            test_loader,
            model,
            criterion,
            normalizer,
            test=True,
            fname="test",
            to_save=False,
        )
        train_maes.append(train_mae.detach().item())
        val_maes.append(val_mae.detach().item())
        test_maes.append(test_mae.detach().item())
    with open("results.out", "a+") as fw:
        fw.write("\n")
        fw.write(f"Avg Train MAE: {np.mean(train_maes):.4f}\n")
        fw.write(f"Avg Val MAE: {np.mean(val_maes):.4f}\n")
        fw.write(f"Avg Test MAE: {np.mean(test_maes):.4f}\n")
Esempio n. 8
0
from cgcnn.data import collate_pool, get_train_val_test_loader
if __name__ == '__main__':

    PATH = r"C:\Users\10989\PycharmProjects\8_CGCNN\cgcnn\pre-trained\band-gap.pth.tar"
    model_dict = torch.load(PATH, map_location=torch.device('cpu'))
    dict_name = list(model_dict)
    for i, p in enumerate(dict_name):
        print(i, p)
    for i, p in enumerate(model_dict["state_dict"]):
        print(i, p,'\t',model_dict["state_dict"][p].size())
    from cgcnn.model import CrystalGraphConvNet
    from cgcnn.data import CIFData

    model = CrystalGraphConvNet(92, 41,
                                    atom_fea_len=64,
                                    n_conv=4,
                                    h_fea_len=32,
                                    n_h=1,
                                    classification=False)
    model.load_state_dict(model_dict["state_dict"])
    # load data
    cif_data = CIFData(r"C:\Users\10989\PycharmProjects\8_CGCNN\cgcnn\data\sample-classification")
    dataset = CIFData(r"C:\Users\10989\PycharmProjects\8_CGCNN\cgcnn\data\sample-classification")
    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=3,
        train_ratio=0.5,
        num_workers=2,
        val_ratio=0.2,
        test_ratio=0.3,
Esempio n. 9
0
def main():
    global args, model_args, best_mae_error

    # load data
    dataset = CIFData(args.cifpath,
                      max_num_nbr=model_args.max_num_nbr,
                      radius=model_args.radius,
                      nn_method=model_args.nn_method,
                      disable_save_torch=args.disable_save_torch)
    collate_fn = collate_pool

    if args.train_val_test:
        train_loader, val_loader, test_loader = get_train_val_test_loader(
            dataset=dataset,
            collate_fn=collate_fn,
            batch_size=model_args.batch_size,
            train_ratio=model_args.train_ratio,
            num_workers=args.workers,
            val_ratio=model_args.val_ratio,
            test_ratio=model_args.test_ratio,
            pin_memory=args.cuda,
            train_size=model_args.train_size,
            val_size=model_args.val_size,
            test_size=model_args.test_size,
            return_test=True)
    else:
        test_loader = DataLoader(dataset,
                                 batch_size=model_args.batch_size,
                                 shuffle=True,
                                 num_workers=args.workers,
                                 collate_fn=collate_fn,
                                 pin_memory=args.cuda)

    # make and clean torch files if needed
    torch_data_path = os.path.join(args.cifpath, 'cifdata')
    if args.clean_torch and os.path.exists(torch_data_path):
        shutil.rmtree(torch_data_path)
    if os.path.exists(torch_data_path):
        if not args.clean_torch:
            warnings.warn('Found torch .json files at ' + torch_data_path +
                          '. Will read in .jsons as-available')
    else:
        os.mkdir(torch_data_path)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=model_args.atom_fea_len,
        n_conv=model_args.n_conv,
        h_fea_len=model_args.h_fea_len,
        n_h=model_args.n_h,
        classification=True if model_args.task == 'classification' else False,
        enable_tanh=model_args.enable_tanh)
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if model_args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()

    normalizer = Normalizer(torch.zeros(3))

    # optionally resume from a checkpoint
    if os.path.isfile(args.modelpath):
        print("=> loading model '{}'".format(args.modelpath))
        checkpoint = torch.load(args.modelpath,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        normalizer.load_state_dict(checkpoint['normalizer'])
        print("=> loaded model '{}' (epoch {}, validation {})".format(
            args.modelpath, checkpoint['epoch'], checkpoint['best_mae_error']))
    else:
        print("=> no model found at '{}'".format(args.modelpath))

    if args.train_val_test:
        print('---------Evaluate Model on Train Set---------------')
        validate(train_loader,
                 model,
                 criterion,
                 normalizer,
                 test=True,
                 csv_name='train_results.csv')
        print('---------Evaluate Model on Val Set---------------')
        validate(val_loader,
                 model,
                 criterion,
                 normalizer,
                 test=True,
                 csv_name='val_results.csv')
        print('---------Evaluate Model on Test Set---------------')
        validate(test_loader,
                 model,
                 criterion,
                 normalizer,
                 test=True,
                 csv_name='test_results.csv')
    else:
        print('---------Evaluate Model on Dataset---------------')
        validate(test_loader,
                 model,
                 criterion,
                 normalizer,
                 test=True,
                 csv_name='predictions.csv')
Esempio n. 10
0
def tc_trans2():
    global args, best_mae_error

    # load data
    dataset = CIFData(*args.data_options)
    collate_fn = collate_pool

    # obtain target value normalizer
    if args.task == 'classification':
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({'mean': 0., 'std': 1.})
    else:
        if len(dataset) < 500:
            warnings.warn('Dataset has less than 500 data points. '
                          'Lower accuracy is expected. ')
            sample_data_list = [dataset[i] for i in range(len(dataset))]
        else:
            sample_data_list = [
                dataset[i] for i in sample(range(len(dataset)), 500)
            ]
        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model_a = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=args.atom_fea_len,
        n_conv=args.n_conv,
        h_fea_len=args.h_fea_len,
        n_h=args.n_h,
        classification=True if args.task == 'classification' else False)
    model_b = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=args.atom_fea_len,
        n_conv=args.n_conv,
        h_fea_len=args.h_fea_len,
        n_h=args.n_h,
        classification=True if args.task == 'classification' else False)
    model = SimpleNN(in_feature=256, out_feature=1)

    # pretrained model path
    model_a_path = '../pre-trained/research-model/bulk_moduli-model_best.pth.tar'
    model_b_path = '../pre-trained/research-model/sps-model_best.pth.tar'

    # load latest model state
    ckpt_a = torch.load(model_a_path)
    ckpt_b = torch.load(model_b_path)

    # load model
    model_a.load_state_dict(ckpt_a['state_dict'])
    model_b.load_state_dict(ckpt_b['state_dict'])

    def get_activation_a(name, activation_a):
        def hook(model, input, output):
            activation_a[name] = output.detach()

        return hook

    def get_activation_b(name, activation_b):
        def hook(model, input, output):
            activation_b[name] = output.detach()

        return hook

    if args.cuda:
        model_a.cuda()
        model_b.cuda()
        model.cuda()

    activation_a = {}
    activation_b = {}

    # hook the activation function
    model_a.conv_to_fc.register_forward_hook(
        get_activation_a('conv_to_fc', activation_a))
    model_b.conv_to_fc.register_forward_hook(
        get_activation_b('conv_to_fc', activation_b))

    # define loss func and optimizer
    if args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise NameError('Only SGD or Adam is allowed as --optim')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_milestones,
                            gamma=0.1)
    X = torch.Tensor()
    T = torch.Tensor()
    for i in range(5):

        total_size = len(dataset)
        indices = list(range(total_size))
        batch_size = args.batch_size
        num_workers = args.workers
        pin_memory = args.cuda

        if i == 0:
            train_sampler = SubsetRandomSampler(indices[:61])
            test_sampler = SubsetRandomSampler(indices[-16:])
        if i == 1:
            x = indices[:45]
            y = x.extend(indices[-16:])
            train_samplre = SubsetRandomSampler(y)
            test_sampler = SubsetRandomSampler(indices[45:-16])
        if i == 2:
            x = indices[:29]
            y = x.extend(indices[-32:])
            train_samplre = SubsetRandomSampler(y)
            test_sampler = SubsetRandomSampler(indices[29:-32])
        if i == 3:
            x = indices[:13]
            y = x.extend(indices[-48:])
            train_samplre = SubsetRandomSampler(y)
            test_sampler = SubsetRandomSampler(indices[13:-48])
        if i == 4:
            y = indices[-64:]
            train_samplre = SubsetRandomSampler(y)
            test_sampler = SubsetRandomSampler(indices[:-64])

        train_loader = DataLoader(dataset,
                                  batch_size=batch_size,
                                  sampler=train_sampler,
                                  num_workers=num_workers,
                                  collate_fn=collate_fn,
                                  pin_memory=pin_memory)

        test_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 sampler=test_sampler,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn,
                                 pin_memory=pin_memory)
        print(test_sampler)
        for epoch in range(args.start_epoch, args.epochs):
            # train for one epoch
            train(args, train_loader, model_a, model_b, model, activation_a,
                  activation_b, criterion, optimizer, epoch, normalizer)

            # evaluate on validation set
            mae_error = validate(args, train_loader, model_a, model_b, model,
                                 activation_a, activation_b, criterion,
                                 normalizer)

            if mae_error != mae_error:
                print('Exit due to NaN')
                sys.exit(1)

            scheduler.step()

            # remember the best mae_eror and save checkpoint
            if args.task == 'regression':
                is_best = mae_error < best_mae_error
                best_mae_error = min(mae_error, best_mae_error)
            else:
                is_best = mae_error > best_mae_error
                best_mae_error = max(mae_error, best_mae_error)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_mae_error': best_mae_error,
                    'optimizer': optimizer.state_dict(),
                    'normalizer': normalizer.state_dict(),
                    'args': vars(args)
                },
                is_best,
                prop=args.property)

        # test best model
        print('---------Evaluate Model on Test Set---------------')
        best_checkpoint = torch.load('../result/' + args.property +
                                     '-model_best.pth.tar')
        model.load_state_dict(best_checkpoint['state_dict'])
        x, t = validate(args,
                        test_loader,
                        model_a,
                        model_b,
                        model,
                        activation_a,
                        activation_b,
                        criterion,
                        normalizer,
                        test=True,
                        tc=True)
        X = torch.cat((X, x), dim=0)
        T = torch.cat((T, t), dim=0)
        x, t = X.numpy(), T.numpy()
        n_max = max(np.max(x), np.max(t))
        n_min = min(np.min(x), np.min(t))
        a = np.linspace(n_min - abs(n_max), n_max + abs(n_max))
        b = a
        plt.rcParams["font.family"] = "Times New Roman"
        plt.plot(a, b, color='blue')
        plt.scatter(t, x, marker=".", color='red', edgecolors='black')
        plt.xlim(n_min - abs(n_min), n_max + abs(n_min))
        plt.ylim(n_min - abs(n_min), n_max + abs(n_min))
        plt.title(
            "Thermal Conductivity Prediction by CGCNN with Combined Model Transfer Learning"
        )
        plt.xlabel("observation")
        plt.ylabel("prediction")
    plt.show()
Esempio n. 11
0
def main():
    global args, best_mae_error

    # load data
    dataset = CIFData(*args.data_options,
                      disable_save_torch=args.disable_save_torch)
    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        train_ratio=args.train_ratio,
        num_workers=args.workers,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        pin_memory=args.cuda,
        train_size=args.train_size,
        val_size=args.val_size,
        test_size=args.test_size,
        return_test=True)

    # Make sure >1 class is present
    if args.task == 'classification':
        total_train = 0
        total_val = 0
        total_test = 0
        for i, (_, target, _) in enumerate(train_loader):
            for target_i in target.squeeze():
                total_train += target_i
        if bool(total_train == 0):
            raise ValueError('All 0s in train')
        elif bool(total_train == 1):
            raise ValueError('All 1s in train')
        for i, (_, target, _) in enumerate(val_loader):
            if len(target) == 1:
                raise ValueError('Only single entry in val')
            for target_i in target.squeeze():
                total_val += target_i
        if bool(total_val == 0):
            raise ValueError('All 0s in val')
        elif bool(total_val == 1):
            raise ValueError('All 1s in val')
        for i, (_, target, _) in enumerate(test_loader):
            if len(target) == 1:
                raise ValueError('Only single entry in test')
            for target_i in target.squeeze():
                total_test += target_i
        if bool(total_test == 0):
            raise ValueError('All 0s in test')
        elif bool(total_test == 1):
            raise ValueError('All 1s in test')

    # make output folder if needed
    if not os.path.exists('output'):
        os.mkdir('output')

    # make and clean torch files if needed
    torch_data_path = os.path.join(args.data_options[0], 'cifdata')
    if args.clean_torch and os.path.exists(torch_data_path):
        shutil.rmtree(torch_data_path)
    if os.path.exists(torch_data_path):
        if not args.clean_torch:
            warnings.warn('Found cifdata folder at ' +
                          torch_data_path+'. Will read in .jsons as-available')
    else:
        os.mkdir(torch_data_path)

    # obtain target value normalizer
    if args.task == 'classification':
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({'mean': 0., 'std': 1.})
    else:
        if len(dataset) < 500:
            warnings.warn('Dataset has less than 500 data points. '
                          'Lower accuracy is expected. ')
            sample_data_list = [dataset[i] for i in range(len(dataset))]
        else:
            sample_data_list = [dataset[i] for i in
                                sample(range(len(dataset)), 500)]
        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
                                atom_fea_len=args.atom_fea_len,
                                n_conv=args.n_conv,
                                h_fea_len=args.h_fea_len,
                                n_h=args.n_h,
                                classification=True if args.task ==
                                'classification' else False)
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(), args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(), args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise NameError('Only SGD or Adam is allowed as --optim')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones,
                            gamma=0.1)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, normalizer)

        # evaluate on validation set
        mae_error = validate(val_loader, model, criterion, normalizer)

        if mae_error != mae_error:
            print('Exit due to NaN')
            sys.exit(1)

        scheduler.step()

        # remember the best mae_eror and save checkpoint
        if args.task == 'regression':
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_mae_error': best_mae_error,
            'optimizer': optimizer.state_dict(),
            'normalizer': normalizer.state_dict(),
            'args': vars(args)
        }, is_best)

    # test best model
    best_checkpoint = torch.load(os.path.join('output', 'model_best.pth.tar'))
    model.load_state_dict(best_checkpoint['state_dict'])
    print('---------Evaluate Best Model on Train Set---------------')
    validate(train_loader, model, criterion, normalizer, test=True,
             csv_name='train_results.csv')
    print('---------Evaluate Best Model on Val Set---------------')
    validate(val_loader, model, criterion, normalizer, test=True,
             csv_name='val_results.csv')
    print('---------Evaluate Best Model on Test Set---------------')
    validate(test_loader, model, criterion, normalizer, test=True,
             csv_name='test_results.csv')
Esempio n. 12
0
def main():
    global args, best_mae_error

    # load dataset: (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id
    dataset = CIFData(args.root + args.target)
    collate_fn = collate_pool
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        train_ratio=args.train_ratio,
        num_workers=args.workers,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        pin_memory=args.cuda,
        return_test=True)

    # obtain target value normalizer
    if args.task == 'classification':
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({'mean': 0., 'std': 1.})
    else:
        sample_data_list = [dataset[i] for i in \
                            sample(range(len(dataset)), 500)]
        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=args.atom_fea_len,
        n_conv=args.n_conv,
        h_fea_len=args.h_fea_len,
        n_h=args.n_h,
        classification=True if args.task == 'classification' else False)
    # pring number of trainable model parameters
    trainable_params = sum(p.numel() for p in model.parameters()
                           if p.requires_grad)
    print('=> number of trainable model parameters: {:d}'.format(
        trainable_params))

    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise NameError('Only SGD or Adam is allowed as --optim')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # TensorBoard writer
    summary_root = './runs/'
    if not os.path.exists(summary_root):
        os.mkdir(summary_root)
    summary_file = summary_root + args.target
    if os.path.exists(summary_file):
        shutil.rmtree(summary_file)
    writer = SummaryWriter(summary_file)

    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_milestones,
                            gamma=0.1)

    for epoch in range(args.start_epoch, args.start_epoch + args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, normalizer,
              writer)

        # evaluate on validation set
        mae_error = validate(val_loader, model, criterion, epoch, normalizer,
                             writer)

        scheduler.step()

        # remember the best mae_eror and save checkpoint
        if args.task == 'regression':
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_mae_error': best_mae_error,
                'optimizer': optimizer.state_dict(),
                'normalizer': normalizer.state_dict(),
                'args': vars(args)
            }, args.target, is_best)