################## ## Specify data ## ################## _, _, data_shape = get_data(args) ################### ## Specify model ## ################### model = get_model(args, data_shape=data_shape) if args.parallel == 'dp': model = DataParallelDistribution(model) checkpoint = torch.load(path_check) model.load_state_dict(checkpoint['model']) print('Loaded weights for model at {}/{} epochs'.format( checkpoint['current_epoch'], args.epochs)) ############ ## Sample ## ############ path_samples = '{}/samples/sample_ep{}_s{}.png'.format( eval_args.model, checkpoint['current_epoch'], eval_args.seed) if not os.path.exists(os.path.dirname(path_samples)): os.mkdir(os.path.dirname(path_samples)) device = 'cuda' if torch.cuda.is_available() else 'cpu' model = model.to(device) model = model.eval()
################## eval_loader, data_shape, cond_shape = get_data(args, eval_only=True) #################### ## Specify models ## #################### device = 'cuda' if torch.cuda.is_available() else 'cpu' # conditional model model = get_model(args, data_shape=data_shape, cond_shape=cond_shape) if args.parallel == 'dp': model = DataParallelDistribution(model) checkpoint = torch.load(path_check, map_location=torch.device(device)) model.load_state_dict(checkpoint['model']) model = model.to(device) model = model.eval() print('Loaded weights for conditional model at {}/{} epochs'.format( checkpoint['current_epoch'], args.epochs)) # prior model prior_model = get_model(prior_args, data_shape=(data_shape[0], data_shape[1] // args.sr_scale_factor, data_shape[2] // args.sr_scale_factor)) if prior_args.parallel == 'dp': prior_model = DataParallelDistribution(prior_model) prior_checkpoint = torch.load(path_prior_check, map_location=torch.device(device)) prior_model.load_state_dict(prior_checkpoint['model'])