runner = Trainer(default_save_path=f"{tt_logger.save_dir}", min_nb_epochs=1, logger=tt_logger, log_save_interval=100, train_percent_check=1., val_percent_check=1., num_sanity_val_steps=5, early_stop_callback=False, **config['trainer_params']) print(f"======= Training {config['model_params']['name']} =======") load_dict = torch.load(config.ckpt_path) experiment.load_state_dict(load_dict['state_dict']) experiment.cuda() experiment.eval() sample_dataloader = experiment.train_dataloader() test_input, test_label = next(iter(sample_dataloader)) #test_input = test_input.to('cuda') test_label = test_label.to('cuda') #imgs = experiment.model.sample(num_samples=64, current_device=0) test_input = scio.loadmat('./cifar10_index.mat') test_input = torch.Tensor(test_input['data']).to('cuda') imgs_recon = experiment.model.generate(test_input, labels=test_label) FID_IS_tf = build_GAN_metric(config.GAN_metric) class SampleFunc(object): def __init__(self, model): self.model = model pass
model_save_path = os.getcwd( ) #'{}/{}/version_{}'.format(config['logging_params']['save_dir'], config['logging_params']['name'], tt_logger.version) parent = '/'.join(model_save_path.split('/')[:-3]) config['logging_params']['save_dir'] = os.path.join( parent, config['logging_params']['save_dir']) config['exp_params']['data_path'] = os.path.join( parent, config['exp_params']['data_path']) print(parent, config['exp_params']['data_path']) model = vae_models[config['model_params']['name']]( imsize=config['exp_params']['img_size'], **config['model_params']) experiment = VAEXperiment(model, config['exp_params']) weights = [x for x in os.listdir(model_save_path) if '.ckpt' in x] weights.sort(key=lambda x: os.path.getmtime(x)) load_weight = weights[-1] print('loading: ', load_weight) checkpoint = torch.load(load_weight) experiment.load_state_dict(checkpoint['state_dict']) _ = experiment.train_dataloader() experiment.eval() experiment.freeze() experiment.sample_interpolate( save_dir=config['logging_params']['save_dir'], name=config['logging_params']['name'], version=config['logging_params']['version'], save_svg=True, other_interpolations=config['logging_params']['other_interpolations'])