def check_dc_discriminator(): """Checks the output and number of parameters of the DCDiscriminator class. """ state = torch.load('checker_files/dc_discriminator.pt') D = DCDiscriminator(conv_dim=32) D.load_state_dict(state['state_dict']) images = state['input'] dc_discriminator_expected = state['output'] output = D(images) output_np = output.data.cpu().numpy() if np.allclose(output_np, dc_discriminator_expected): print('DCDiscriminator output: EQUAL') else: print('DCDiscriminator output: NOT EQUAL') num_params = count_parameters(D) expected_params = 167872 print('DCDiscriminator #params = {}, expected #params = {}, {}'.format( num_params, expected_params, 'EQUAL' if num_params == expected_params else 'NOT EQUAL')) print('-' * 80)
def load_checkpoint(opts): """Loads the generator and discriminator models from checkpoints. """ G_XtoY_path = os.path.join(opts.load, 'G_XtoY.pkl') G_YtoX_path = os.path.join(opts.load, 'G_YtoX.pkl') D_X_path = os.path.join(opts.load, 'D_X.pkl') D_Y_path = os.path.join(opts.load, 'D_Y.pkl') G_XtoY = CycleGenerator(conv_dim=opts.g_conv_dim, init_zero_weights=opts.init_zero_weights) G_YtoX = CycleGenerator(conv_dim=opts.g_conv_dim, init_zero_weights=opts.init_zero_weights) D_X = DCDiscriminator(conv_dim=opts.d_conv_dim) D_Y = DCDiscriminator(conv_dim=opts.d_conv_dim) G_XtoY.load_state_dict( torch.load(G_XtoY_path, map_location=lambda storage, loc: storage)) G_YtoX.load_state_dict( torch.load(G_YtoX_path, map_location=lambda storage, loc: storage)) D_X.load_state_dict( torch.load(D_X_path, map_location=lambda storage, loc: storage)) D_Y.load_state_dict( torch.load(D_Y_path, map_location=lambda storage, loc: storage)) if torch.cuda.is_available(): G_XtoY.cuda() G_YtoX.cuda() D_X.cuda() D_Y.cuda() print('Models moved to GPU.') return G_XtoY, G_YtoX, D_X, D_Y
def check_dc_discriminator(): """Checks the output and number of parameters of the DCDiscriminator class. """ state = torch.load('/home/love_you/Documents/Study/deep_learning/a4-code/a4-code-v2-updated/checker_files/dc_discriminator.pt') # for key, value in state.items(): # print(key) D = DCDiscriminator(conv_dim=32) # summary(D, input_size=(3, 32, 32)) D.load_state_dict(state['state_dict']) images = state['input'] dc_discriminator_expected = state['output'] output = D(images) output_np = output.data.cpu().numpy() if np.allclose(output_np, dc_discriminator_expected): print('DCDiscriminator output: EQUAL') else: print("output_np: ", output_np.shape) print("dc_discriminator_expected: ", dc_discriminator_expected.shape) print('DCDiscriminator output: NOT EQUAL') num_params = count_parameters(D) expected_params = 167872 print('DCDiscriminator #params = {}, expected #params = {}, {}'.format( num_params, expected_params, 'EQUAL' if num_params == expected_params else 'NOT EQUAL')) print('-' * 80)