示例#1
0
文件: eval.py 项目: preddy5/Im2Vec
    imsize=config['exp_params']['img_size'], **config['model_params'])
experiment = VAEXperiment(model, config['exp_params'])
model_save_path = '{}/{}/version_{}'.format(
    config['logging_params']['save_dir'], config['logging_params']['name'],
    tt_logger.version)

if config['logging_params']['resume'] == None:
    weights = [
        os.path.join(model_save_path, x) for x in os.listdir(model_save_path)
        if '.ckpt' in x
    ]
    weights.sort(key=lambda x: os.path.getmtime(x))
    model_path = weights[-1]
    print('loading: ', model_path)
    experiment = VAEXperiment.load_from_checkpoint(model_path,
                                                   vae_model=model,
                                                   params=config['exp_params'])
else:
    model_path = '{}/{}'.format(model_save_path,
                                config['logging_params']['resume'])
experiment = VAEXperiment.load_from_checkpoint(model_path,
                                               vae_model=model,
                                               params=config['exp_params'])
experiment.eval()
experiment.freeze()
experiment.sample_interpolate(
    save_dir=config['logging_params']['save_dir'],
    name=config['logging_params']['name'],
    version=config['logging_params']['version'],
    save_svg=True,
    other_interpolations=config['logging_params']['other_interpolations'])
示例#2
0
runner = Trainer(default_root_dir=f"{tt_logger.save_dir}",
                 logger=tt_logger,
                 val_check_interval=1.,
                 num_sanity_val_steps=5,
                 **config['trainer_params'])

print(f"======= Training {config['model_params']['name']} =======")
if not args.reg_only:
    runner.fit(experiment)

# NN_Reg part starts here...

dir_path = f"{config['logging_params']['save_dir']}/{config['logging_params']['name']}/version_{config['logging_params']['version']}"
ckpt_path = glob.glob(dir_path + '/checkpoints/*')[0]
experiment = VAEXperiment.load_from_checkpoint(ckpt_path, vae_model=model, params=config['exp_params'],
                                               map_location='cuda:0')

nn_reg_tt_logger = TestTubeLogger(
    save_dir=config['logging_params']['save_dir'],
    name=config['logging_params']['name'],
    debug=False,
    create_git_tag=False,
    version=f"version_{config['logging_params']['version']}/nn_reg",
    prefix='nn_reg_',
)

nn_reg_model = vae_models[config['reg_params']['name']](**config['reg_params'])
nn_reg_experiment = RegExperiment(experiment.model, nn_reg_model,
                                  config['reg_exp_params'])

nn_reg_runner = Trainer(default_root_dir=f"{nn_reg_tt_logger.save_dir}",