torch.manual_seed(opts.seed) torch.cuda.manual_seed(opts.seed) # Load experiment setting config = get_config(opts.config) input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b'] council_size = config['council']['council_size'] # Setup model and data loader if not 'new_size_a' in config.keys(): config['new_size_a'] = config['new_size'] is_data_A = opts.a2b style_dim = config['gen']['style_dim'] trainer = Council_Trainer(config) only_one = False if 'gen_' in opts.checkpoint[-21:]: state_dict = torch.load(opts.checkpoint, map_location={'cuda:1': 'cuda:0'}) if opts.a2b: trainer.gen_a2b_s[0].load_state_dict(state_dict['a2b']) else: trainer.gen_b2a_s[0].load_state_dict(state_dict['b2a']) council_size = 1 only_one = True else: for i in range(council_size): if opts.a2b: tmp_checkpoint = opts.checkpoint[:-8] + 'a2b_gen_' + str( i) + '_' + opts.checkpoint[-8:] + '.pt' state_dict = torch.load(tmp_checkpoint,
]).cuda(config['cuda_device']) except: # test_display_images_a = torch.stack([test_loader_a[0].dataset[np.random.randint(test_loader_a[0].__len__())] for _ in range(display_size)]).cuda() test_display_images_a = None try: test_display_images_b = torch.stack([ test_loader_b[0].dataset[np.random.randint(test_loader_b[0].__len__())] for _ in range(display_size) ]).cuda(config['cuda_device']) except: test_display_images_b = torch.stack([ test_loader_b[0].dataset[np.random.randint(test_loader_b[0].__len__())] for _ in range(display_size) ]).cuda(config['cuda_device']) trainer = Council_Trainer(config, config['cuda_device']) trainer.cuda(config['cuda_device']) # Setup logger and output folders model_name = os.path.splitext(os.path.basename(opts.config))[0] output_directory = os.path.join(opts.output_path, model_name) checkpoint_directory, image_directory, log_directory = prepare_sub_folder( output_directory) config_backup_folder = os.path.join(output_directory, 'config_backup') if not os.path.exists(config_backup_folder): os.mkdir(config_backup_folder) shutil.copy(opts.config, os.path.join(config_backup_folder, ('config_backup_' +
def loadModel(config, checkpoint, a2b): seed = 1 torch.manual_seed(seed) torch.cuda.manual_seed(seed) # Load experiment setting config = get_config(config) input_dim = config['input_dim_a'] if a2b else config['input_dim_b'] council_size = config['council']['council_size'] style_dim = config['gen']['style_dim'] trainer = Council_Trainer(config) only_one = False if 'gen_' in checkpoint[-21:]: state_dict = torch.load(checkpoint) try: print(state_dict) if a2b: trainer.gen_a2b_s[0].load_state_dict(state_dict['a2b']) else: trainer.gen_b2a_s[0].load_state_dict(state_dict['b2a']) except: print('a2b should be set to ' + str(not a2b) + ' , Or config file could be wrong') a2b = not a2b if a2b: trainer.gen_a2b_s[0].load_state_dict(state_dict['a2b']) else: trainer.gen_b2a_s[0].load_state_dict(state_dict['b2a']) council_size = 1 only_one = True else: for i in range(council_size): try: if a2b: tmp_checkpoint = checkpoint[:-8] + 'a2b_gen_' + str( i) + '_' + checkpoint[-8:] + '.pt' state_dict = torch.load(tmp_checkpoint) trainer.gen_a2b_s[i].load_state_dict(state_dict['a2b']) else: tmp_checkpoint = checkpoint[:-8] + 'b2a_gen_' + str( i) + '_' + checkpoint[-8:] + '.pt' state_dict = torch.load(tmp_checkpoint) trainer.gen_b2a_s[i].load_state_dict(state_dict['b2a']) except: print('a2b should be set to ' + str(not a2b) + ' , Or config file could be wrong') a2b = not a2b if a2b: tmp_checkpoint = checkpoint[:-8] + 'a2b_gen_' + str( i) + '_' + checkpoint[-8:] + '.pt' state_dict = torch.load(tmp_checkpoint) trainer.gen_a2b_s[i].load_state_dict(state_dict['a2b']) else: tmp_checkpoint = checkpoint[:-8] + 'b2a_gen_' + str( i) + '_' + checkpoint[-8:] + '.pt' state_dict = torch.load(tmp_checkpoint) trainer.gen_b2a_s[i].load_state_dict(state_dict['b2a']) trainer.cuda() trainer.eval() return [trainer, config, council_size, style_dim]