def main(): parser = _build_parser() args = parser.parse_args() logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') model = None if args.model == 'vae': model = VAE().double().to(device) elif args.model == 'hm': model = HM().double().to(device) else: logging.critical('model unimplemented: %s' % args.model) return if not args.out.exists(): args.out.mkdir(parents=True) _, test_ds = build_datasets(args.im_path, train_test_split=1) ckpt = torch.load(args.save_path, map_location=device) model.load_state_dict(ckpt['model_state_dict']) model.eval() _sample_images(model, args.batch, args.samples, test_ds, args.out)
def reset_gen(): if args.model in ['iagan_began_cs']: gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) img_size = 128 elif args.model in ['iagan_dcgan_cs']: gen = dcgan_generator() t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt')) gen.load_state_dict(t) gen = gen.eval().to(DEVICE) img_size = 64 elif args.model in ['iagan_vanilla_vae_cs']: gen = VAE() t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt') gen.load_state_dict(t) gen = gen.eval().to(DEVICE) gen = gen.decoder img_size = 128 else: raise NotImplementedError() return gen, img_size
def main(): parser = _build_parser() args = parser.parse_args() logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') model = None if args.model == 'vae': model = VAE().double().to(device) else: logging.critical('model unimplemented: %s' % args.model) return if not args.out.parent.exists(): args.out.parent.mkdir() _, test_ds = build_datasets(args.im_path, train_test_split=1) ckpt = torch.load(args.save_path, map_location=device) model.load_state_dict(ckpt['model_state_dict']) model.eval() with torch.no_grad(): samps = model.sample(args.samples).reshape(-1, *IM_DIMS, 3) loader = DataLoader(test_ds, batch_size=args.batch, num_workers=args.workers, pin_memory=torch.cuda.is_available()) record = _init_record(samps) with tqdm(total=TOTAL_IMAGES) as pbar: for chunk in loader: _update_winner(chunk.reshape(-1, *IM_DIMS, 3), record, pbar) np.save(args.out, record['pair']) print('final distances:', record['distance'])
def reset_gen(model): if model == 'began': gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) img_size = 128 elif model == 'vae': gen = VAE() t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt') gen.load_state_dict(t) gen = gen.eval().to(DEVICE) gen = gen.decoder img_size = 128 elif model == 'dcgan': gen = dcgan_generator() t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt')) gen.load_state_dict(t) gen = gen.eval().to(DEVICE) img_size = 64 return gen, img_size
white_to_asian_lda_acc_pca = [] white_to_black_lda_acc_pca = [] asian_to_black_lda_acc_pca = [] male_to_female_lda_acc_pca = [] for conf in configs: print('processing conf: {}'.format(conf.name)) out_dir = out_path / conf.name if not out_dir.exists(): out_dir.mkdir(parents=True) # set up model model = VAE(**conf.params) ckpt = torch.load(conf.save_path, map_location=torch.device('cpu')) model.load_state_dict(ckpt['model_state_dict']) model.eval() # encode data test_len = len(test_ds) ldims = model.latent_dims mu_points = np.zeros((test_len, ldims)) var_points = np.zeros((test_len, ldims)) feats = [] if not (out_dir / 'mu_points.npy').exists(): for i in tqdm(range(test_len)): im, feat = test_ds[i] im = im.unsqueeze(0) mu, var = model.encode(im)
def main(): CUDA = False if torch.cuda.is_available(): CUDA = True print('cuda available') torch.backends.cudnn.benchmark = True config = config_process(parser.parse_args()) print(config) with open('pkl/task_1_train.pkl', 'rb') as f: task_1_train = pkl.load(f) # with open('pkl/task_1_test.pkl', 'rb') as g: # task_1_test = pkl.load(g) ######################## task_1_testval.pkl with open('pkl/task_1_test.pkl', 'rb') as g: task_1_testval = pkl.load(g) task_1_test = task_1_testval[:500] task_1_val = task_1_testval[500:] ###################################### task 1 test + val ###### task 0:seen training data and unseen test data examples, labels, class_map = image_load(config['class_file'], config['image_label']) ###### task 0: seen test data examples_0, labels_0, class_map_0 = image_load(config['class_file'], config['test_seen_classes']) datasets = split_byclass(config, examples, labels, np.loadtxt(config['attributes_file']), class_map) datasets_0 = split_byclass(config, examples_0, labels_0, np.loadtxt(config['attributes_file']), class_map) print('load the task 0 train: {} the task 1 as test: {}'.format( len(datasets[0][0]), len(datasets[0][1]))) print('load task 0 test data {}'.format(len(datasets_0[0][0]))) train_attr = F.normalize(datasets[0][3]) test_attr = F.normalize(datasets[0][4]) best_cfg = config best_cfg['n_classes'] = datasets[0][3].size(0) best_cfg['n_train_lbl'] = datasets[0][3].size(0) best_cfg['n_test_lbl'] = datasets[0][4].size(0) task_1_train_set = grab_data(best_cfg, task_1_train, datasets[0][2], True) task_1_test_set = grab_data(best_cfg, task_1_test, datasets[0][2], False) task_1_val_set = grab_data(best_cfg, task_1_val, datasets[0][2], False) task_0_seen_test_set = grab_data(best_cfg, datasets_0[0][0], datasets_0[0][2], False) base_model = models.__dict__[config['arch']](pretrained=True) if config['arch'].startswith('resnet'): FE_model = nn.Sequential(*list(base_model.children())[:-1]) else: print('untested') raise NotImplementedError print('load pretrained FE_model') #######3 task id 'softmax' FE_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], config['softmax_method'], config['arch'], config['task_id'], config['finetune'], 'checkpoint.pth') FE_model.load_state_dict(torch.load(FE_path)['state_dict_FE']) for name, para in FE_model.named_parameters(): para.requires_grad = False vae = VAE(encoder_layer_sizes=config['encoder_layer_sizes'], latent_size=config['latent_size'], decoder_layer_sizes=config['decoder_layer_sizes'], num_labels=config['num_labels']) vae_path = './ckpts/{}_{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], config['method'], config['softmax_method'], config['arch'], config['task_id'], config['finetune'], 'ckpt.pth') vae.load_state_dict(torch.load(vae_path)) for name, para in vae.named_parameters(): para.requires_grad = False FE_model.eval() vae.eval() # print(vae) if CUDA: FE_model = FE_model.cuda() vae = vae.cuda() #seen task_1_real_train = get_prev_feat(FE_model, task_1_train_set, CUDA) # task_0_real_val = get_prev_feat(FE_model, task_0_val_set, CUDA) print('have got real trainval feats and labels') print('...GENERATING fake features...') task_0_fake = generate_syn_feature(150, vae, train_attr, config['syn_num'], config) train_X = torch.cat((task_0_fake[0].cuda(), task_1_real_train[0].cuda())) train_Y = torch.cat( (task_0_fake[1].cuda(), task_1_real_train[1].cuda() + 150)) # train_X = task_0_fake[0].cuda() # train_Y = task_0_fake[1].cuda() test_Dataset = PURE(train_X, train_Y) test_dataloader = torch.utils.data.DataLoader(test_Dataset, batch_size=256, shuffle=True) test_loss_net = nn.Linear(in_features=2048, out_features=200).cuda() test_loss_net_optimizer = torch.optim.Adam(test_loss_net.parameters()) print('...TRAIN test set CLASSIFIER...') # train_syn_val(test_loss_net,task_0_val_set, test_dataloader, test_loss_net_optimizer, 200) best_loss_net = train_syn_val(config, FE_model, test_loss_net, task_1_val_set, test_dataloader, test_loss_net_optimizer, 200) test_loss_net = copy.deepcopy(best_loss_net) print('\n...TESTING... GZSL: 200 labels') # test_loss_net = torch.load('hist_files/vae_time_2_doulbe_distill_loss_net.pth') # torch.save(test_loss_net,'vae_time_2_loss_net.pth') test_0_acc, test_0_top1, _, true_labels_0, pred_labels_0 = test( config, FE_model, test_loss_net, task_0_seen_test_set, CUDA, 0, eval_mode=0, print_sign=1) test_1_acc, test_1_top1, _, true_labels_1, pred_labels_1 = test( config, FE_model, test_loss_net, task_1_test_set, CUDA, 1, eval_mode=1, print_sign=1) H = 2 * test_0_acc * test_1_acc / (test_0_acc + test_1_acc) print(H) OM = (3 * test_0_acc + test_1_acc) / 4 print(OM) if not os.path.exists('results'): os.makedirs('results') file_name = '{}_{}_{}_{}.txt'.format(config['dataset'], config['arch'], config['method'], config['task_id']) with open('results/' + file_name, 'a') as fp: print(best_cfg, file=fp) print( 'task B: {:.3f}, task A:{:.3f}, H= {:.3f}, OM= {:.3f} \n'.format( test_0_acc, test_1_acc, H, OM), file=fp)
def main(): CUDA = False if torch.cuda.is_available(): CUDA = True print('cuda available') torch.backends.cudnn.benchmark = True config = config_process(parser.parse_args()) print(config) with open('pkl/task_1_train.pkl', 'rb') as f: task_1_train = pkl.load(f) with open('pkl/task_1_test.pkl', 'rb') as g: task_1_test = pkl.load(g) ###### task 0:seen training data and unseen test data examples, labels, class_map = image_load(config['class_file'], config['image_label']) ###### task 0: seen test data examples_0, labels_0, class_map_0 = image_load(config['class_file'], config['test_seen_classes']) datasets = split_byclass(config, examples, labels, np.loadtxt(config['attributes_file']), class_map) datasets_0 = split_byclass(config, examples_0, labels_0, np.loadtxt(config['attributes_file']), class_map) print('load the task 0 train: {} the task 1 as test: {}'.format( len(datasets[0][0]), len(datasets[0][1]))) print('load task 0 test data {}'.format(len(datasets_0[0][0]))) test_attr = F.normalize(datasets[0][4]) best_cfg = config best_cfg['n_classes'] = datasets[0][3].size(0) best_cfg['n_train_lbl'] = datasets[0][3].size(0) best_cfg['n_test_lbl'] = datasets[0][4].size(0) task_1_train_set = grab_data(best_cfg, task_1_train, datasets[0][2], True) # task_1_test_set = grab_data(best_cfg, task_1_test, datasets[0][2], False) base_model = models.__dict__[config['arch']](pretrained=False) if config['arch'].startswith('resnet'): FE_model = nn.Sequential(*list(base_model.children())[:-1]) else: print('untested') raise NotImplementedError ###### if finetune == False, print('load pretrained FE_model') FE_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], config['softmax_method'], config['arch'], config['task_id'], config['finetune'], 'checkpoint.pth') FE_model.load_state_dict(torch.load(FE_path)['state_dict_FE']) for name, para in FE_model.named_parameters(): para.requires_grad = False vae = VAE(encoder_layer_sizes=config['encoder_layer_sizes'], latent_size=config['latent_size'], decoder_layer_sizes=config['decoder_layer_sizes'], num_labels=config['num_labels']) vae_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], 'vae', config['arch'], 0, config['finetune'], 'ckpt.pth') vae.load_state_dict(torch.load(vae_path)) print(vae) if CUDA: FE_model = FE_model.cuda() vae = vae.cuda() FE_model.eval() optimizer = torch.optim.Adam(vae.parameters(), lr=config['lr']) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config['step'], gamma=0.1, last_epoch=-1) criterion = loss_fn print('have got real trainval feats and labels') for epoch in range(config['epoch']): print('\n epoch: %d' % epoch) print('...TRAIN...') print_learning_rate(optimizer) ### train_attr--->test_attr, task_0_train_set---> task_1_train_set train(epoch, FE_model, vae, task_1_train_set, optimizer, criterion, test_attr, CUDA) scheduler.step() vae_ckpt_name = './ckpts/{}_{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], config['method'], config['softmax_method'], config['arch'], config['task_id'], config['finetune'], 'ckpt.pth') torch.save(vae.state_dict(), vae_ckpt_name)
def generator_samples(model): if model == 'began': g = Generator128(64).to('cuda:0') g = load_trained_net( g, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) elif model == 'vae': g = VAE().to('cuda:0') g.load_state_dict( torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt')) g = g.decoder elif model == 'biggan': g = BigGanSkip().to('cuda:0') elif model == 'dcgan': g = dcgan_generator().to('cuda:0') g.load_state_dict( torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt'))) else: raise NotImplementedError nseed = 10 n_cuts_list = [0, 1, 2, 3, 4, 5] fig, ax = plt.subplots(len(n_cuts_list), nseed, figsize=(10, len(n_cuts_list))) for row, n_cuts in enumerate(n_cuts_list): input_shapes = g.input_shapes[n_cuts] z1_shape = input_shapes[0] z2_shape = input_shapes[1] for col in range(nseed): torch.manual_seed(col) np.random.seed(col) if n_cuts == 0 and model == 'biggan': class_vector = torch.tensor( 949, dtype=torch.long).to('cuda:0').unsqueeze( 0) # 949 = strawberry embed = g.biggan.embeddings( torch.nn.functional.one_hot( class_vector, num_classes=1000).to(torch.float)) cond_vector = torch.cat( (torch.randn(1, 128).to('cuda:0'), embed), dim=1) img = orig_biggan_forward( g.biggan.generator, cond_vector, truncation=1.0).detach().cpu().squeeze( 0).numpy().transpose([1, 2, 0]) elif n_cuts > 0 and model == 'biggan': z1 = torch.randn(1, *z1_shape).to('cuda:0') class_vector = torch.tensor( 949, dtype=torch.long).to('cuda:0').unsqueeze( 0) # 949 = strawberry embed = g.biggan.embeddings( torch.nn.functional.one_hot( class_vector, num_classes=1000).to(torch.float)) cond_vector = torch.cat( (torch.randn(1, 128).to('cuda:0'), embed), dim=1) z2 = cond_vector img = g( z1, z2, truncation=1.0, n_cuts=n_cuts).detach().cpu().squeeze(0).numpy().transpose( [1, 2, 0]) else: z1 = torch.randn(1, *z1_shape).to('cuda:0') if len(z2_shape) == 0: z2 = None else: z2 = torch.randn(1, *z2_shape).to('cuda:0') img = g( z1, z2, n_cuts=n_cuts).detach().cpu().squeeze(0).numpy().transpose( [1, 2, 0]) if g.rescale: img = (img + 1) / 2 ax[row, col].imshow(np.clip(img, 0, 1), aspect='auto') ax[row, col].set_xticks([]) ax[row, col].set_yticks([]) ax[row, col].set_frame_on(False) if col == 0: ax[row, col].set_ylabel(f'{n_cuts}') fig.subplots_adjust(0, 0, 1, 1, 0, 0) os.makedirs('./figures/generator_samples', exist_ok=True) plt.savefig((f'./figures/generator_samples/' f'model={model}.pdf'), dpi=300, bbox_inches='tight')
params = yaml.load(file, Loader=yaml.Loader) layer = '_'.join([str(l) for l in args.layer]) path = '{}/{}/{}'.format(params['dataset_dir'], params['dataset'], params['feature']) print('feature data path: ' + path) X = io.mmread(path).A.astype('float32') args.n, args.d = X.shape # X = normalize(X, norm='l2', axis=0) vae = VAE(args) if args.load: path = '{}/model/vae_{}_{}'.format(params['dataset_dir'], params['dataset'], layer) vae.load_state_dict(torch.load(path)) vae.to(args.device) # psm = PSM(torch.from_numpy(R.astype('float32')).to(args.device), args).to(args.device) loader = DataLoader(np.arange(args.n), batch_size=1, shuffle=True) optimizer = optim.Adam(vae.parameters(), lr=args.lr) evaluator = Evaluator({'recall', 'dcg_cut'}) # variational() # maximum() # evaluate() train() if args.save:
def main(): file_name = 'data/CUB/attributes.txt' f = open(file_name, 'r') content = f.readlines() atts = [] for item in content: atts.append(item.strip().split(' ')[1]) CUDA = False if torch.cuda.is_available(): CUDA = True print('cuda available') torch.backends.cudnn.benchmark = True config = config_process(parser.parse_args()) print(config) # pkl_name='./pkl/{}_{}_{}_{}_task_id_{}_finetune_{}_{}'.format( # config['dataset'], config['method'], config['softmax_method'],config['arch'], # config['task_id'], config['finetune'], '.pkl') # # with open(pkl_name,'rb') as f: # feat_dict=pkl.load(f) with open('pkl/task_0_train.pkl', 'rb') as f: task_0_train = pkl.load(f) with open('pkl/task_1_train.pkl', 'rb') as f: task_1_train = pkl.load(f) ###### task 0:seen training data and unseen test data examples, labels, class_map = image_load(config['class_file'], config['image_label']) ###### task 0: seen test data examples_0, labels_0, class_map_0 = image_load(config['class_file'], config['test_seen_classes']) datasets = split_byclass(config, examples, labels, np.loadtxt(config['attributes_file']), class_map) datasets_0 = split_byclass(config, examples_0, labels_0, np.loadtxt(config['attributes_file']), class_map) print('load the task 0 train: {} the task 1 as test: {}'.format( len(datasets[0][0]), len(datasets[0][1]))) print('load task 0 test data {}'.format(len(datasets_0[0][0]))) classes_text_embedding = torch.eye(312, dtype=torch.float32) test_attr = classes_text_embedding[:, :] train_attr = F.normalize(datasets[0][3]) # test_attr=F.normalize(datasets[0][4]) best_cfg = config best_cfg['n_classes'] = datasets[0][3].size(0) best_cfg['n_train_lbl'] = datasets[0][3].size(0) best_cfg['n_test_lbl'] = datasets[0][4].size(0) task_0_train_set = grab_data(best_cfg, task_0_train, datasets[0][2], True) task_1_train_set = grab_data(best_cfg, task_1_train, datasets[0][2], False) base_model = models.__dict__[config['arch']](pretrained=True) if config['arch'].startswith('resnet'): FE_model = nn.Sequential(*list(base_model.children())[:-1]) else: print('untested') raise NotImplementedError print('load pretrained FE_model') FE_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], config['softmax_method'], config['arch'], config['task_id'], config['finetune'], 'checkpoint.pth') FE_model.load_state_dict(torch.load(FE_path)['state_dict_FE']) for name, para in FE_model.named_parameters(): para.requires_grad = False vae = VAE(encoder_layer_sizes=config['encoder_layer_sizes'], latent_size=config['latent_size'], decoder_layer_sizes=config['decoder_layer_sizes'], num_labels=config['num_labels']) vae2 = VAE(encoder_layer_sizes=config['encoder_layer_sizes'], latent_size=config['latent_size'], decoder_layer_sizes=config['decoder_layer_sizes'], num_labels=config['num_labels']) vae_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], config['method'], config['arch'], config['task_id'], config['finetune'], 'ckpt.pth') vae_path = './ckpts/{}_{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], 'vae', 'softmax_distill', config['arch'], 1, config['finetune'], 'ckpt.pth') vae2_path = './ckpts/{}_{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], 'vae_distill', 'softmax_distill', config['arch'], 1, config['finetune'], 'ckpt.pth') vae.load_state_dict(torch.load(vae_path)) vae2.load_state_dict(torch.load(vae2_path)) for name, para in vae.named_parameters(): para.requires_grad = False for name, para in vae2.named_parameters(): para.requires_grad = False if CUDA: FE_model = FE_model.cuda() vae = vae.cuda() vae2 = vae2.cuda() ATTR_NUM = 312 SYN_NUM = config['syn_num'] attr_feat, attr_lbl = generate_syn_feature(ATTR_NUM, vae, test_attr, SYN_NUM, config) attr_feat2, attr_lbl2 = generate_syn_feature(ATTR_NUM, vae2, test_attr, SYN_NUM, config) with open('attr_tsne_data/attr_vae_time_2_fe_distill.pkl', 'wb') as f: pkl.dump(attr_feat, f) with open('attr_tsne_data/attr_vae_time_2_double_distill.pkl', 'wb') as g: pkl.dump(attr_feat2, g) colors = cm.rainbow(np.linspace(0, 1, ATTR_NUM)) fig = plt.figure(figsize=(16, 9)) tsne = TSNE(n_components=2) feat = torch.cat((attr_feat, attr_feat2)) tsne_results = tsne.fit_transform(feat) color_ind = colors[attr_lbl] ax = fig.add_subplot(1, 1, 1) for i in range(ATTR_NUM): ax.scatter(tsne_results[i * SYN_NUM:(i + 1) * SYN_NUM, 0], tsne_results[i * SYN_NUM:(i + 1) * SYN_NUM, 1], label=atts[i], c=np.tile(colors[i].reshape(1, -1), (SYN_NUM, 1)), s=20, marker='X') result = tsne_results[ATTR_NUM * SYN_NUM:, :] for j in range(ATTR_NUM): ax.scatter(result[j * SYN_NUM:(j + 1) * SYN_NUM, 0], result[j * SYN_NUM:(j + 1) * SYN_NUM, 1], label=atts[j], c=np.tile(colors[j].reshape(1, -1), (SYN_NUM, 1)), s=20, marker='o') # ax.scatter(tsne_results[:, 0], tsne_results[:, 1],c=color_ind, s=20, marker='X') plt.legend() plt.show()
parser.add_argument('--use-trained', type=bool, default=False, metavar='UT', help='load pretrained model (default: False)') args = parser.parse_args() batch_loader = BatchLoader() parameters = Parameters(batch_loader.vocab_size) vae = VAE(parameters.vocab_size, parameters.embed_size, parameters.latent_size, parameters.decoder_rnn_size, parameters.decoder_rnn_num_layers) if args.use_trained: vae.load_state_dict(t.load('trained_VAE')) if args.use_cuda: vae = vae.cuda() optimizer = Adam(vae.parameters(), args.learning_rate) for iteration in range(args.num_iterations): '''Train step''' input, decoder_input, target = batch_loader.next_batch( args.batch_size, 'train', args.use_cuda) target = target.view(-1) logits, aux_logits, kld = vae(args.dropout, input, decoder_input) logits = logits.view(-1, batch_loader.vocab_size) cross_entropy = F.cross_entropy(logits, target, size_average=False)
def mgan_images(args): if args.set_seed: torch.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.makedirs(BASE_DIR, exist_ok=True) if args.model in ['mgan_began_cs']: gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) img_size = 128 elif args.model in ['mgan_vanilla_vae_cs']: gen = VAE() t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt') gen.load_state_dict(t) gen = gen.eval().to(DEVICE) gen = gen.decoder img_size = 128 elif args.model in ['mgan_dcgan_cs']: gen = dcgan_generator() t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt')) gen.load_state_dict(t) gen = gen.eval().to(DEVICE) img_size = 64 else: raise NotImplementedError() img_shape = (3, img_size, img_size) metadata = recovery_settings[args.model] n_cuts_list = metadata['n_cuts_list'] del (metadata['n_cuts_list']) z_init_mode_list = metadata['z_init_mode'] limit_list = metadata['limit'] assert len(z_init_mode_list) == len(limit_list) del (metadata['z_init_mode']) del (metadata['limit']) forwards = forward_models[args.model] data_split = Path(args.img_dir).name for img_name in tqdm(sorted(os.listdir(args.img_dir)), desc='Images', leave=True, disable=args.disable_tqdm): # Load image and get filename without extension orig_img = load_target_image(os.path.join(args.img_dir, img_name), img_size).to(DEVICE) img_basename, _ = os.path.splitext(img_name) for n_cuts in tqdm(n_cuts_list, desc='N_cuts', leave=False, disable=args.disable_tqdm): metadata['n_cuts'] = n_cuts for i, (f, f_args_list) in enumerate( tqdm(forwards.items(), desc='Forwards', leave=False, disable=args.disable_tqdm)): for f_args in tqdm(f_args_list, desc=f'{f} Args', leave=False, disable=args.disable_tqdm): f_args['img_shape'] = img_shape forward_model = get_forward_model(f, **f_args) for z_init_mode, limit in zip( tqdm(z_init_mode_list, desc='z_init_mode', leave=False), limit_list): metadata['z_init_mode'] = z_init_mode metadata['limit'] = limit # Before doing recovery, check if results already exist # and possibly skip recovered_name = 'recovered.pt' results_folder = get_results_folder( image_name=img_basename, model=args.model, n_cuts=n_cuts, split=data_split, forward_model=forward_model, recovery_params=dict_to_str(metadata), base_dir=BASE_DIR) os.makedirs(results_folder, exist_ok=True) recovered_path = results_folder / recovered_name if os.path.exists( recovered_path) and not args.overwrite: print( f'{recovered_path} already exists, skipping...' ) continue if args.run_name is not None: current_run_name = ( f'{img_basename}.{forward_model}' f'.{dict_to_str(metadata)}' f'.{args.run_name}') else: current_run_name = None recovered_img, distorted_img, _ = mgan_recover( orig_img, gen, n_cuts, forward_model, metadata['optimizer'], z_init_mode, limit, metadata['z_lr'], metadata['n_steps'], metadata['z_number'], metadata['restarts'], args.run_dir, current_run_name, args.disable_tqdm) # Make images folder img_folder = get_images_folder(split=data_split, image_name=img_basename, img_size=img_size, base_dir=BASE_DIR) os.makedirs(img_folder, exist_ok=True) # Save original image if needed original_img_path = img_folder / 'original.pt' if not os.path.exists(original_img_path): torch.save(orig_img, original_img_path) # Save distorted image if needed if forward_model.viewable: distorted_img_path = img_folder / f'{forward_model}.pt' if not os.path.exists(distorted_img_path): torch.save(distorted_img, distorted_img_path) # Save recovered image and metadata torch.save(recovered_img, recovered_path) pickle.dump( metadata, open(results_folder / 'metadata.pkl', 'wb')) p = psnr(recovered_img, orig_img) pickle.dump(p, open(results_folder / 'psnr.pkl', 'wb'))