def main(): parser = _build_parser() args = parser.parse_args() logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') model = None if args.model == 'vae': model = VAE().double().to(device) elif args.model == 'hm': model = HM().double().to(device) else: logging.critical('model unimplemented: %s' % args.model) return if not args.out.exists(): args.out.mkdir(parents=True) _, test_ds = build_datasets(args.im_path, train_test_split=1) ckpt = torch.load(args.save_path, map_location=device) model.load_state_dict(ckpt['model_state_dict']) model.eval() _sample_images(model, args.batch, args.samples, test_ds, args.out)
def reset_gen(): if args.model in ['iagan_began_cs']: gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) img_size = 128 elif args.model in ['iagan_dcgan_cs']: gen = dcgan_generator() t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt')) gen.load_state_dict(t) gen = gen.eval().to(DEVICE) img_size = 64 elif args.model in ['iagan_vanilla_vae_cs']: gen = VAE() t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt') gen.load_state_dict(t) gen = gen.eval().to(DEVICE) gen = gen.decoder img_size = 128 else: raise NotImplementedError() return gen, img_size
def main(): parser = _build_parser() args = parser.parse_args() logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') model = None if args.model == 'vae': model = VAE().double().to(device) else: logging.critical('model unimplemented: %s' % args.model) return if not args.out.parent.exists(): args.out.parent.mkdir() _, test_ds = build_datasets(args.im_path, train_test_split=1) ckpt = torch.load(args.save_path, map_location=device) model.load_state_dict(ckpt['model_state_dict']) model.eval() with torch.no_grad(): samps = model.sample(args.samples).reshape(-1, *IM_DIMS, 3) loader = DataLoader(test_ds, batch_size=args.batch, num_workers=args.workers, pin_memory=torch.cuda.is_available()) record = _init_record(samps) with tqdm(total=TOTAL_IMAGES) as pbar: for chunk in loader: _update_winner(chunk.reshape(-1, *IM_DIMS, 3), record, pbar) np.save(args.out, record['pair']) print('final distances:', record['distance'])
def reset_gen(model): if model == 'began': gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) img_size = 128 elif model == 'vae': gen = VAE() t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt') gen.load_state_dict(t) gen = gen.eval().to(DEVICE) gen = gen.decoder img_size = 128 elif model == 'dcgan': gen = dcgan_generator() t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt')) gen.load_state_dict(t) gen = gen.eval().to(DEVICE) img_size = 64 return gen, img_size
white_to_asian_lda_acc_pca = [] white_to_black_lda_acc_pca = [] asian_to_black_lda_acc_pca = [] male_to_female_lda_acc_pca = [] for conf in configs: print('processing conf: {}'.format(conf.name)) out_dir = out_path / conf.name if not out_dir.exists(): out_dir.mkdir(parents=True) # set up model model = VAE(**conf.params) ckpt = torch.load(conf.save_path, map_location=torch.device('cpu')) model.load_state_dict(ckpt['model_state_dict']) model.eval() # encode data test_len = len(test_ds) ldims = model.latent_dims mu_points = np.zeros((test_len, ldims)) var_points = np.zeros((test_len, ldims)) feats = [] if not (out_dir / 'mu_points.npy').exists(): for i in tqdm(range(test_len)): im, feat = test_ds[i] im = im.unsqueeze(0) mu, var = model.encode(im) with torch.no_grad():
def main(): CUDA = False if torch.cuda.is_available(): CUDA = True print('cuda available') torch.backends.cudnn.benchmark = True config = config_process(parser.parse_args()) print(config) with open('pkl/task_1_train.pkl', 'rb') as f: task_1_train = pkl.load(f) # with open('pkl/task_1_test.pkl', 'rb') as g: # task_1_test = pkl.load(g) ######################## task_1_testval.pkl with open('pkl/task_1_test.pkl', 'rb') as g: task_1_testval = pkl.load(g) task_1_test = task_1_testval[:500] task_1_val = task_1_testval[500:] ###################################### task 1 test + val ###### task 0:seen training data and unseen test data examples, labels, class_map = image_load(config['class_file'], config['image_label']) ###### task 0: seen test data examples_0, labels_0, class_map_0 = image_load(config['class_file'], config['test_seen_classes']) datasets = split_byclass(config, examples, labels, np.loadtxt(config['attributes_file']), class_map) datasets_0 = split_byclass(config, examples_0, labels_0, np.loadtxt(config['attributes_file']), class_map) print('load the task 0 train: {} the task 1 as test: {}'.format( len(datasets[0][0]), len(datasets[0][1]))) print('load task 0 test data {}'.format(len(datasets_0[0][0]))) train_attr = F.normalize(datasets[0][3]) test_attr = F.normalize(datasets[0][4]) best_cfg = config best_cfg['n_classes'] = datasets[0][3].size(0) best_cfg['n_train_lbl'] = datasets[0][3].size(0) best_cfg['n_test_lbl'] = datasets[0][4].size(0) task_1_train_set = grab_data(best_cfg, task_1_train, datasets[0][2], True) task_1_test_set = grab_data(best_cfg, task_1_test, datasets[0][2], False) task_1_val_set = grab_data(best_cfg, task_1_val, datasets[0][2], False) task_0_seen_test_set = grab_data(best_cfg, datasets_0[0][0], datasets_0[0][2], False) base_model = models.__dict__[config['arch']](pretrained=True) if config['arch'].startswith('resnet'): FE_model = nn.Sequential(*list(base_model.children())[:-1]) else: print('untested') raise NotImplementedError print('load pretrained FE_model') #######3 task id 'softmax' FE_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], config['softmax_method'], config['arch'], config['task_id'], config['finetune'], 'checkpoint.pth') FE_model.load_state_dict(torch.load(FE_path)['state_dict_FE']) for name, para in FE_model.named_parameters(): para.requires_grad = False vae = VAE(encoder_layer_sizes=config['encoder_layer_sizes'], latent_size=config['latent_size'], decoder_layer_sizes=config['decoder_layer_sizes'], num_labels=config['num_labels']) vae_path = './ckpts/{}_{}_{}_{}_task_id_{}_finetune_{}_{}'.format( config['dataset'], config['method'], config['softmax_method'], config['arch'], config['task_id'], config['finetune'], 'ckpt.pth') vae.load_state_dict(torch.load(vae_path)) for name, para in vae.named_parameters(): para.requires_grad = False FE_model.eval() vae.eval() # print(vae) if CUDA: FE_model = FE_model.cuda() vae = vae.cuda() #seen task_1_real_train = get_prev_feat(FE_model, task_1_train_set, CUDA) # task_0_real_val = get_prev_feat(FE_model, task_0_val_set, CUDA) print('have got real trainval feats and labels') print('...GENERATING fake features...') task_0_fake = generate_syn_feature(150, vae, train_attr, config['syn_num'], config) train_X = torch.cat((task_0_fake[0].cuda(), task_1_real_train[0].cuda())) train_Y = torch.cat( (task_0_fake[1].cuda(), task_1_real_train[1].cuda() + 150)) # train_X = task_0_fake[0].cuda() # train_Y = task_0_fake[1].cuda() test_Dataset = PURE(train_X, train_Y) test_dataloader = torch.utils.data.DataLoader(test_Dataset, batch_size=256, shuffle=True) test_loss_net = nn.Linear(in_features=2048, out_features=200).cuda() test_loss_net_optimizer = torch.optim.Adam(test_loss_net.parameters()) print('...TRAIN test set CLASSIFIER...') # train_syn_val(test_loss_net,task_0_val_set, test_dataloader, test_loss_net_optimizer, 200) best_loss_net = train_syn_val(config, FE_model, test_loss_net, task_1_val_set, test_dataloader, test_loss_net_optimizer, 200) test_loss_net = copy.deepcopy(best_loss_net) print('\n...TESTING... GZSL: 200 labels') # test_loss_net = torch.load('hist_files/vae_time_2_doulbe_distill_loss_net.pth') # torch.save(test_loss_net,'vae_time_2_loss_net.pth') test_0_acc, test_0_top1, _, true_labels_0, pred_labels_0 = test( config, FE_model, test_loss_net, task_0_seen_test_set, CUDA, 0, eval_mode=0, print_sign=1) test_1_acc, test_1_top1, _, true_labels_1, pred_labels_1 = test( config, FE_model, test_loss_net, task_1_test_set, CUDA, 1, eval_mode=1, print_sign=1) H = 2 * test_0_acc * test_1_acc / (test_0_acc + test_1_acc) print(H) OM = (3 * test_0_acc + test_1_acc) / 4 print(OM) if not os.path.exists('results'): os.makedirs('results') file_name = '{}_{}_{}_{}.txt'.format(config['dataset'], config['arch'], config['method'], config['task_id']) with open('results/' + file_name, 'a') as fp: print(best_cfg, file=fp) print( 'task B: {:.3f}, task A:{:.3f}, H= {:.3f}, OM= {:.3f} \n'.format( test_0_acc, test_1_acc, H, OM), file=fp)
def main(): parser = _build_parser() args = parser.parse_args() logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') model = None save_path = args.save if args.model == 'vae': model = VAE(args.dim) logging.info('training VAE with dims: {}'.format(args.dim)) elif args.model == 'ae': model = AE(args.dim) logging.info('training AE with dims: {}'.format(args.dim)) elif args.model == 'hm': model = HM(args.color) elif args.model == 'gmvae': model = GMVAE() else: logging.critical('model unimplemented: %s' % args.model) return if not save_path.exists(): save_path.mkdir(parents=True) model = model.float() model.to(device) optimizer = optim.Adam(model.parameters()) train_ds, test_ds = build_datasets(args.path) losses = [] for e in range(args.epochs): logging.info('epoch: %d of %d' % (e + 1, args.epochs)) loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=torch.cuda.is_available()) total_batches = len(train_ds) // args.batch_size log_every = total_batches // 50 + 1 save_every = 1 # hardcoded for now for i, x in enumerate(loader): x = x.to(device) optimizer.zero_grad() output = model(x) total_loss = model.loss_function(output) if type(total_loss) is dict: # TODO: generalize loss handling total_loss = total_loss['loss'] total_loss.backward() optimizer.step() if i % log_every == 0: model.eval() loss = _eval(model, test_ds, device) model.train() logging.info('[batch %d/%d] ' % (i + 1, total_batches) + model.print_loss(loss)) # TODO: generalize printing # print_params = (i+1, total_batches, loss['loss'], loss['mse'], loss['kld']) # logging.info('[batch %d/%d] loss: %f, mse: %f, kld: %f' % print_params) # print_params = (i+1, total_batches, loss) # logging.info('[batch %d/%d] loss: %f' % print_params) losses.append({'iter': i, 'epoch': e, 'loss': loss}) if e % save_every == 0: torch.save( { 'epoch': e + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss }, save_path / ('epoch_%d.pt' % (e + 1))) model.eval() loss = _eval(model, test_ds, device) model.train() logging.info('final loss: %s' % loss) losses.append({'iter': 0, 'epoch': e + 1, 'loss': loss}) with open(save_path / 'loss.pk', 'wb') as pkf: pickle.dump(losses, pkf) torch.save( { 'epoch': args.epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss }, save_path / 'final.pt') print('done!')
def mgan_images(args): if args.set_seed: torch.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.makedirs(BASE_DIR, exist_ok=True) if args.model in ['mgan_began_cs']: gen = Generator128(64) gen = load_trained_net( gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25' '.n_cuts=0/gen_ckpt.49.pt')) gen = gen.eval().to(DEVICE) img_size = 128 elif args.model in ['mgan_vanilla_vae_cs']: gen = VAE() t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt') gen.load_state_dict(t) gen = gen.eval().to(DEVICE) gen = gen.decoder img_size = 128 elif args.model in ['mgan_dcgan_cs']: gen = dcgan_generator() t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64' '.b1_0.5.lr_0.0002.pt')) gen.load_state_dict(t) gen = gen.eval().to(DEVICE) img_size = 64 else: raise NotImplementedError() img_shape = (3, img_size, img_size) metadata = recovery_settings[args.model] n_cuts_list = metadata['n_cuts_list'] del (metadata['n_cuts_list']) z_init_mode_list = metadata['z_init_mode'] limit_list = metadata['limit'] assert len(z_init_mode_list) == len(limit_list) del (metadata['z_init_mode']) del (metadata['limit']) forwards = forward_models[args.model] data_split = Path(args.img_dir).name for img_name in tqdm(sorted(os.listdir(args.img_dir)), desc='Images', leave=True, disable=args.disable_tqdm): # Load image and get filename without extension orig_img = load_target_image(os.path.join(args.img_dir, img_name), img_size).to(DEVICE) img_basename, _ = os.path.splitext(img_name) for n_cuts in tqdm(n_cuts_list, desc='N_cuts', leave=False, disable=args.disable_tqdm): metadata['n_cuts'] = n_cuts for i, (f, f_args_list) in enumerate( tqdm(forwards.items(), desc='Forwards', leave=False, disable=args.disable_tqdm)): for f_args in tqdm(f_args_list, desc=f'{f} Args', leave=False, disable=args.disable_tqdm): f_args['img_shape'] = img_shape forward_model = get_forward_model(f, **f_args) for z_init_mode, limit in zip( tqdm(z_init_mode_list, desc='z_init_mode', leave=False), limit_list): metadata['z_init_mode'] = z_init_mode metadata['limit'] = limit # Before doing recovery, check if results already exist # and possibly skip recovered_name = 'recovered.pt' results_folder = get_results_folder( image_name=img_basename, model=args.model, n_cuts=n_cuts, split=data_split, forward_model=forward_model, recovery_params=dict_to_str(metadata), base_dir=BASE_DIR) os.makedirs(results_folder, exist_ok=True) recovered_path = results_folder / recovered_name if os.path.exists( recovered_path) and not args.overwrite: print( f'{recovered_path} already exists, skipping...' ) continue if args.run_name is not None: current_run_name = ( f'{img_basename}.{forward_model}' f'.{dict_to_str(metadata)}' f'.{args.run_name}') else: current_run_name = None recovered_img, distorted_img, _ = mgan_recover( orig_img, gen, n_cuts, forward_model, metadata['optimizer'], z_init_mode, limit, metadata['z_lr'], metadata['n_steps'], metadata['z_number'], metadata['restarts'], args.run_dir, current_run_name, args.disable_tqdm) # Make images folder img_folder = get_images_folder(split=data_split, image_name=img_basename, img_size=img_size, base_dir=BASE_DIR) os.makedirs(img_folder, exist_ok=True) # Save original image if needed original_img_path = img_folder / 'original.pt' if not os.path.exists(original_img_path): torch.save(orig_img, original_img_path) # Save distorted image if needed if forward_model.viewable: distorted_img_path = img_folder / f'{forward_model}.pt' if not os.path.exists(distorted_img_path): torch.save(distorted_img, distorted_img_path) # Save recovered image and metadata torch.save(recovered_img, recovered_path) pickle.dump( metadata, open(results_folder / 'metadata.pkl', 'wb')) p = psnr(recovered_img, orig_img) pickle.dump(p, open(results_folder / 'psnr.pkl', 'wb'))