from cgcnn.data_muti import CIFData if __name__ == '__main__': dataset = CIFData( r"C:\Users\10989\PycharmProjects\8_CGCNN\cgcnn\data\sample-regression-muti-output" ) (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id = dataset.__getitem__(0) print(target) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=3, train_ratio=0.5, num_workers=2, val_ratio=0.2, test_ratio=0.3, pin_memory=False, train_size=5, val_size=3, test_size=2, return_test=True) _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
def main(): global args, best_mae_error # load data dataset = CIFData(*args.data_options) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=args.batch_size, train_ratio=args.train_ratio, num_workers=args.workers, val_ratio=args.val_ratio, test_ratio=args.test_ratio, pin_memory=args.cuda, train_size=args.train_size, val_size=args.val_size, test_size=args.test_size, return_test=True) # obtain target value normalizer if args.task == 'classification': normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({'mean': 0., 'std': 1.}) else: if len(dataset) < 500: warnings.warn('Dataset has less than 500 data points. ' 'Lower accuracy is expected. ') sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [ dataset[i] for i in sample(range(len(dataset)), 500) ] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == 'classification' else False) if args.cuda: model.cuda() # define loss func and optimizer if args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NameError('Only SGD or Adam is allowed as --optim') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, normalizer) # evaluate on validation set mae_error = validate(val_loader, model, criterion, normalizer) if mae_error != mae_error: print('Exit due to NaN') sys.exit(1) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == 'regression': is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_mae_error': best_mae_error, 'optimizer': optimizer.state_dict(), 'normalizer': normalizer.state_dict(), 'args': vars(args) }, is_best) # test best model print('---------Evaluate Model on Test Set---------------') best_checkpoint = torch.load('model_best.pth.tar') model.load_state_dict(best_checkpoint['state_dict']) validate(test_loader, model, criterion, normalizer, test=True)
def main(): #taken from sys.argv model_name = sys.argv[1] save_name = sys.argv[2] #var. for dataset loader root_dir = sys.argv[3] max_num_nbr = 8 radius = 4 dmin = 0 step = 0.2 random_seed = 123 batch_size = 64 Ntot = 36000 #Total number of data train_idx = list(range(100)) #do not change val_idx = list(range(100)) #do not change test_idx = list(range(Ntot)) num_workers = 0 pin_memory = True return_test = True #var for model atom_fea_len = 40 h_fea_len = 80 n_conv = 3 n_h = 2 lr = 0.001 lr_decay_rate = 0.98 weight_decay = 0.0 best_mae_error = 1e10 start_epoch = 0 epochs = 200 #setup dataset = CIFData(root_dir, max_num_nbr, radius, dmin, step, random_seed) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset, collate_fn, batch_size, train_idx, val_idx, test_idx, num_workers, pin_memory, return_test) sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 100)] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) #build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len, n_conv, h_fea_len, n_h) model.cuda() criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay) scheduler = ExponentialLR(optimizer, gamma=lr_decay_rate) # optionally resume from a checkpoint print("=> loading checkpoint '{}'".format(model_name)) checkpoint = torch.load(model_name) start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})".format( model_name, checkpoint['epoch'])) print('---------Evaluate Model on Test Set---------------') validate(test_loader, model, criterion, normalizer, test=True, save_name=save_name)
def main(): global args, best_mae_error # load data dataset = CIFData(*args.data_options) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=args.batch_size, train_ratio=args.train_ratio, num_workers=args.workers, val_ratio=args.val_ratio, test_ratio=args.test_ratio, pin_memory=args.cuda, train_size=args.train_size, val_size=args.val_size, test_size=args.test_size, return_val=True, return_test=True, ) # obtain target value normalizer if args.task == "classification": normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({"mean": 0.0, "std": 1.0}) else: if len(dataset) < 500: warnings.warn( "Dataset has less than 500 data points. " "Lower accuracy is expected. " ) sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 500)] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == "classification" else False, dropout_rate=args.dropout_rate, ) if args.cuda: model.cuda() # define loss func and optimizer if args.task == "classification": criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == "SGD": optimizer = optim.SGD( model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, ) elif args.optim == "Adam": optimizer = optim.Adam( model.parameters(), args.lr, weight_decay=args.weight_decay ) else: raise NameError("Only SGD or Adam is allowed as --optim") # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] best_mae_error = checkpoint["best_mae_error"] model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) normalizer.load_state_dict(checkpoint["normalizer"]) print( "=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint["epoch"] ) ) else: print("=> no checkpoint found at '{}'".format(args.resume)) scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, normalizer) # evaluate on validation set mae_error = validate(val_loader, model, criterion, normalizer) if mae_error != mae_error: print("Exit due to NaN") sys.exit(1) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == "regression": is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_mae_error": best_mae_error, "optimizer": optimizer.state_dict(), "normalizer": normalizer.state_dict(), "args": vars(args), }, is_best, ) # test best model best_checkpoint = torch.load("model_best.pth.tar") model.load_state_dict(best_checkpoint["state_dict"]) validate( train_loader, model, criterion, normalizer, test=True, fname="train_results" ) validate(val_loader, model, criterion, normalizer, test=True, fname="val_results") validate(test_loader, model, criterion, normalizer, test=True, fname="test_results")
def main(): global args, model_args, best_mae_error # load data dataset = CIFData(args.cifpath, max_num_nbr=model_args.max_num_nbr, radius=model_args.radius, nn_method=model_args.nn_method, disable_save_torch=args.disable_save_torch) collate_fn = collate_pool if args.train_val_test: train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=model_args.batch_size, train_ratio=model_args.train_ratio, num_workers=args.workers, val_ratio=model_args.val_ratio, test_ratio=model_args.test_ratio, pin_memory=args.cuda, train_size=model_args.train_size, val_size=model_args.val_size, test_size=model_args.test_size, return_test=True) else: test_loader = DataLoader(dataset, batch_size=model_args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=collate_fn, pin_memory=args.cuda) # make and clean torch files if needed torch_data_path = os.path.join(args.cifpath, 'cifdata') if args.clean_torch and os.path.exists(torch_data_path): shutil.rmtree(torch_data_path) if os.path.exists(torch_data_path): if not args.clean_torch: warnings.warn('Found torch .json files at ' + torch_data_path + '. Will read in .jsons as-available') else: os.mkdir(torch_data_path) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=model_args.atom_fea_len, n_conv=model_args.n_conv, h_fea_len=model_args.h_fea_len, n_h=model_args.n_h, classification=True if model_args.task == 'classification' else False, enable_tanh=model_args.enable_tanh) if args.cuda: model.cuda() # define loss func and optimizer if model_args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() normalizer = Normalizer(torch.zeros(3)) # optionally resume from a checkpoint if os.path.isfile(args.modelpath): print("=> loading model '{}'".format(args.modelpath)) checkpoint = torch.load(args.modelpath, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded model '{}' (epoch {}, validation {})".format( args.modelpath, checkpoint['epoch'], checkpoint['best_mae_error'])) else: print("=> no model found at '{}'".format(args.modelpath)) if args.train_val_test: print('---------Evaluate Model on Train Set---------------') validate(train_loader, model, criterion, normalizer, test=True, csv_name='train_results.csv') print('---------Evaluate Model on Val Set---------------') validate(val_loader, model, criterion, normalizer, test=True, csv_name='val_results.csv') print('---------Evaluate Model on Test Set---------------') validate(test_loader, model, criterion, normalizer, test=True, csv_name='test_results.csv') else: print('---------Evaluate Model on Dataset---------------') validate(test_loader, model, criterion, normalizer, test=True, csv_name='predictions.csv')
def main(): #taken from sys.argv chk_name = sys.argv[1] best_name = sys.argv[2] save_name = sys.argv[3] #var. for dataset loader root_dir = '/your/model/path/' max_num_nbr = 8 radius = 4 dmin = 0 step = 0.2 random_seed = 1234 batch_size = 64 N_tot = len(open(root_dir + '/id_prop.csv').readlines()) N_tr = int(N_tot * 0.8) N_val = int(N_tot * 0.1) N_test = N_tot - N_tr - N_val train_idx = list(range(N_tr)) val_idx = list(range(N_tr, N_tr + N_val)) test_idx = list(range(N_tr + N_val, N_tr + N_val + N_test)) num_workers = 0 pin_memory = True return_test = True #var for model # atom_fea_len,h_fea_len,n_conv,n_h,lr_decay_rate = Hyp_loader(root_dir,hyp_idx) atom_fea_len = 90 h_fea_len = 2 * atom_fea_len n_conv = 5 n_h = 2 lr_decay_rate = 0.97 lr = 0.001 weight_decay = 0.0 resume = False resume_path = 'ddd' #var for training best_mae_error = 1e10 start_epoch = 0 epochs = 200 #setup dataset = CIFData(root_dir, max_num_nbr, radius, dmin, step, random_seed) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset, collate_fn, batch_size, train_idx, val_idx, test_idx, num_workers, pin_memory, return_test) sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 500)] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) #build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len, n_conv, h_fea_len, n_h) model.cuda() criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay) scheduler = ExponentialLR(optimizer, gamma=lr_decay_rate) # optionally resume from a checkpoint if resume: print("=> loading checkpoint '{}'".format(args.resume)) start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) t0 = time.time() for epoch in range(start_epoch, epochs): train(train_loader, model, criterion, optimizer, epoch, normalizer) mae_error = validate(val_loader, model, criterion, normalizer) scheduler.step() is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) save_checkpoint( { 'epoch': epoch, 'state_dict': model.state_dict(), 'best_mae_error': best_mae_error, 'optimizer': optimizer.state_dict(), 'normalizer': normalizer.state_dict() }, is_best, chk_name, best_name) t1 = time.time() print('--------Training time in sec-------------') print(t1 - t0) print('---------Best Model on Validation Set---------------') best_checkpoint = torch.load(best_name) print(best_checkpoint['best_mae_error'].cpu().numpy()) print('---------Evaluate Model on Test Set---------------') model.load_state_dict(best_checkpoint['state_dict']) validate(test_loader, model, criterion, normalizer, test=True, save_name=save_name)
def main(): #taken from sys.argv resume = True resume_path = sys.argv[1] #var. for dataset loader root_dir = '/your/data/path/' max_num_nbr = 8 radius = 4 dmin = 0 step = 0.2 random_seed = 1234 batch_size = 64 N_tot = len(open(root_dir + '/id_prop.csv').readlines()) N_tr = int(N_tot * 0.8) N_val = int(N_tot * 0.1) N_test = N_tot - N_tr - N_val # N_test = N_tot train_idx = list(range(N_tr)) val_idx = list(range(N_tr, N_tr + N_val)) test_idx = list(range(N_tot)) num_workers = 0 pin_memory = False return_test = True #var for model atom_fea_len = 90 h_fea_len = 2 * atom_fea_len n_conv = 5 n_h = 2 lr_decay_rate = 0.99 lr = 0.001 weight_decay = 0.0 model_args = { 'radius': radius, 'dmin': dmin, 'step': step, 'batch_size': batch_size, 'random_seed': random_seed, 'N_tr': N_tr, 'N_val': N_val, 'N_test': N_test, 'atom_fea_len': atom_fea_len, 'h_fea_len': h_fea_len, 'n_conv': n_conv, 'n_h': n_h, 'lr': lr, 'lr_decay_rate': lr_decay_rate, 'weight_decay': weight_decay } #var for training best_mae_error = 1e10 start_epoch = 0 epochs = 1000 #setup dataset = CIFData(root_dir, max_num_nbr, radius, dmin, step, random_seed) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset, collate_fn, batch_size, train_idx, val_idx, test_idx, num_workers, pin_memory, return_test) sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 1)] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) #build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len, n_conv, h_fea_len, n_h) model.cuda() criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay) scheduler = ExponentialLR(optimizer, gamma=lr_decay_rate) # optionally resume from a checkpoint if resume: print("=> loading checkpoint '{}'".format(resume_path)) checkpoint = torch.load(resume_path) start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})".format( resume_path, checkpoint['epoch'])) print('---------Evaluate Model on Test Set---------------') save_name = 'dropout_test.csv' validate(test_loader, model, criterion, normalizer, test=True, save_name=save_name)
def main(): global args, best_mae_error # load data dataset = CIFData(*args.data_options, disable_save_torch=args.disable_save_torch) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=args.batch_size, train_ratio=args.train_ratio, num_workers=args.workers, val_ratio=args.val_ratio, test_ratio=args.test_ratio, pin_memory=args.cuda, train_size=args.train_size, val_size=args.val_size, test_size=args.test_size, return_test=True) # Make sure >1 class is present if args.task == 'classification': total_train = 0 total_val = 0 total_test = 0 for i, (_, target, _) in enumerate(train_loader): for target_i in target.squeeze(): total_train += target_i if bool(total_train == 0): raise ValueError('All 0s in train') elif bool(total_train == 1): raise ValueError('All 1s in train') for i, (_, target, _) in enumerate(val_loader): if len(target) == 1: raise ValueError('Only single entry in val') for target_i in target.squeeze(): total_val += target_i if bool(total_val == 0): raise ValueError('All 0s in val') elif bool(total_val == 1): raise ValueError('All 1s in val') for i, (_, target, _) in enumerate(test_loader): if len(target) == 1: raise ValueError('Only single entry in test') for target_i in target.squeeze(): total_test += target_i if bool(total_test == 0): raise ValueError('All 0s in test') elif bool(total_test == 1): raise ValueError('All 1s in test') # make output folder if needed if not os.path.exists('output'): os.mkdir('output') # make and clean torch files if needed torch_data_path = os.path.join(args.data_options[0], 'cifdata') if args.clean_torch and os.path.exists(torch_data_path): shutil.rmtree(torch_data_path) if os.path.exists(torch_data_path): if not args.clean_torch: warnings.warn('Found cifdata folder at ' + torch_data_path+'. Will read in .jsons as-available') else: os.mkdir(torch_data_path) # obtain target value normalizer if args.task == 'classification': normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({'mean': 0., 'std': 1.}) else: if len(dataset) < 500: warnings.warn('Dataset has less than 500 data points. ' 'Lower accuracy is expected. ') sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [dataset[i] for i in sample(range(len(dataset)), 500)] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == 'classification' else False) if args.cuda: model.cuda() # define loss func and optimizer if args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NameError('Only SGD or Adam is allowed as --optim') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, normalizer) # evaluate on validation set mae_error = validate(val_loader, model, criterion, normalizer) if mae_error != mae_error: print('Exit due to NaN') sys.exit(1) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == 'regression': is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_mae_error': best_mae_error, 'optimizer': optimizer.state_dict(), 'normalizer': normalizer.state_dict(), 'args': vars(args) }, is_best) # test best model best_checkpoint = torch.load(os.path.join('output', 'model_best.pth.tar')) model.load_state_dict(best_checkpoint['state_dict']) print('---------Evaluate Best Model on Train Set---------------') validate(train_loader, model, criterion, normalizer, test=True, csv_name='train_results.csv') print('---------Evaluate Best Model on Val Set---------------') validate(val_loader, model, criterion, normalizer, test=True, csv_name='val_results.csv') print('---------Evaluate Best Model on Test Set---------------') validate(test_loader, model, criterion, normalizer, test=True, csv_name='test_results.csv')
def main(): global args, best_mae_error # load dataset: (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id dataset = CIFData(args.root + args.target) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_fn, batch_size=args.batch_size, train_ratio=args.train_ratio, num_workers=args.workers, val_ratio=args.val_ratio, test_ratio=args.test_ratio, pin_memory=args.cuda, return_test=True) # obtain target value normalizer if args.task == 'classification': normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({'mean': 0., 'std': 1.}) else: sample_data_list = [dataset[i] for i in \ sample(range(len(dataset)), 500)] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == 'classification' else False) # pring number of trainable model parameters trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print('=> number of trainable model parameters: {:d}'.format( trainable_params)) if args.cuda: model.cuda() # define loss func and optimizer if args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NameError('Only SGD or Adam is allowed as --optim') # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) args.start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # TensorBoard writer summary_root = './runs/' if not os.path.exists(summary_root): os.mkdir(summary_root) summary_file = summary_root + args.target if os.path.exists(summary_file): shutil.rmtree(summary_file) writer = SummaryWriter(summary_file) scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) for epoch in range(args.start_epoch, args.start_epoch + args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, normalizer, writer) # evaluate on validation set mae_error = validate(val_loader, model, criterion, epoch, normalizer, writer) scheduler.step() # remember the best mae_eror and save checkpoint if args.task == 'regression': is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_mae_error': best_mae_error, 'optimizer': optimizer.state_dict(), 'normalizer': normalizer.state_dict(), 'args': vars(args) }, args.target, is_best)