def check_cycle_generator(): """Checks the output and number of parameters of the CycleGenerator class. """ state = torch.load('checker_files/cycle_generator.pt') G_XtoY = CycleGenerator(conv_dim=32, init_zero_weights=False) G_XtoY.load_state_dict(state['state_dict']) images = state['input'] cycle_generator_expected = state['output'] output = G_XtoY(images) output_np = output.data.cpu().numpy() if np.allclose(output_np, cycle_generator_expected): print('CycleGenerator output: EQUAL') else: print('CycleGenerator output: NOT EQUAL') num_params = count_parameters(G_XtoY) expected_params = 105856 print('CycleGenerator #params = {}, expected #params = {}, {}'.format( num_params, expected_params, 'EQUAL' if num_params == expected_params else 'NOT EQUAL')) print('-' * 80)
def load_checkpoint(opts): """Loads the generator and discriminator models from checkpoints. """ G_XtoY_path = os.path.join(opts.load, 'G_XtoY.pkl') G_YtoX_path = os.path.join(opts.load, 'G_YtoX.pkl') D_X_path = os.path.join(opts.load, 'D_X.pkl') D_Y_path = os.path.join(opts.load, 'D_Y.pkl') G_XtoY = CycleGenerator(conv_dim=opts.g_conv_dim, init_zero_weights=opts.init_zero_weights) G_YtoX = CycleGenerator(conv_dim=opts.g_conv_dim, init_zero_weights=opts.init_zero_weights) D_X = DCDiscriminator(conv_dim=opts.d_conv_dim) D_Y = DCDiscriminator(conv_dim=opts.d_conv_dim) G_XtoY.load_state_dict( torch.load(G_XtoY_path, map_location=lambda storage, loc: storage)) G_YtoX.load_state_dict( torch.load(G_YtoX_path, map_location=lambda storage, loc: storage)) D_X.load_state_dict( torch.load(D_X_path, map_location=lambda storage, loc: storage)) D_Y.load_state_dict( torch.load(D_Y_path, map_location=lambda storage, loc: storage)) if torch.cuda.is_available(): G_XtoY.cuda() G_YtoX.cuda() D_X.cuda() D_Y.cuda() print('Models moved to GPU.') return G_XtoY, G_YtoX, D_X, D_Y
print(opt) if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") ###### Definition of variables ###### # Networks netG_A2B = CycleGenerator(opt.A_nc, opt.B_nc) netG_B2A = CycleGenerator(opt.B_nc, opt.A_nc) if opt.cuda: netG_A2B.cuda() netG_B2A.cuda() # Load state dicts netG_A2B.load_state_dict(torch.load(opt.generator_A2B)) netG_B2A.load_state_dict(torch.load(opt.generator_B2A)) # Set model's test mode netG_A2B.eval() netG_B2A.eval() # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batch_size, opt.A_nc, opt.size, opt.size) input_B = Tensor(opt.batch_size, opt.B_nc, opt.size, opt.size) # Dataset loader data_loader = datasetloader.DatasetLoader(opt) dataset = data_loader.load_data() dataset_size = len(dataset)