def main(): global args, model_args, best_mae_error # load data dataset = CIFData(args.cifpath) collate_fn = collate_pool test_loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=collate_fn, pin_memory=args.cuda, ) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=model_args.atom_fea_len, n_conv=model_args.n_conv, h_fea_len=model_args.h_fea_len, n_h=model_args.n_h, classification=True if model_args.task == "classification" else False, ) if args.cuda: model.cuda() # define loss func and optimizer if model_args.task == "classification": criterion = nn.NLLLoss() else: criterion = nn.MSELoss() normalizer = Normalizer(torch.zeros(3)) # optionally resume from a checkpoint if os.path.isfile(args.modelpath): print("=> loading model '{}'".format(args.modelpath)) checkpoint = torch.load(args.modelpath, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint["state_dict"]) normalizer.load_state_dict(checkpoint["normalizer"]) print("=> loaded model '{}' (epoch {}, validation {})".format( args.modelpath, checkpoint["epoch"], checkpoint["best_mae_error"])) else: print("=> no model found at '{}'".format(args.modelpath)) validate(test_loader, model, criterion, normalizer, test=True, fname="predictions") print(dataset.bad_indices)
def main(): global args, model_args, best_mae_error # load data dataset = CIFData(args.cifpath) collate_fn = collate_pool test_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=collate_fn, pin_memory=args.cuda) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=model_args.atom_fea_len, n_conv=model_args.n_conv, h_fea_len=model_args.h_fea_len, n_h=model_args.n_h, classification=True if model_args.task == 'classification' else False) if args.cuda: model.cuda() # define loss func and optimizer if model_args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() # if args.optim == 'SGD': # optimizer = optim.SGD(model.parameters(), args.lr, # momentum=args.momentum, # weight_decay=args.weight_decay) # elif args.optim == 'Adam': # optimizer = optim.Adam(model.parameters(), args.lr, # weight_decay=args.weight_decay) # else: # raise NameError('Only SGD or Adam is allowed as --optim') normalizer = Normalizer(torch.zeros(3)) # optionally resume from a checkpoint if os.path.isfile(args.modelpath): print("=> loading model '{}'".format(args.modelpath)) checkpoint = torch.load(args.modelpath, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded model '{}' (epoch {}, validation {})".format( args.modelpath, checkpoint['epoch'], checkpoint['best_mae_error'])) else: print("=> no model found at '{}'".format(args.modelpath)) validate(test_loader, model, criterion, normalizer, test=True)
def predict(args, dataset, collate_fn, test_loader): # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=model_args.atom_fea_len, n_conv=model_args.n_conv, h_fea_len=model_args.h_fea_len, n_h=model_args.n_h, classification=True if model_args.task == 'classification' else False) if args.cuda: model.cuda() # define loss func and optimizer if model_args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() normalizer = Normalizer(torch.zeros(3)) # optionally resume from a checkpoint if os.path.isfile(args.modelpath): print("=> loading model '{}'".format(args.modelpath)) checkpoint = torch.load(args.modelpath, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded model '{}' (epoch {}, validation {})".format( args.modelpath, checkpoint['epoch'], checkpoint['best_mae_error'])) else: print("=> no model found at '{}'".format(args.modelpath)) validate(test_loader, model, criterion, normalizer, test=True)
def main(): global args, best_mae_error # load data dataset = CIFData(*args.data_options) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=args.batch_size, train_ratio=args.train_ratio, num_workers=args.workers, val_ratio=args.val_ratio, test_ratio=args.test_ratio, pin_memory=args.cuda, train_size=args.train_size, val_size=args.val_size, test_size=args.test_size, return_test=True) # obtain target value normalizer if args.task == 'classification': normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({'mean': 0., 'std': 1.}) else: if len(dataset) < 500: warnings.warn('Dataset has less than 500 data points. ' 'Lower accuracy is expected. ') sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [ dataset[i] for i in sample(range(len(dataset)), 500) ] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == 'classification' else False) if args.cuda: model.cuda() # define loss func and optimizer if args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NameError('Only SGD or Adam is allowed as --optim') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, normalizer) # evaluate on validation set mae_error = validate(val_loader, model, criterion, normalizer) if mae_error != mae_error: print('Exit due to NaN') sys.exit(1) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == 'regression': is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_mae_error': best_mae_error, 'optimizer': optimizer.state_dict(), 'normalizer': normalizer.state_dict(), 'args': vars(args) }, is_best) # test best model print('---------Evaluate Model on Test Set---------------') best_checkpoint = torch.load('model_best.pth.tar') model.load_state_dict(best_checkpoint['state_dict']) validate(test_loader, model, criterion, normalizer, test=True)
print(target) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=3, train_ratio=0.5, num_workers=2, val_ratio=0.2, test_ratio=0.3, pin_memory=False, train_size=5, val_size=3, test_size=2, return_test=True) _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len=4, n_conv=3, h_fea_len=32, n_h=1, classification=False)
def main(): global args, best_mae_error # load data dataset = CIFData(*args.data_options) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=args.batch_size, train_ratio=args.train_ratio, num_workers=args.workers, val_ratio=args.val_ratio, test_ratio=args.test_ratio, pin_memory=args.cuda, train_size=args.train_size, val_size=args.val_size, test_size=args.test_size, return_val=True, return_test=True, ) # obtain target value normalizer if args.task == "classification": normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({"mean": 0.0, "std": 1.0}) else: if len(dataset) < 500: warnings.warn( "Dataset has less than 500 data points. " "Lower accuracy is expected. " ) sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 500)] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == "classification" else False, dropout_rate=args.dropout_rate, ) if args.cuda: model.cuda() # define loss func and optimizer if args.task == "classification": criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == "SGD": optimizer = optim.SGD( model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optim == "Adam": optimizer = optim.Adam( model.parameters(), args.lr, weight_decay=args.weight_decay ) else: raise NameError("Only SGD or Adam is allowed as --optim") # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] best_mae_error = checkpoint["best_mae_error"] model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) normalizer.load_state_dict(checkpoint["normalizer"]) print( "=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint["epoch"] ) ) else: print("=> no checkpoint found at '{}'".format(args.resume)) scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, normalizer) # evaluate on validation set mae_error = validate(val_loader, model, criterion, normalizer) if mae_error != mae_error: print("Exit due to NaN") sys.exit(1) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == "regression": is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_mae_error": best_mae_error, "optimizer": optimizer.state_dict(), "normalizer": normalizer.state_dict(), "args": vars(args), }, is_best, ) # test best model best_checkpoint = torch.load("model_best.pth.tar") model.load_state_dict(best_checkpoint["state_dict"]) validate( train_loader, model, criterion, normalizer, test=True, fname="train_results" ) validate(val_loader, model, criterion, normalizer, test=True, fname="val_results") validate(test_loader, model, criterion, normalizer, test=True, fname="test_results")
def cv(): global args, best_mae_error if not os.path.exists("./checkpoints"): os.mkdir("./checkpoints") # load data dataset = CIFData(*args.data_options) collate_fn = collate_pool i = 0 train_maes = [] val_maes = [] test_maes = [] for train_loader, val_loader, test_loader in get_cv_loader( dataset=dataset, collate_fn=collate_fn, batch_size=args.batch_size, train_ratio=args.train_ratio, num_workers=args.workers, test_ratio=args.test_ratio, pin_memory=args.cuda, train_size=args.train_size, val_size=args.val_size, test_size=args.test_size, cross_validation=args.cross_validation, ): i += 1 # obtain target value normalizer if args.task == "classification": normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({"mean": 0.0, "std": 1.0}) else: if len(dataset) < 500: warnings.warn( "Dataset has less than 500 data points. " "Lower accuracy is expected. " ) sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [ dataset[i] for i in sample(range(len(dataset)), 500) ] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == "classification" else False, dropout_rate=args.dropout_rate, ) if args.cuda: model.cuda() # define loss func and optimizer if args.task == "classification": criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == "SGD": optimizer = optim.SGD( model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optim == "Adam": optimizer = optim.Adam( model.parameters(), args.lr, weight_decay=args.weight_decay ) else: raise NameError("Only SGD or Adam is allowed as --optim") scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) print(f"Split {i}") if args.task == "regression": best_mae_error = 1e10 else: best_mae_error = 0.0 for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, normalizer) # evaluate on validation set mae_error = validate(val_loader, model, criterion, normalizer) if mae_error != mae_error: print("Exit due to NaN") sys.exit(1) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == "regression": is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_mae_error": best_mae_error, "optimizer": optimizer.state_dict(), "normalizer": normalizer.state_dict(), "args": vars(args), }, is_best, pad_string=f"./checkpoints/{i}", ) # test best model best_checkpoint = torch.load(f"./checkpoints/{i}_model_best.pth.tar") model.load_state_dict(best_checkpoint["state_dict"]) train_mae = validate( train_loader, model, criterion, normalizer, test=True, split=i, fname="train", to_save=False, ) val_mae = validate( val_loader, model, criterion, normalizer, test=True, fname="val", to_save=False, ) test_mae = validate( test_loader, model, criterion, normalizer, test=True, fname="test", to_save=False, ) train_maes.append(train_mae.detach().item()) val_maes.append(val_mae.detach().item()) test_maes.append(test_mae.detach().item()) with open("results.out", "a+") as fw: fw.write("\n") fw.write(f"Avg Train MAE: {np.mean(train_maes):.4f}\n") fw.write(f"Avg Val MAE: {np.mean(val_maes):.4f}\n") fw.write(f"Avg Test MAE: {np.mean(test_maes):.4f}\n")
from cgcnn.data import collate_pool, get_train_val_test_loader if __name__ == '__main__': PATH = r"C:\Users\10989\PycharmProjects\8_CGCNN\cgcnn\pre-trained\band-gap.pth.tar" model_dict = torch.load(PATH, map_location=torch.device('cpu')) dict_name = list(model_dict) for i, p in enumerate(dict_name): print(i, p) for i, p in enumerate(model_dict["state_dict"]): print(i, p,'\t',model_dict["state_dict"][p].size()) from cgcnn.model import CrystalGraphConvNet from cgcnn.data import CIFData model = CrystalGraphConvNet(92, 41, atom_fea_len=64, n_conv=4, h_fea_len=32, n_h=1, classification=False) model.load_state_dict(model_dict["state_dict"]) # load data cif_data = CIFData(r"C:\Users\10989\PycharmProjects\8_CGCNN\cgcnn\data\sample-classification") dataset = CIFData(r"C:\Users\10989\PycharmProjects\8_CGCNN\cgcnn\data\sample-classification") collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=3, train_ratio=0.5, num_workers=2, val_ratio=0.2, test_ratio=0.3,
def main(): global args, model_args, best_mae_error # load data dataset = CIFData(args.cifpath, max_num_nbr=model_args.max_num_nbr, radius=model_args.radius, nn_method=model_args.nn_method, disable_save_torch=args.disable_save_torch) collate_fn = collate_pool if args.train_val_test: train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=model_args.batch_size, train_ratio=model_args.train_ratio, num_workers=args.workers, val_ratio=model_args.val_ratio, test_ratio=model_args.test_ratio, pin_memory=args.cuda, train_size=model_args.train_size, val_size=model_args.val_size, test_size=model_args.test_size, return_test=True) else: test_loader = DataLoader(dataset, batch_size=model_args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=collate_fn, pin_memory=args.cuda) # make and clean torch files if needed torch_data_path = os.path.join(args.cifpath, 'cifdata') if args.clean_torch and os.path.exists(torch_data_path): shutil.rmtree(torch_data_path) if os.path.exists(torch_data_path): if not args.clean_torch: warnings.warn('Found torch .json files at ' + torch_data_path + '. Will read in .jsons as-available') else: os.mkdir(torch_data_path) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=model_args.atom_fea_len, n_conv=model_args.n_conv, h_fea_len=model_args.h_fea_len, n_h=model_args.n_h, classification=True if model_args.task == 'classification' else False, enable_tanh=model_args.enable_tanh) if args.cuda: model.cuda() # define loss func and optimizer if model_args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() normalizer = Normalizer(torch.zeros(3)) # optionally resume from a checkpoint if os.path.isfile(args.modelpath): print("=> loading model '{}'".format(args.modelpath)) checkpoint = torch.load(args.modelpath, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded model '{}' (epoch {}, validation {})".format( args.modelpath, checkpoint['epoch'], checkpoint['best_mae_error'])) else: print("=> no model found at '{}'".format(args.modelpath)) if args.train_val_test: print('---------Evaluate Model on Train Set---------------') validate(train_loader, model, criterion, normalizer, test=True, csv_name='train_results.csv') print('---------Evaluate Model on Val Set---------------') validate(val_loader, model, criterion, normalizer, test=True, csv_name='val_results.csv') print('---------Evaluate Model on Test Set---------------') validate(test_loader, model, criterion, normalizer, test=True, csv_name='test_results.csv') else: print('---------Evaluate Model on Dataset---------------') validate(test_loader, model, criterion, normalizer, test=True, csv_name='predictions.csv')
def tc_trans2(): global args, best_mae_error # load data dataset = CIFData(*args.data_options) collate_fn = collate_pool # obtain target value normalizer if args.task == 'classification': normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({'mean': 0., 'std': 1.}) else: if len(dataset) < 500: warnings.warn('Dataset has less than 500 data points. ' 'Lower accuracy is expected. ') sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [ dataset[i] for i in sample(range(len(dataset)), 500) ] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model_a = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == 'classification' else False) model_b = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == 'classification' else False) model = SimpleNN(in_feature=256, out_feature=1) # pretrained model path model_a_path = '../pre-trained/research-model/bulk_moduli-model_best.pth.tar' model_b_path = '../pre-trained/research-model/sps-model_best.pth.tar' # load latest model state ckpt_a = torch.load(model_a_path) ckpt_b = torch.load(model_b_path) # load model model_a.load_state_dict(ckpt_a['state_dict']) model_b.load_state_dict(ckpt_b['state_dict']) def get_activation_a(name, activation_a): def hook(model, input, output): activation_a[name] = output.detach() return hook def get_activation_b(name, activation_b): def hook(model, input, output): activation_b[name] = output.detach() return hook if args.cuda: model_a.cuda() model_b.cuda() model.cuda() activation_a = {} activation_b = {} # hook the activation function model_a.conv_to_fc.register_forward_hook( get_activation_a('conv_to_fc', activation_a)) model_b.conv_to_fc.register_forward_hook( get_activation_b('conv_to_fc', activation_b)) # define loss func and optimizer if args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NameError('Only SGD or Adam is allowed as --optim') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) X = torch.Tensor() T = torch.Tensor() for i in range(5): total_size = len(dataset) indices = list(range(total_size)) batch_size = args.batch_size num_workers = args.workers pin_memory = args.cuda if i == 0: train_sampler = SubsetRandomSampler(indices[:61]) test_sampler = SubsetRandomSampler(indices[-16:]) if i == 1: x = indices[:45] y = x.extend(indices[-16:]) train_samplre = SubsetRandomSampler(y) test_sampler = SubsetRandomSampler(indices[45:-16]) if i == 2: x = indices[:29] y = x.extend(indices[-32:]) train_samplre = SubsetRandomSampler(y) test_sampler = SubsetRandomSampler(indices[29:-32]) if i == 3: x = indices[:13] y = x.extend(indices[-48:]) train_samplre = SubsetRandomSampler(y) test_sampler = SubsetRandomSampler(indices[13:-48]) if i == 4: y = indices[-64:] train_samplre = SubsetRandomSampler(y) test_sampler = SubsetRandomSampler(indices[:-64]) train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory) test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory) print(test_sampler) for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(args, train_loader, model_a, model_b, model, activation_a, activation_b, criterion, optimizer, epoch, normalizer) # evaluate on validation set mae_error = validate(args, train_loader, model_a, model_b, model, activation_a, activation_b, criterion, normalizer) if mae_error != mae_error: print('Exit due to NaN') sys.exit(1) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == 'regression': is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_mae_error': best_mae_error, 'optimizer': optimizer.state_dict(), 'normalizer': normalizer.state_dict(), 'args': vars(args) }, is_best, prop=args.property) # test best model print('---------Evaluate Model on Test Set---------------') best_checkpoint = torch.load('../result/' + args.property + '-model_best.pth.tar') model.load_state_dict(best_checkpoint['state_dict']) x, t = validate(args, test_loader, model_a, model_b, model, activation_a, activation_b, criterion, normalizer, test=True, tc=True) X = torch.cat((X, x), dim=0) T = torch.cat((T, t), dim=0) x, t = X.numpy(), T.numpy() n_max = max(np.max(x), np.max(t)) n_min = min(np.min(x), np.min(t)) a = np.linspace(n_min - abs(n_max), n_max + abs(n_max)) b = a plt.rcParams["font.family"] = "Times New Roman" plt.plot(a, b, color='blue') plt.scatter(t, x, marker=".", color='red', edgecolors='black') plt.xlim(n_min - abs(n_min), n_max + abs(n_min)) plt.ylim(n_min - abs(n_min), n_max + abs(n_min)) plt.title( "Thermal Conductivity Prediction by CGCNN with Combined Model Transfer Learning" ) plt.xlabel("observation") plt.ylabel("prediction") plt.show()
def main(): global args, best_mae_error # load data dataset = CIFData(*args.data_options, disable_save_torch=args.disable_save_torch) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=args.batch_size, train_ratio=args.train_ratio, num_workers=args.workers, val_ratio=args.val_ratio, test_ratio=args.test_ratio, pin_memory=args.cuda, train_size=args.train_size, val_size=args.val_size, test_size=args.test_size, return_test=True) # Make sure >1 class is present if args.task == 'classification': total_train = 0 total_val = 0 total_test = 0 for i, (_, target, _) in enumerate(train_loader): for target_i in target.squeeze(): total_train += target_i if bool(total_train == 0): raise ValueError('All 0s in train') elif bool(total_train == 1): raise ValueError('All 1s in train') for i, (_, target, _) in enumerate(val_loader): if len(target) == 1: raise ValueError('Only single entry in val') for target_i in target.squeeze(): total_val += target_i if bool(total_val == 0): raise ValueError('All 0s in val') elif bool(total_val == 1): raise ValueError('All 1s in val') for i, (_, target, _) in enumerate(test_loader): if len(target) == 1: raise ValueError('Only single entry in test') for target_i in target.squeeze(): total_test += target_i if bool(total_test == 0): raise ValueError('All 0s in test') elif bool(total_test == 1): raise ValueError('All 1s in test') # make output folder if needed if not os.path.exists('output'): os.mkdir('output') # make and clean torch files if needed torch_data_path = os.path.join(args.data_options[0], 'cifdata') if args.clean_torch and os.path.exists(torch_data_path): shutil.rmtree(torch_data_path) if os.path.exists(torch_data_path): if not args.clean_torch: warnings.warn('Found cifdata folder at ' + torch_data_path+'. Will read in .jsons as-available') else: os.mkdir(torch_data_path) # obtain target value normalizer if args.task == 'classification': normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({'mean': 0., 'std': 1.}) else: if len(dataset) < 500: warnings.warn('Dataset has less than 500 data points. ' 'Lower accuracy is expected. ') sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 500)] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == 'classification' else False) if args.cuda: model.cuda() # define loss func and optimizer if args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NameError('Only SGD or Adam is allowed as --optim') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, normalizer) # evaluate on validation set mae_error = validate(val_loader, model, criterion, normalizer) if mae_error != mae_error: print('Exit due to NaN') sys.exit(1) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == 'regression': is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_mae_error': best_mae_error, 'optimizer': optimizer.state_dict(), 'normalizer': normalizer.state_dict(), 'args': vars(args) }, is_best) # test best model best_checkpoint = torch.load(os.path.join('output', 'model_best.pth.tar')) model.load_state_dict(best_checkpoint['state_dict']) print('---------Evaluate Best Model on Train Set---------------') validate(train_loader, model, criterion, normalizer, test=True, csv_name='train_results.csv') print('---------Evaluate Best Model on Val Set---------------') validate(val_loader, model, criterion, normalizer, test=True, csv_name='val_results.csv') print('---------Evaluate Best Model on Test Set---------------') validate(test_loader, model, criterion, normalizer, test=True, csv_name='test_results.csv')
def main(): global args, best_mae_error # load dataset: (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id dataset = CIFData(args.root + args.target) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=args.batch_size, train_ratio=args.train_ratio, num_workers=args.workers, val_ratio=args.val_ratio, test_ratio=args.test_ratio, pin_memory=args.cuda, return_test=True) # obtain target value normalizer if args.task == 'classification': normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({'mean': 0., 'std': 1.}) else: sample_data_list = [dataset[i] for i in \ sample(range(len(dataset)), 500)] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == 'classification' else False) # pring number of trainable model parameters trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print('=> number of trainable model parameters: {:d}'.format( trainable_params)) if args.cuda: model.cuda() # define loss func and optimizer if args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NameError('Only SGD or Adam is allowed as --optim') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) args.start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # TensorBoard writer summary_root = './runs/' if not os.path.exists(summary_root): os.mkdir(summary_root) summary_file = summary_root + args.target if os.path.exists(summary_file): shutil.rmtree(summary_file) writer = SummaryWriter(summary_file) scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) for epoch in range(args.start_epoch, args.start_epoch + args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, normalizer, writer) # evaluate on validation set mae_error = validate(val_loader, model, criterion, epoch, normalizer, writer) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == 'regression': is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_mae_error': best_mae_error, 'optimizer': optimizer.state_dict(), 'normalizer': normalizer.state_dict(), 'args': vars(args) }, args.target, is_best)