Пример #1
0
def main():
    parser = _build_parser()
    args = parser.parse_args()

    logging.basicConfig(format="%(levelname)s: %(message)s",
                        level=logging.DEBUG)

    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    model = None
    if args.model == 'vae':
        model = VAE().double().to(device)
    elif args.model == 'hm':
        model = HM().double().to(device)
    else:
        logging.critical('model unimplemented: %s' % args.model)
        return

    if not args.out.exists():
        args.out.mkdir(parents=True)

    _, test_ds = build_datasets(args.im_path, train_test_split=1)

    ckpt = torch.load(args.save_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()

    _sample_images(model, args.batch, args.samples, test_ds, args.out)
Пример #2
0
    def reset_gen():
        if args.model in ['iagan_began_cs']:
            gen = Generator128(64)
            gen = load_trained_net(
                gen,
                ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
                 '.n_cuts=0/gen_ckpt.49.pt'))
            gen = gen.eval().to(DEVICE)
            img_size = 128
        elif args.model in ['iagan_dcgan_cs']:
            gen = dcgan_generator()
            t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64'
                            '.b1_0.5.lr_0.0002.pt'))
            gen.load_state_dict(t)
            gen = gen.eval().to(DEVICE)
            img_size = 64

        elif args.model in ['iagan_vanilla_vae_cs']:
            gen = VAE()
            t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt')
            gen.load_state_dict(t)
            gen = gen.eval().to(DEVICE)
            gen = gen.decoder
            img_size = 128
        else:
            raise NotImplementedError()
        return gen, img_size
def main():
    parser = _build_parser()
    args = parser.parse_args()

    logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)

    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    model = None
    if args.model == 'vae':
        model = VAE().double().to(device)
    else:
        logging.critical('model unimplemented: %s' % args.model)
        return

    if not args.out.parent.exists():
        args.out.parent.mkdir()

    _, test_ds = build_datasets(args.im_path, train_test_split=1)

    ckpt = torch.load(args.save_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()

    with torch.no_grad():
        samps = model.sample(args.samples).reshape(-1, *IM_DIMS, 3)

    loader = DataLoader(test_ds, 
                        batch_size=args.batch, 
                        num_workers=args.workers,
                        pin_memory=torch.cuda.is_available())

    record = _init_record(samps)
    with tqdm(total=TOTAL_IMAGES) as pbar:
        for chunk in loader:
            _update_winner(chunk.reshape(-1, *IM_DIMS, 3), record, pbar)
    
    np.save(args.out, record['pair'])
    print('final distances:', record['distance'])
Пример #4
0
 def reset_gen(model):
     if model == 'began':
         gen = Generator128(64)
         gen = load_trained_net(
             gen,
             ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
              '.n_cuts=0/gen_ckpt.49.pt'))
         gen = gen.eval().to(DEVICE)
         img_size = 128
     elif model == 'vae':
         gen = VAE()
         t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt')
         gen.load_state_dict(t)
         gen = gen.eval().to(DEVICE)
         gen = gen.decoder
         img_size = 128
     elif model == 'dcgan':
         gen = dcgan_generator()
         t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64'
                         '.b1_0.5.lr_0.0002.pt'))
         gen.load_state_dict(t)
         gen = gen.eval().to(DEVICE)
         img_size = 64
     return gen, img_size
white_to_asian_lda_acc_pca = []
white_to_black_lda_acc_pca = []
asian_to_black_lda_acc_pca = []
male_to_female_lda_acc_pca = []

for conf in configs:
    print('processing conf: {}'.format(conf.name))
    out_dir = out_path / conf.name
    if not out_dir.exists():
        out_dir.mkdir(parents=True)

    # set up model
    model = VAE(**conf.params)
    ckpt = torch.load(conf.save_path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()

    # encode data
    test_len = len(test_ds)
    ldims = model.latent_dims
    mu_points = np.zeros((test_len, ldims))
    var_points = np.zeros((test_len, ldims))
    feats = []

    if not (out_dir / 'mu_points.npy').exists():
        for i in tqdm(range(test_len)):
            im, feat = test_ds[i]
            im = im.unsqueeze(0)
            mu, var = model.encode(im)

            with torch.no_grad():
def main():
    CUDA = False
    if torch.cuda.is_available():
        CUDA = True
        print('cuda available')
        torch.backends.cudnn.benchmark = True
    config = config_process(parser.parse_args())
    print(config)

    with open('pkl/task_1_train.pkl', 'rb') as f:
        task_1_train = pkl.load(f)
    # with open('pkl/task_1_test.pkl', 'rb') as g:
    #     task_1_test = pkl.load(g)
    ######################## task_1_testval.pkl
    with open('pkl/task_1_test.pkl', 'rb') as g:
        task_1_testval = pkl.load(g)

    task_1_test = task_1_testval[:500]
    task_1_val = task_1_testval[500:]
    ###################################### task 1  test + val
    ###### task 0:seen training data and unseen test data
    examples, labels, class_map = image_load(config['class_file'],
                                             config['image_label'])
    ###### task 0: seen test data
    examples_0, labels_0, class_map_0 = image_load(config['class_file'],
                                                   config['test_seen_classes'])

    datasets = split_byclass(config, examples, labels,
                             np.loadtxt(config['attributes_file']), class_map)
    datasets_0 = split_byclass(config, examples_0, labels_0,
                               np.loadtxt(config['attributes_file']),
                               class_map)
    print('load the task 0 train: {} the task 1 as test: {}'.format(
        len(datasets[0][0]), len(datasets[0][1])))
    print('load task 0 test data {}'.format(len(datasets_0[0][0])))

    train_attr = F.normalize(datasets[0][3])
    test_attr = F.normalize(datasets[0][4])

    best_cfg = config
    best_cfg['n_classes'] = datasets[0][3].size(0)
    best_cfg['n_train_lbl'] = datasets[0][3].size(0)
    best_cfg['n_test_lbl'] = datasets[0][4].size(0)

    task_1_train_set = grab_data(best_cfg, task_1_train, datasets[0][2], True)
    task_1_test_set = grab_data(best_cfg, task_1_test, datasets[0][2], False)
    task_1_val_set = grab_data(best_cfg, task_1_val, datasets[0][2], False)
    task_0_seen_test_set = grab_data(best_cfg, datasets_0[0][0],
                                     datasets_0[0][2], False)

    base_model = models.__dict__[config['arch']](pretrained=True)
    if config['arch'].startswith('resnet'):
        FE_model = nn.Sequential(*list(base_model.children())[:-1])
    else:
        print('untested')
        raise NotImplementedError

    print('load pretrained FE_model')
    #######3 task id 'softmax'
    FE_path = './ckpts/{}_{}_{}_task_id_{}_finetune_{}_{}'.format(
        config['dataset'], config['softmax_method'], config['arch'],
        config['task_id'], config['finetune'], 'checkpoint.pth')

    FE_model.load_state_dict(torch.load(FE_path)['state_dict_FE'])
    for name, para in FE_model.named_parameters():
        para.requires_grad = False

    vae = VAE(encoder_layer_sizes=config['encoder_layer_sizes'],
              latent_size=config['latent_size'],
              decoder_layer_sizes=config['decoder_layer_sizes'],
              num_labels=config['num_labels'])

    vae_path = './ckpts/{}_{}_{}_{}_task_id_{}_finetune_{}_{}'.format(
        config['dataset'], config['method'], config['softmax_method'],
        config['arch'], config['task_id'], config['finetune'], 'ckpt.pth')
    vae.load_state_dict(torch.load(vae_path))
    for name, para in vae.named_parameters():
        para.requires_grad = False

    FE_model.eval()
    vae.eval()
    # print(vae)
    if CUDA:
        FE_model = FE_model.cuda()
        vae = vae.cuda()

    #seen
    task_1_real_train = get_prev_feat(FE_model, task_1_train_set, CUDA)
    # task_0_real_val = get_prev_feat(FE_model, task_0_val_set, CUDA)

    print('have got real trainval feats and labels')
    print('...GENERATING fake features...')
    task_0_fake = generate_syn_feature(150, vae, train_attr, config['syn_num'],
                                       config)

    train_X = torch.cat((task_0_fake[0].cuda(), task_1_real_train[0].cuda()))
    train_Y = torch.cat(
        (task_0_fake[1].cuda(), task_1_real_train[1].cuda() + 150))
    # train_X = task_0_fake[0].cuda()
    # train_Y = task_0_fake[1].cuda()
    test_Dataset = PURE(train_X, train_Y)

    test_dataloader = torch.utils.data.DataLoader(test_Dataset,
                                                  batch_size=256,
                                                  shuffle=True)

    test_loss_net = nn.Linear(in_features=2048, out_features=200).cuda()
    test_loss_net_optimizer = torch.optim.Adam(test_loss_net.parameters())

    print('...TRAIN test set CLASSIFIER...')
    # train_syn_val(test_loss_net,task_0_val_set, test_dataloader, test_loss_net_optimizer, 200)
    best_loss_net = train_syn_val(config, FE_model, test_loss_net,
                                  task_1_val_set, test_dataloader,
                                  test_loss_net_optimizer, 200)
    test_loss_net = copy.deepcopy(best_loss_net)
    print('\n...TESTING... GZSL: 200 labels')

    # test_loss_net = torch.load('hist_files/vae_time_2_doulbe_distill_loss_net.pth')
    # torch.save(test_loss_net,'vae_time_2_loss_net.pth')
    test_0_acc, test_0_top1, _, true_labels_0, pred_labels_0 = test(
        config,
        FE_model,
        test_loss_net,
        task_0_seen_test_set,
        CUDA,
        0,
        eval_mode=0,
        print_sign=1)
    test_1_acc, test_1_top1, _, true_labels_1, pred_labels_1 = test(
        config,
        FE_model,
        test_loss_net,
        task_1_test_set,
        CUDA,
        1,
        eval_mode=1,
        print_sign=1)
    H = 2 * test_0_acc * test_1_acc / (test_0_acc + test_1_acc)
    print(H)
    OM = (3 * test_0_acc + test_1_acc) / 4
    print(OM)
    if not os.path.exists('results'):
        os.makedirs('results')

    file_name = '{}_{}_{}_{}.txt'.format(config['dataset'], config['arch'],
                                         config['method'], config['task_id'])
    with open('results/' + file_name, 'a') as fp:
        print(best_cfg, file=fp)
        print(
            'task B: {:.3f}, task A:{:.3f}, H= {:.3f}, OM= {:.3f}  \n'.format(
                test_0_acc, test_1_acc, H, OM),
            file=fp)
Пример #7
0
def main():
    parser = _build_parser()
    args = parser.parse_args()

    logging.basicConfig(format="%(levelname)s: %(message)s",
                        level=logging.DEBUG)

    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    model = None
    save_path = args.save
    if args.model == 'vae':
        model = VAE(args.dim)
        logging.info('training VAE with dims: {}'.format(args.dim))
    elif args.model == 'ae':
        model = AE(args.dim)
        logging.info('training AE with dims: {}'.format(args.dim))
    elif args.model == 'hm':
        model = HM(args.color)
    elif args.model == 'gmvae':
        model = GMVAE()
    else:
        logging.critical('model unimplemented: %s' % args.model)
        return

    if not save_path.exists():
        save_path.mkdir(parents=True)

    model = model.float()
    model.to(device)
    optimizer = optim.Adam(model.parameters())

    train_ds, test_ds = build_datasets(args.path)

    losses = []
    for e in range(args.epochs):
        logging.info('epoch: %d of %d' % (e + 1, args.epochs))

        loader = DataLoader(train_ds,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.workers,
                            pin_memory=torch.cuda.is_available())
        total_batches = len(train_ds) // args.batch_size

        log_every = total_batches // 50 + 1
        save_every = 1  # hardcoded for now
        for i, x in enumerate(loader):
            x = x.to(device)
            optimizer.zero_grad()
            output = model(x)
            total_loss = model.loss_function(output)
            if type(total_loss) is dict:  # TODO: generalize loss handling
                total_loss = total_loss['loss']

            total_loss.backward()
            optimizer.step()

            if i % log_every == 0:
                model.eval()
                loss = _eval(model, test_ds, device)
                model.train()

                logging.info('[batch %d/%d] ' % (i + 1, total_batches) +
                             model.print_loss(loss))
                # TODO: generalize printing
                # print_params = (i+1, total_batches, loss['loss'], loss['mse'], loss['kld'])
                # logging.info('[batch %d/%d] loss: %f, mse: %f, kld: %f' % print_params)
                # print_params = (i+1, total_batches, loss)
                # logging.info('[batch %d/%d] loss: %f' % print_params)
                losses.append({'iter': i, 'epoch': e, 'loss': loss})

        if e % save_every == 0:
            torch.save(
                {
                    'epoch': e + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss
                }, save_path / ('epoch_%d.pt' % (e + 1)))

    model.eval()
    loss = _eval(model, test_ds, device)
    model.train()

    logging.info('final loss: %s' % loss)
    losses.append({'iter': 0, 'epoch': e + 1, 'loss': loss})

    with open(save_path / 'loss.pk', 'wb') as pkf:
        pickle.dump(losses, pkf)

    torch.save(
        {
            'epoch': args.epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, save_path / 'final.pt')
    print('done!')
Пример #8
0
def mgan_images(args):
    if args.set_seed:
        torch.manual_seed(0)
        np.random.seed(0)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    os.makedirs(BASE_DIR, exist_ok=True)

    if args.model in ['mgan_began_cs']:
        gen = Generator128(64)
        gen = load_trained_net(
            gen, ('./checkpoints/celeba_began.withskips.bs32.cosine.min=0.25'
                  '.n_cuts=0/gen_ckpt.49.pt'))
        gen = gen.eval().to(DEVICE)
        img_size = 128
    elif args.model in ['mgan_vanilla_vae_cs']:
        gen = VAE()
        t = torch.load('./vae_checkpoints/vae_bs=128_beta=1.0/epoch_19.pt')
        gen.load_state_dict(t)
        gen = gen.eval().to(DEVICE)
        gen = gen.decoder
        img_size = 128
    elif args.model in ['mgan_dcgan_cs']:
        gen = dcgan_generator()
        t = torch.load(('./dcgan_checkpoints/netG.epoch_24.n_cuts_0.bs_64'
                        '.b1_0.5.lr_0.0002.pt'))
        gen.load_state_dict(t)
        gen = gen.eval().to(DEVICE)
        img_size = 64
    else:
        raise NotImplementedError()

    img_shape = (3, img_size, img_size)
    metadata = recovery_settings[args.model]
    n_cuts_list = metadata['n_cuts_list']
    del (metadata['n_cuts_list'])

    z_init_mode_list = metadata['z_init_mode']
    limit_list = metadata['limit']
    assert len(z_init_mode_list) == len(limit_list)
    del (metadata['z_init_mode'])
    del (metadata['limit'])

    forwards = forward_models[args.model]

    data_split = Path(args.img_dir).name
    for img_name in tqdm(sorted(os.listdir(args.img_dir)),
                         desc='Images',
                         leave=True,
                         disable=args.disable_tqdm):
        # Load image and get filename without extension
        orig_img = load_target_image(os.path.join(args.img_dir, img_name),
                                     img_size).to(DEVICE)
        img_basename, _ = os.path.splitext(img_name)

        for n_cuts in tqdm(n_cuts_list,
                           desc='N_cuts',
                           leave=False,
                           disable=args.disable_tqdm):
            metadata['n_cuts'] = n_cuts
            for i, (f, f_args_list) in enumerate(
                    tqdm(forwards.items(),
                         desc='Forwards',
                         leave=False,
                         disable=args.disable_tqdm)):
                for f_args in tqdm(f_args_list,
                                   desc=f'{f} Args',
                                   leave=False,
                                   disable=args.disable_tqdm):

                    f_args['img_shape'] = img_shape
                    forward_model = get_forward_model(f, **f_args)

                    for z_init_mode, limit in zip(
                            tqdm(z_init_mode_list,
                                 desc='z_init_mode',
                                 leave=False), limit_list):
                        metadata['z_init_mode'] = z_init_mode
                        metadata['limit'] = limit

                        # Before doing recovery, check if results already exist
                        # and possibly skip
                        recovered_name = 'recovered.pt'
                        results_folder = get_results_folder(
                            image_name=img_basename,
                            model=args.model,
                            n_cuts=n_cuts,
                            split=data_split,
                            forward_model=forward_model,
                            recovery_params=dict_to_str(metadata),
                            base_dir=BASE_DIR)

                        os.makedirs(results_folder, exist_ok=True)

                        recovered_path = results_folder / recovered_name
                        if os.path.exists(
                                recovered_path) and not args.overwrite:
                            print(
                                f'{recovered_path} already exists, skipping...'
                            )
                            continue

                        if args.run_name is not None:
                            current_run_name = (
                                f'{img_basename}.{forward_model}'
                                f'.{dict_to_str(metadata)}'
                                f'.{args.run_name}')
                        else:
                            current_run_name = None

                        recovered_img, distorted_img, _ = mgan_recover(
                            orig_img, gen, n_cuts, forward_model,
                            metadata['optimizer'], z_init_mode, limit,
                            metadata['z_lr'], metadata['n_steps'],
                            metadata['z_number'], metadata['restarts'],
                            args.run_dir, current_run_name, args.disable_tqdm)

                        # Make images folder
                        img_folder = get_images_folder(split=data_split,
                                                       image_name=img_basename,
                                                       img_size=img_size,
                                                       base_dir=BASE_DIR)
                        os.makedirs(img_folder, exist_ok=True)

                        # Save original image if needed
                        original_img_path = img_folder / 'original.pt'
                        if not os.path.exists(original_img_path):
                            torch.save(orig_img, original_img_path)

                        # Save distorted image if needed
                        if forward_model.viewable:
                            distorted_img_path = img_folder / f'{forward_model}.pt'
                            if not os.path.exists(distorted_img_path):
                                torch.save(distorted_img, distorted_img_path)

                        # Save recovered image and metadata
                        torch.save(recovered_img, recovered_path)
                        pickle.dump(
                            metadata,
                            open(results_folder / 'metadata.pkl', 'wb'))
                        p = psnr(recovered_img, orig_img)
                        pickle.dump(p, open(results_folder / 'psnr.pkl', 'wb'))