def main(): global args, model_args, best_mae_error # load data dataset = CIFData(args.cifpath) collate_fn = collate_pool test_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=collate_fn, pin_memory=args.cuda) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=model_args.atom_fea_len, n_conv=model_args.n_conv, h_fea_len=model_args.h_fea_len, n_h=model_args.n_h, classification=True if model_args.task == 'classification' else False) if args.cuda: model.cuda() # define loss func and optimizer if model_args.task == 'classification': criterion = nn.NLLLoss() else: criterion = nn.MSELoss() normalizer = Normalizer(torch.zeros(3)) # optionally resume from a checkpoint if os.path.isfile(args.modelpath): print("=> loading model '{}'".format(args.modelpath)) checkpoint = torch.load(args.modelpath, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded model '{}' (epoch {}, validation {})".format( args.modelpath, checkpoint['epoch'], checkpoint['best_mae_error'])) else: print("=> no model found at '{}'".format(args.modelpath)) validate(test_loader, model, criterion, normalizer, test=True)
def main(): global args, best_error_global, best_error_local, savepath, dataset parser = buildParser() args = parser.parse_args() print('Torch Device being used: ', cfg.device) # create the savepath savepath = args.save_dir + str(args.name) + '/' if not os.path.exists(savepath): os.makedirs(savepath) # Writes to file and also to terminal sys.stdout = Logger(savepath) print(vars(args)) best_error_global, best_error_local = 1e10, 1e10 randomSeed(args.seed) # create train/val/test dataset separately assert os.path.exists(args.protein_dir), '{} does not exist!'.format( args.protein_dir) all_dirs = [ d for d in os.listdir(args.protein_dir) if not d.startswith('.DS_Store') ] dir_len = len(all_dirs) indices = list(range(dir_len)) random.shuffle(indices) train_size = math.floor(args.train * dir_len) val_size = math.floor(args.val * dir_len) test_size = math.floor(args.test * dir_len) test_dirs = all_dirs[:test_size] train_dirs = all_dirs[test_size:test_size + train_size] val_dirs = all_dirs[test_size + train_size:test_size + train_size + val_size] print('Testing on {} protein directories:'.format(len(test_dirs))) dataset = ProteinDataset(args.pkl_dir, args.id_prop, args.atom_init, random_seed=args.seed) print('Dataset length: ', len(dataset)) # load all model args from pretrained model if args.pretrained is not None and os.path.isfile(args.pretrained): print("=> loading model params '{}'".format(args.pretrained)) model_checkpoint = torch.load( args.pretrained, map_location=lambda storage, loc: storage) model_args = argparse.Namespace(**model_checkpoint['args']) # override all args value with model_args args.h_a = model_args.h_a args.h_g = model_args.h_g args.n_conv = model_args.n_conv args.random_seed = model_args.seed args.lr = model_args.lr print("=> loaded model params '{}'".format(args.pretrained)) else: print("=> no model params found at '{}'".format(args.pretrained)) # build model kwargs = { 'pkl_dir': args.pkl_dir, # Root directory for data 'atom_init': args.atom_init, # Atom Init filename 'h_a': args.h_a, # Dim of the hidden atom embedding learnt 'h_g': args.h_g, # Dim of the hidden graph embedding after pooling 'n_conv': args.n_conv, # Number of GCN layers 'random_seed': args.seed, # Seed to fix the simulation 'lr': args.lr, # Learning rate for optimizer } structures, _, _ = dataset[0] h_b = structures[1].shape[-1] kwargs['h_b'] = h_b # Dim of the bond embedding initialization # Use DataParallel for faster training print("Let's use", torch.cuda.device_count(), "GPUs and Data Parallel Model.") model = ProteinGCN(**kwargs) model = torch.nn.DataParallel(model) model.cuda() print('Trainable Model Parameters: ', count_parameters(model)) # Create dataloader to iterate through the dataset in batches train_loader, val_loader, test_loader = get_train_val_test_loader( dataset, train_dirs, val_dirs, test_dirs, collate_fn=collate_pool, num_workers=args.workers, batch_size=args.batch_size, pin_memory=False) try: print('Training data : ', len(train_loader.sampler)) print('Validation data : ', len(val_loader.sampler)) print('Testing data : ', len(test_loader.sampler)) except Exception as e: # sometimes test may not be defined print('\nException Cause: {}'.format(e.args[0])) # obtain target value normalizer if len(dataset) < args.avg_sample: sample_data_list = [dataset[i] for i in tqdm(range(len(dataset)))] else: sample_data_list = [ dataset[i] for i in tqdm(random.sample(range(len(dataset)), args.avg_sample)) ] _, _, sample_target = collate_pool(sample_data_list) normalizer_global = Normalizer(sample_target[0]) normalizer_local = Normalizer(torch.tensor([0.0])) normalizer_local = Normalizer(sample_target[1]) # load the model state dict from given pretrained model if args.pretrained is not None and os.path.isfile(args.pretrained): print("=> loading model '{}'".format(args.pretrained)) checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage) print('Best error global: ', checkpoint['best_error_global']) print('Best error local: ', checkpoint['best_error_local']) best_error_global = checkpoint['best_error_global'] best_error_local = checkpoint['best_error_local'] model.module.load_state_dict(checkpoint['state_dict']) model.module.optimizer.load_state_dict(checkpoint['optimizer']) normalizer_local.load_state_dict(checkpoint['normalizer_local']) normalizer_global.load_state_dict(checkpoint['normalizer_global']) else: print("=> no model found at '{}'".format(args.pretrained)) # Main training loop for epoch in range(args.epochs): # Training [train_error_global, train_error_local, train_loss] = trainModel(train_loader, model, normalizer_global, normalizer_local, epoch=epoch) # Validation [val_error_global, val_error_local, val_loss] = trainModel(val_loader, model, normalizer_global, normalizer_local, epoch=epoch, evaluation=True) # check for error overflow if (val_error_global != val_error_global) or (val_error_local != val_error_local): print('Exit due to NaN') sys.exit(1) # remember the best error and possibly save checkpoint is_best = val_error_global < best_error_global best_error_global = min(val_error_global, best_error_global) best_error_local = val_error_local # save best model if args.save_checkpoints: model.module.save( { 'epoch': epoch, 'state_dict': model.module.state_dict(), 'best_error_global': best_error_global, 'best_error_local': best_error_local, 'optimizer': model.module.optimizer.state_dict(), 'normalizer_global': normalizer_global.state_dict(), 'normalizer_local': normalizer_local.state_dict(), 'args': vars(args) }, is_best, savepath) # test best model using saved checkpoints if args.save_checkpoints and len(test_loader): print('---------Evaluate Model on Test Set---------------') # this try/except allows the code to test on the go or by defining a pretrained path separately try: best_checkpoint = torch.load(savepath + 'model_best.pth.tar') except Exception as e: best_checkpoint = torch.load(args.pretrained) model.module.load_state_dict(best_checkpoint['state_dict']) [test_error_global, test_error_local, test_loss] = trainModel(test_loader, model, normalizer_global, normalizer_local, testing=True)
def main(): global args, best_mae_error # Dataset from CIF files dataset = CIFData(*args.data_options) print(f'Dataset size: {len(dataset)}') # Dataloader from dataset train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, collate_fn=collate_pool, batch_size=args.batch_size, train_size=args.train_size, num_workers=args.workers, val_size=args.val_size, test_size=args.test_size, pin_memory=args.cuda, return_test=True) # Initialize data normalizer with sample of 500 points if args.task == 'classification': normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({'mean': 0., 'std': 1.}) elif args.task == 'regression': if len(dataset) < 500: warnings.warn('Dataset has less than 500 data points. ' 'Lower accuracy is expected. ') sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [ dataset[i] for i in sample(range(len(dataset)), 500) ] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) else: raise NameError('task argument must be regression or classification') # Build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len=args.atom_fea_len, n_conv=args.n_conv, h_fea_len=args.h_fea_len, n_h=args.n_h, classification=(args.task == 'classification')) # GPU if args.cuda: model.cuda() # Loss function criterion = nn.NLLLoss() if args.task == 'classification' else nn.MSELoss() # Optimizer if args.optim == 'SGD': optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optim == 'Adam': optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NameError('optim argument must be SGD or Adam') # Scheduler scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.1) # Resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_mae_error = checkpoint['best_mae_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) normalizer.load_state_dict(checkpoint['normalizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # Train for epoch in range(args.start_epoch, args.epochs): # Train (one epoch) train(train_loader, model, criterion, optimizer, epoch, normalizer) # Validate mae_error = validate(val_loader, model, criterion, normalizer) assert mae_error == mae_error, 'NaN :(' # Step learning rate scheduler scheduler.step(mae_error) # Save checkpoint if args.task == 'regression': is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_mae_error': best_mae_error, 'optimizer': optimizer.state_dict(), 'normalizer': normalizer.state_dict(), 'args': vars(args) }, is_best) # Evaluate best model on test set print('--------- Evaluate model on test set ---------------') best_checkpoint = torch.load('model_best.pth.tar') model.load_state_dict(best_checkpoint['state_dict']) validate(test_loader, model, criterion, normalizer, test=True)