Esempio n. 1
0
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

# Load experiment setting
config = get_config(opts.config)
input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']
council_size = config['council']['council_size']

# Setup model and data loader
if not 'new_size_a' in config.keys():
    config['new_size_a'] = config['new_size']
is_data_A = opts.a2b

style_dim = config['gen']['style_dim']
trainer = Council_Trainer(config)
only_one = False
if 'gen_' in opts.checkpoint[-21:]:
    state_dict = torch.load(opts.checkpoint, map_location={'cuda:1': 'cuda:0'})
    if opts.a2b:
        trainer.gen_a2b_s[0].load_state_dict(state_dict['a2b'])
    else:
        trainer.gen_b2a_s[0].load_state_dict(state_dict['b2a'])
    council_size = 1
    only_one = True
else:
    for i in range(council_size):
        if opts.a2b:
            tmp_checkpoint = opts.checkpoint[:-8] + 'a2b_gen_' + str(
                i) + '_' + opts.checkpoint[-8:] + '.pt'
            state_dict = torch.load(tmp_checkpoint,
Esempio n. 2
0
    ]).cuda(config['cuda_device'])
except:
    # test_display_images_a = torch.stack([test_loader_a[0].dataset[np.random.randint(test_loader_a[0].__len__())] for _ in range(display_size)]).cuda()
    test_display_images_a = None
try:
    test_display_images_b = torch.stack([
        test_loader_b[0].dataset[np.random.randint(test_loader_b[0].__len__())]
        for _ in range(display_size)
    ]).cuda(config['cuda_device'])
except:
    test_display_images_b = torch.stack([
        test_loader_b[0].dataset[np.random.randint(test_loader_b[0].__len__())]
        for _ in range(display_size)
    ]).cuda(config['cuda_device'])

trainer = Council_Trainer(config, config['cuda_device'])

trainer.cuda(config['cuda_device'])

# Setup logger and output folders
model_name = os.path.splitext(os.path.basename(opts.config))[0]
output_directory = os.path.join(opts.output_path, model_name)
checkpoint_directory, image_directory, log_directory = prepare_sub_folder(
    output_directory)

config_backup_folder = os.path.join(output_directory, 'config_backup')
if not os.path.exists(config_backup_folder):
    os.mkdir(config_backup_folder)
shutil.copy(opts.config,
            os.path.join(config_backup_folder,
                         ('config_backup_' +
Esempio n. 3
0
def loadModel(config, checkpoint, a2b):
    seed = 1

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # Load experiment setting
    config = get_config(config)
    input_dim = config['input_dim_a'] if a2b else config['input_dim_b']
    council_size = config['council']['council_size']

    style_dim = config['gen']['style_dim']
    trainer = Council_Trainer(config)
    only_one = False
    if 'gen_' in checkpoint[-21:]:
        state_dict = torch.load(checkpoint)
        try:
            print(state_dict)
            if a2b:
                trainer.gen_a2b_s[0].load_state_dict(state_dict['a2b'])
            else:
                trainer.gen_b2a_s[0].load_state_dict(state_dict['b2a'])
        except:
            print('a2b should be set to ' + str(not a2b) +
                  ' , Or config file could be wrong')
            a2b = not a2b
            if a2b:
                trainer.gen_a2b_s[0].load_state_dict(state_dict['a2b'])
            else:
                trainer.gen_b2a_s[0].load_state_dict(state_dict['b2a'])

        council_size = 1
        only_one = True
    else:
        for i in range(council_size):
            try:
                if a2b:
                    tmp_checkpoint = checkpoint[:-8] + 'a2b_gen_' + str(
                        i) + '_' + checkpoint[-8:] + '.pt'
                    state_dict = torch.load(tmp_checkpoint)
                    trainer.gen_a2b_s[i].load_state_dict(state_dict['a2b'])
                else:
                    tmp_checkpoint = checkpoint[:-8] + 'b2a_gen_' + str(
                        i) + '_' + checkpoint[-8:] + '.pt'
                    state_dict = torch.load(tmp_checkpoint)
                    trainer.gen_b2a_s[i].load_state_dict(state_dict['b2a'])
            except:
                print('a2b should be set to ' + str(not a2b) +
                      ' , Or config file could be wrong')

                a2b = not a2b
                if a2b:
                    tmp_checkpoint = checkpoint[:-8] + 'a2b_gen_' + str(
                        i) + '_' + checkpoint[-8:] + '.pt'
                    state_dict = torch.load(tmp_checkpoint)
                    trainer.gen_a2b_s[i].load_state_dict(state_dict['a2b'])
                else:
                    tmp_checkpoint = checkpoint[:-8] + 'b2a_gen_' + str(
                        i) + '_' + checkpoint[-8:] + '.pt'
                    state_dict = torch.load(tmp_checkpoint)
                    trainer.gen_b2a_s[i].load_state_dict(state_dict['b2a'])

    trainer.cuda()
    trainer.eval()

    return [trainer, config, council_size, style_dim]