예제 #1
0
파일: train.py 프로젝트: yinxx/ProteinGCN
def main():
    global args, best_error_global, best_error_local, savepath, dataset

    parser = buildParser()
    args = parser.parse_args()

    print('Torch Device being used: ', cfg.device)

    # create the savepath
    savepath = args.save_dir + str(args.name) + '/'
    if not os.path.exists(savepath):
        os.makedirs(savepath)

    # Writes to file and also to terminal
    sys.stdout = Logger(savepath)
    print(vars(args))

    best_error_global, best_error_local = 1e10, 1e10

    randomSeed(args.seed)

    # create train/val/test dataset separately
    assert os.path.exists(args.protein_dir), '{} does not exist!'.format(
        args.protein_dir)
    all_dirs = [
        d for d in os.listdir(args.protein_dir)
        if not d.startswith('.DS_Store')
    ]
    dir_len = len(all_dirs)
    indices = list(range(dir_len))
    random.shuffle(indices)

    train_size = math.floor(args.train * dir_len)
    val_size = math.floor(args.val * dir_len)
    test_size = math.floor(args.test * dir_len)
    test_dirs = all_dirs[:test_size]
    train_dirs = all_dirs[test_size:test_size + train_size]
    val_dirs = all_dirs[test_size + train_size:test_size + train_size +
                        val_size]
    print('Testing on {} protein directories:'.format(len(test_dirs)))

    dataset = ProteinDataset(args.pkl_dir,
                             args.id_prop,
                             args.atom_init,
                             random_seed=args.seed)

    print('Dataset length: ', len(dataset))

    # load all model args from pretrained model
    if args.pretrained is not None and os.path.isfile(args.pretrained):
        print("=> loading model params '{}'".format(args.pretrained))
        model_checkpoint = torch.load(
            args.pretrained, map_location=lambda storage, loc: storage)
        model_args = argparse.Namespace(**model_checkpoint['args'])
        # override all args value with model_args
        args.h_a = model_args.h_a
        args.h_g = model_args.h_g
        args.n_conv = model_args.n_conv
        args.random_seed = model_args.seed
        args.lr = model_args.lr

        print("=> loaded model params '{}'".format(args.pretrained))
    else:
        print("=> no model params found at '{}'".format(args.pretrained))

    # build model
    kwargs = {
        'pkl_dir': args.pkl_dir,  # Root directory for data
        'atom_init': args.atom_init,  # Atom Init filename
        'h_a': args.h_a,  # Dim of the hidden atom embedding learnt
        'h_g': args.h_g,  # Dim of the hidden graph embedding after pooling
        'n_conv': args.n_conv,  # Number of GCN layers
        'random_seed': args.seed,  # Seed to fix the simulation
        'lr': args.lr,  # Learning rate for optimizer
    }

    structures, _, _ = dataset[0]
    h_b = structures[1].shape[-1]
    kwargs['h_b'] = h_b  # Dim of the bond embedding initialization

    # Use DataParallel for faster training
    print("Let's use", torch.cuda.device_count(),
          "GPUs and Data Parallel Model.")
    model = ProteinGCN(**kwargs)
    model = torch.nn.DataParallel(model)
    model.cuda()

    print('Trainable Model Parameters: ', count_parameters(model))

    # Create dataloader to iterate through the dataset in batches
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset,
        train_dirs,
        val_dirs,
        test_dirs,
        collate_fn=collate_pool,
        num_workers=args.workers,
        batch_size=args.batch_size,
        pin_memory=False)

    try:
        print('Training data    : ', len(train_loader.sampler))
        print('Validation data  : ', len(val_loader.sampler))
        print('Testing data     : ', len(test_loader.sampler))
    except Exception as e:
        # sometimes test may not be defined
        print('\nException Cause: {}'.format(e.args[0]))

    # obtain target value normalizer
    if len(dataset) < args.avg_sample:
        sample_data_list = [dataset[i] for i in tqdm(range(len(dataset)))]
    else:
        sample_data_list = [
            dataset[i]
            for i in tqdm(random.sample(range(len(dataset)), args.avg_sample))
        ]

    _, _, sample_target = collate_pool(sample_data_list)
    normalizer_global = Normalizer(sample_target[0])
    normalizer_local = Normalizer(torch.tensor([0.0]))
    normalizer_local = Normalizer(sample_target[1])

    # load the model state dict from given pretrained model
    if args.pretrained is not None and os.path.isfile(args.pretrained):
        print("=> loading model '{}'".format(args.pretrained))
        checkpoint = torch.load(args.pretrained,
                                map_location=lambda storage, loc: storage)

        print('Best error global: ', checkpoint['best_error_global'])
        print('Best error local: ', checkpoint['best_error_local'])

        best_error_global = checkpoint['best_error_global']
        best_error_local = checkpoint['best_error_local']

        model.module.load_state_dict(checkpoint['state_dict'])
        model.module.optimizer.load_state_dict(checkpoint['optimizer'])
        normalizer_local.load_state_dict(checkpoint['normalizer_local'])
        normalizer_global.load_state_dict(checkpoint['normalizer_global'])
    else:
        print("=> no model found at '{}'".format(args.pretrained))

    # Main training loop
    for epoch in range(args.epochs):
        # Training
        [train_error_global, train_error_local,
         train_loss] = trainModel(train_loader,
                                  model,
                                  normalizer_global,
                                  normalizer_local,
                                  epoch=epoch)
        # Validation
        [val_error_global, val_error_local,
         val_loss] = trainModel(val_loader,
                                model,
                                normalizer_global,
                                normalizer_local,
                                epoch=epoch,
                                evaluation=True)

        # check for error overflow
        if (val_error_global != val_error_global) or (val_error_local !=
                                                      val_error_local):
            print('Exit due to NaN')
            sys.exit(1)

        # remember the best error and possibly save checkpoint
        is_best = val_error_global < best_error_global
        best_error_global = min(val_error_global, best_error_global)
        best_error_local = val_error_local

        # save best model
        if args.save_checkpoints:
            model.module.save(
                {
                    'epoch': epoch,
                    'state_dict': model.module.state_dict(),
                    'best_error_global': best_error_global,
                    'best_error_local': best_error_local,
                    'optimizer': model.module.optimizer.state_dict(),
                    'normalizer_global': normalizer_global.state_dict(),
                    'normalizer_local': normalizer_local.state_dict(),
                    'args': vars(args)
                }, is_best, savepath)

    # test best model using saved checkpoints
    if args.save_checkpoints and len(test_loader):
        print('---------Evaluate Model on Test Set---------------')
        # this try/except allows the code to test on the go or by defining a pretrained path separately
        try:
            best_checkpoint = torch.load(savepath + 'model_best.pth.tar')
        except Exception as e:
            best_checkpoint = torch.load(args.pretrained)

        model.module.load_state_dict(best_checkpoint['state_dict'])
        [test_error_global, test_error_local,
         test_loss] = trainModel(test_loader,
                                 model,
                                 normalizer_global,
                                 normalizer_local,
                                 testing=True)
예제 #2
0
def main():
    global args, model_args, best_mae_error
    # print(FloatTensor)
    # load data
    dataset = CIFData(args.cifpath)
    collate_fn = collate_pool
    test_loader = DataLoader(
     dataset,
     batch_size=args.batch_size,
     shuffle=True,
     num_workers=args.workers,
     collate_fn=collate_fn,
     pin_memory=args.cuda)

    # build model
    structures, targets, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    n_p = len(targets)
    properties_loss_weight = torch.ones(n_p)
    model = MTCGCNN(
     orig_atom_fea_len,
     nbr_fea_len,
     atom_fea_len=model_args.atom_fea_len,
     n_conv=model_args.n_conv,
     h_fea_len=model_args.h_fea_len,
     n_p=n_p,
     n_hp=model_args.n_hp,
     dropout=model_args.dropout)

    if args.cuda:
        model.cuda()

    properties_loss_weight = torch.ones(n_p)

    if model_args.weights is not None:
        USE_WEIGHTED_LOSS = True
        properties_loss_weight = FloatTensor(model_args.weights)
        print('Using weights: ', properties_loss_weight)

    collate_fn = collate_pool
    # Only training loader needs to be differentiated, val/test only use full dataset
    # obtain target value normalizer
    if len(dataset) < 2000:
        warnings.warn('Dataset has less than 2000 data points. '
         'Lower accuracy is expected. ')
        sample_data_list = [dataset[i] for i in tqdm(range(len(dataset)))]
    else:
        sample_data_list = [dataset[i] for i in
          tqdm(random.sample(range(len(dataset)), 2000))]
    _, sample_target, _ = collate_pool(sample_data_list)
    normalizer = Normalizer(sample_target)

    if args.cuda:
        criterion = ModifiedMSELoss().cuda()
    else:
        criterion = ModifiedMSELoss()
    if model_args.optimizer == 'SGD':
        optimizer = optim.SGD(
         model.parameters(),
         model_args.lr,
         momentum=model_args.momentum,
         weight_decay=model_args.weight_decay)
    elif model_args.optimizer == 'Adam':
        optimizer = optim.Adam(
         model.parameters(),
         model_args.lr,
         weight_decay=model_args.weight_decay)
    else:
        raise NameError('Only SGD or Adam is allowed as optimizer')



    # optionally resume from a checkpoint
    if os.path.isfile(args.modelpath):
        print("=> loading model '{}'".format(args.modelpath))
        checkpoint = torch.load(
        args.modelpath, map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        normalizer.load_state_dict(checkpoint['normalizer'])
        print("=> loaded model '{}' (epoch {}, validation {})".format(
         args.modelpath, checkpoint['epoch'], checkpoint['best_error']))
    else:
        print("=> no model found at '{}'".format(args.modelpath))

    # validate(test_loader, model,n_p, criterion, normalizer, test=True)
    validate(test_loader, model, criterion, normalizer, n_p, properties_loss_weight, test=True, print_checkpoints=True)
예제 #3
0
파일: main.py 프로젝트: lukemelas/mt-cgcnn
def main():
    global args, best_mae_error

    # Dataset from CIF files
    dataset = CIFData(*args.data_options)
    print(f'Dataset size: {len(dataset)}')

    # Dataloader from dataset
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_pool,
        batch_size=args.batch_size,
        train_size=args.train_size,
        num_workers=args.workers,
        val_size=args.val_size,
        test_size=args.test_size,
        pin_memory=args.cuda,
        return_test=True)

    # Initialize data normalizer with sample of 500 points
    if args.task == 'classification':
        normalizer = Normalizer(torch.zeros(2))
        normalizer.load_state_dict({'mean': 0., 'std': 1.})
    elif args.task == 'regression':
        if len(dataset) < 500:
            warnings.warn('Dataset has less than 500 data points. '
                          'Lower accuracy is expected. ')
            sample_data_list = [dataset[i] for i in range(len(dataset))]
        else:
            sample_data_list = [
                dataset[i] for i in sample(range(len(dataset)), 500)
            ]
        _, sample_target, _ = collate_pool(sample_data_list)
        normalizer = Normalizer(sample_target)
    else:
        raise NameError('task argument must be regression or classification')

    # Build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len,
                                nbr_fea_len,
                                atom_fea_len=args.atom_fea_len,
                                n_conv=args.n_conv,
                                h_fea_len=args.h_fea_len,
                                n_h=args.n_h,
                                classification=(args.task == 'classification'))

    # GPU
    if args.cuda:
        model.cuda()

    # Loss function
    criterion = nn.NLLLoss() if args.task == 'classification' else nn.MSELoss()

    # Optimizer
    if args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise NameError('optim argument must be SGD or Adam')

    # Scheduler
    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_milestones,
                            gamma=0.1)

    # Resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Train
    for epoch in range(args.start_epoch, args.epochs):

        # Train (one epoch)
        train(train_loader, model, criterion, optimizer, epoch, normalizer)

        # Validate
        mae_error = validate(val_loader, model, criterion, normalizer)
        assert mae_error == mae_error, 'NaN :('

        # Step learning rate scheduler
        scheduler.step(mae_error)

        # Save checkpoint
        if args.task == 'regression':
            is_best = mae_error < best_mae_error
            best_mae_error = min(mae_error, best_mae_error)
        else:
            is_best = mae_error > best_mae_error
            best_mae_error = max(mae_error, best_mae_error)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_mae_error': best_mae_error,
                'optimizer': optimizer.state_dict(),
                'normalizer': normalizer.state_dict(),
                'args': vars(args)
            }, is_best)

    # Evaluate best model on test set
    print('--------- Evaluate model on test set ---------------')
    best_checkpoint = torch.load('model_best.pth.tar')
    model.load_state_dict(best_checkpoint['state_dict'])
    validate(test_loader, model, criterion, normalizer, test=True)
예제 #4
0
def main(root_dir, save_dir, disable_cuda=True, workers=0, epochs=30,
				start_epoch=0, batch_size=256, lr=0.01, lr_milestones=[100], momentum=0.9,
				weight_decay=0.0, print_freq=10, resume='', train_size=None, val_size=1000,
				test_size=1000, optimizer='SGD', atom_fea_len=64, h_fea_len=128, n_conv=3, n_hp=1,
				print_checkpoints=False, save_checkpoints=False, scheduler='MultiStepLR', metric='mae',
				seed=123, weights=None, dropout=0):
	global args, best_error, USE_WEIGHTED_LOSS, STORE_GRAD
	args = Argument(root_dir, save_dir, disable_cuda=disable_cuda,
			workers=workers, epochs=epochs, start_epoch=start_epoch, batch_size=batch_size, lr=lr,
			lr_milestones=lr_milestones, momentum=momentum, weight_decay=weight_decay, print_freq=print_freq,
			resume=resume, train_size=train_size, val_size=val_size, test_size=test_size, optimizer=optimizer,
			atom_fea_len=atom_fea_len, h_fea_len=h_fea_len, n_conv=n_conv, n_hp=n_hp, weights=weights,
			scheduler=scheduler, metric=metric, seed=seed, dropout=dropout)

	print(vars(args))

	best_error = 1e10
	best_error_vec = None

	# load data
	print("Loading datasets...")
	full_dataset = CIFData(args.root_dir, random_seed=args.seed)

	# build model
	structures, targets, _ = full_dataset[0]
	orig_atom_fea_len = structures[0].shape[-1]
	nbr_fea_len = structures[1].shape[-1]
	n_p = len(targets)
	print("Predicting ", n_p, " properties!!")
	model = MTCGCNN(orig_atom_fea_len, nbr_fea_len,
								atom_fea_len=args.atom_fea_len,
								n_conv=args.n_conv,
								h_fea_len=args.h_fea_len,
								n_p=n_p, n_hp=args.n_hp, dropout=args.dropout)

	if args.cuda:
		model.cuda()

	# set some defaults
	properties_loss_weight = torch.ones(n_p)

	if args.weights is not None:
		USE_WEIGHTED_LOSS = True
		properties_loss_weight = FloatTensor(args.weights)
		print('Using weights: ', properties_loss_weight)

	collate_fn = collate_pool
	# Only training loader needs to be differentiated, val/test only use full dataset
	train_loader, val_loader, test_loader = get_train_val_test_loader(
		dataset=full_dataset, collate_fn=collate_fn, batch_size=args.batch_size,
		train_size=args.train_size, num_workers=args.workers,
		val_size=args.val_size, test_size=args.test_size,
		pin_memory=args.cuda, return_test=True, return_val=True)
	
	# obtain target value normalizer
	if len(full_dataset) < 2000:
		warnings.warn('Dataset has less than 2000 data points. '
			'Lower accuracy is expected. ')
		sample_data_list = [full_dataset[i] for i in tqdm(range(len(full_dataset)))]
	else:
		sample_data_list = [full_dataset[i] for i in
							tqdm(random.sample(range(len(full_dataset)), 2000))]
	_, sample_target, _ = collate_pool(sample_data_list)
	normalizer = Normalizer(sample_target)

	# define loss func and optimizer
	if args.cuda:
		criterion = ModifiedMSELoss().cuda()
	else:
		criterion = ModifiedMSELoss()
	if args.optimizer == 'SGD':
		optimizer = optim.SGD(model.parameters(), args.lr,
							  momentum=args.momentum,
							  weight_decay=args.weight_decay)
	elif args.optimizer == 'Adam':
		optimizer = optim.Adam(model.parameters(), args.lr,
							weight_decay=args.weight_decay)
	else:
		raise NameError('Only SGD or Adam is allowed as optimizer')

	# optionally resume from a checkpoint
	if args.resume:
		if os.path.isfile(args.resume):
			print("=> loading checkpoint '{}'".format(args.resume))
			checkpoint = torch.load(args.resume)
			args.start_epoch = checkpoint['epoch']
			best_error = checkpoint['best_error']
			model.load_state_dict(checkpoint['state_dict'])
			optimizer.load_state_dict(checkpoint['optimizer'])
			normalizer.load_state_dict(checkpoint['normalizer'])
			print("=> loaded checkpoint '{}' (epoch {})"
				  .format(args.resume, checkpoint['epoch']))
		else:
			print("=> no checkpoint found at '{}'".format(args.resume))

	if args.scheduler == 'MultiStepLR':
		scheduler = MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=0.5)
	elif args.scheduler == 'ReduceLROnPlateau':
		scheduler = ReduceLROnPlateau(optimizer, patience=10, factor=0.8, verbose=True)

	train_error_vec_per_epoch = []
	val_error_vec_per_epoch = []
	train_loss_vec_per_epoch = []
	val_loss_vec_per_epoch = []
	train_loss_list = []
	train_error_list = []
	val_loss_list = []
	val_error_list = []
	for epoch in range(args.start_epoch, args.epochs):
		# train for one epoch
		[train_error_vec, train_loss_vec] = train(train_loader, model, criterion, optimizer,
							epoch, normalizer, n_p, properties_loss_weight,
							print_checkpoints=print_checkpoints)

		train_loss, train_error = torch.mean(train_loss_vec.avg).item(),\
										torch.mean(train_error_vec.avg).item()
		print('Training Error: %0.3f  Loss: %0.3f' % (train_error, train_loss))
		train_loss_list.append(train_loss)
		train_error_list.append(train_error)

		# evaluate on validation set
		[error, val_error_vec, val_loss_vec] = validate(val_loader, model, criterion,
										normalizer, n_p, properties_loss_weight,
										print_checkpoints=print_checkpoints)

		val_loss, val_error = torch.mean(val_loss_vec.avg).item(),\
										torch.mean(val_error_vec.avg).item()
		val_loss_list.append(val_loss)
		val_error_list.append(val_error)

		if error != error:
			print('Exit due to NaN')
			sys.exit(1)

		if args.scheduler == 'MultiStepLR':
			scheduler.step()
		elif args.scheduler == 'ReduceLROnPlateau': 
			scheduler.step(error)

		# store the error values from previous iteration - useful for plotting
		train_error_vec_per_epoch.append(train_error_vec.avg.cpu().numpy().squeeze())
		val_error_vec_per_epoch.append(val_error_vec.avg.cpu().numpy().squeeze())
		# store the loss values from previous iteration - useful for plotting
		train_loss_vec_per_epoch.append(train_loss_vec.avg.cpu().numpy().squeeze())
		val_loss_vec_per_epoch.append(val_loss_vec.avg.cpu().numpy().squeeze())

		# remember the best error and possibly save checkpoint
		is_best = error < best_error
		if is_best:
			best_error_vec = val_error_vec.avg.squeeze()
		best_error = min(error, best_error)
		
		if save_checkpoints:
			save_checkpoint({
				'epoch': epoch + 1,
				'state_dict': model.state_dict(),
				'best_error': best_error,
				'optimizer': optimizer.state_dict(),
				'normalizer': normalizer.state_dict(),
				'args': vars(args)
			}, is_best)

	# Draw some meaningful plots
	
	if save_checkpoints:
		# Plot1: individual property error vs epoch for all properties
		plotMultiGraph(np.array(train_error_vec_per_epoch), np.array(val_error_vec_per_epoch),
					path=args.save_dir, name='train_val_err_vs_epoch')
		np.savetxt(args.save_dir + 'train_error.txt', np.array(train_error_vec_per_epoch))
		np.savetxt(args.save_dir + 'val_error.txt', np.array(val_error_vec_per_epoch))

		# Plot2: individual property loss vs epoch for all properties
		plotMultiGraph(np.array(train_loss_vec_per_epoch), np.array(val_loss_vec_per_epoch),
					path=args.save_dir, name='train_val_loss_vs_epoch')
		np.savetxt(args.save_dir + 'train_loss.txt', np.array(train_loss_vec_per_epoch))
		np.savetxt(args.save_dir + 'val_loss.txt', np.array(val_loss_vec_per_epoch))

		# Plot3: average loss vs epoch
		plotGraph(train_loss_list, val_loss_list, path=args.save_dir, name='train_val_loss_avg_vs_epoch')
		np.savetxt(args.save_dir + 'train_loss_avg.txt', np.array(train_loss_list))
		np.savetxt(args.save_dir + 'val_loss_avg.txt', np.array(val_loss_list))

		# Plot4: error vs epoch overall
		plotGraph(train_error_list, val_error_list, path=args.save_dir, name='train_val_err_avg_vs_epoch')
		np.savetxt(args.save_dir + 'train_error_avg.txt', np.array(train_error_list))
		np.savetxt(args.save_dir + 'val_error_avg.txt', np.array(val_error_list))

	# test best model using saved checkpoints
	if save_checkpoints:
		print('---------Evaluate Model on Test Set---------------')
		best_checkpoint = torch.load(args.save_dir + 'model_best.pth.tar')
		model.load_state_dict(best_checkpoint['state_dict'])
		[test_error, test_error_vec, test_loss_vec] = validate(test_loader, model, criterion, normalizer, n_p,
									properties_loss_weight, test=True, print_checkpoints=print_checkpoints)
		return best_error.item(), test_error.item(), test_error_vec.avg.cpu().numpy().squeeze(),\
					test_loss_vec.avg.cpu().numpy().squeeze()

	return best_error.item(), None, best_error_vec.cpu().numpy(), None
예제 #5
0
파일: main.py 프로젝트: guozhn/AECNN
def main():
    global opt, best_mae_error

    #dataset = CIFData(*opt.dataroot)
    dataset = h5(*opt.dataroot) 

    collate_fn = collate_pool
    
    train_loader, val_loader, test_loader = get_train_val_test_loader(
            dataset=dataset,collate_fn=collate_fn,batch_size=opt.batch_size,
            train_size=opt.train_size, num_workers=opt.workers,
            val_size=opt.val_size, test_size=opt.test_size,pin_memory=opt.cuda,
            return_test=True)
    # obtain target value normalizer
    sample_data_list = [dataset[i] for i in
                        sample(range(len(dataset)), 1000)]
    input, sample_target,_ = collate_pool(sample_data_list)
    input_1=input[0]
    normalizer = Normalizer(sample_target)
    s = Normalizer(input_1)


    model=NET()

    if torch.cuda.is_available():
        print('cuda is ok')
        model = model.cuda()

    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), opt.lr,
                            momentum=opt.momentum,
                            weight_decay=opt.weight_decay)
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint['epoch']
            best_mae_error = checkpoint['best_mae_error']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            normalizer.load_state_dict(checkpoint['normalizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch']))   

        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))
    scheduler = MultiStepLR(optimizer, milestones=opt.lr_milestones,
                            gamma=0.1)
    for epoch in range(opt.start_epoch,opt.epochs):
        train(train_loader, model, criterion, optimizer, epoch,normalizer,s)

        mae_error = validate(val_loader, model, criterion, normalizer,s) 

        if mae_error != mae_error:
            print('Exit due to NaN')
            sys.exit(1)
        is_best = mae_error < best_mae_error
        best_mae_error = min(mae_error, best_mae_error)
        
        save_checkpoint({ 
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_mae_error': best_mae_error,
            'optimizer': optimizer.state_dict(),
            'normalizer': normalizer.state_dict(),
            'opt': vars(opt)
        }, is_best)
        # test bset model
    print('---------Evaluate Model on Test Set---------------')
    best_checkpoint = torch.load('model_best.pth.tar')
    model.load_state_dict(best_checkpoint['state_dict'])
    validate(test_loader, model, criterion, normalizer, s,test=True)