Пример #1
0
def create_run(device, temp, pixelsnail_batch, ckpt_epoch, pixelsnail_ckpt_epoch, hier, architecture, num_embeddings, neighborhood, selection_fn, dataset, size, **kwargs):
    experiment_name = util_funcs.create_experiment_name(architecture, dataset, num_embeddings, neighborhood, selection_fn, size, **kwargs)
    vqvae_checkpoint_name = util_funcs.create_checkpoint_name(experiment_name, ckpt_epoch)

    # pixelsnail_checkpoint_name = f'pixelsnail_{experiment_name}_{hier}_{str(pixelsnail_ckpt_epoch + 1).zfill(3)}.pt'
    pixelsnail_checkpoint_name = 'pixelsnail_vqvae_imagenet_num_embeddings[512]_neighborhood[1]_selectionFN[vanilla]_size[128]_bottom_420.pt'

    # model_vqvae = load_model('vqvae', vqvae_checkpoint_name, device, architecture, num_embeddings, neighborhood, selection_fn, **kwargs)
    # model_top = load_model('pixelsnail_top', args.top, device)
    model_bottom = load_model('pixelsnail_bottom', pixelsnail_checkpoint_name, device, size=size, **kwargs)

    num_samples = 50000
    sampled_directory = os.path.join('sampled_images', pixelsnail_checkpoint_name).replace('.pt', '')
    if os.path.exists(sampled_directory):
        shutil.rmtree(sampled_directory)
    os.mkdir(sampled_directory)

    for sample_ind in tqdm(range(num_samples), 'Sampling image for: {}'.format(pixelsnail_checkpoint_name)):

        # top_sample = sample_model(model_top, device, args.batch, [32, 32], args.temp)
        bottom_sample = sample_model(
            model_bottom, device, pixelsnail_batch, [size // 4, size // 4], temp, condition=None
            # model_bottom, device, args.batch, [64, 64], args.temp, condition=top_sample
        )

        # decoded_sample = model_vqvae._modules['module'].decode_code(bottom_sample)
        # decoded_sample = model_vqvae.decode_code(top_sample, bottom_sample)
        # decoded_sample = decoded_sample.clamp(-1, 1)

        # filename = 'sampled_{}.png'.format(sample_ind)
        # target_path = os.path.join(sampled_directory, filename)
        save_image(decoded_sample, target_path, normalize=True, range=(-1, 1))
Пример #2
0
def create_extraction_run(size, device, dataset, data_path, num_workers,
                          num_embeddings, architecture, ckpt_epoch,
                          neighborhood, selection_fn, embed_dim, **kwargs):
    train_dataset, test_dataset = get_dataset(dataset, data_path, size)

    print('Creating named datasets')
    # We don't really use the "Named" part, but I'm keeping it to stay close to the original code repository
    train_named_dataset = NamedDataset(train_dataset)
    test_named_dataset = NamedDataset(test_dataset)

    print('creating data loaders')
    train_loader = DataLoader(train_named_dataset,
                              batch_size=kwargs['vae_batch'],
                              shuffle=False,
                              num_workers=num_workers)
    test_loader = DataLoader(test_named_dataset,
                             batch_size=kwargs['vae_batch'],
                             shuffle=False,
                             num_workers=num_workers)

    # This is still the VQ-VAE experiment name and path
    experiment_name = util_funcs.create_experiment_name(
        architecture, dataset, num_embeddings, neighborhood, selection_fn,
        size, **kwargs)
    checkpoint_name = util_funcs.create_checkpoint_name(
        experiment_name, ckpt_epoch)
    checkpoint_path = f'checkpoint/{checkpoint_name}'

    print('Loading model')
    model = get_model(architecture,
                      num_embeddings,
                      device,
                      neighborhood,
                      selection_fn,
                      embed_dim,
                      parallel=False,
                      **kwargs)
    model.load_state_dict(torch.load(checkpoint_path), strict=False)
    model = model.to(device)
    model.eval()

    print('Creating LMDB DBs')
    map_size = 100 * 1024 * 1024 * 1024  # This would be the maximum size of the databases
    db_name = checkpoint_name[:-3] + '_dataset[{}]'.format(
        dataset
    )  # This comprises of the experiment name and the epoch the codes are taken from
    train_env = lmdb.open(
        os.path.join('codes', 'train_codes', db_name),
        map_size=map_size)  # Will save the encodings of train samples
    test_env = lmdb.open(
        os.path.join('codes', 'test_codes', db_name),
        map_size=map_size)  # Will save the encodings of test samples

    print('Extracting')
    if architecture == 'vqvae':
        extract(train_env, train_loader, model, device, 'train')
        extract(test_env, test_loader, model, device, 'test')
Пример #3
0
def get_checkpoint_names(architecture, ckpt_epoch, dataset, hier, kwargs,
                         neighborhood, num_embeddings, pixelsnail_ckpt_epoch,
                         selection_fn, size):
    experiment_name = util_funcs.create_experiment_name(
        architecture, dataset, num_embeddings, neighborhood, selection_fn,
        size, **kwargs)
    vqvae_checkpoint_name = util_funcs.create_checkpoint_name(
        experiment_name, ckpt_epoch)
    pixelsnail_checkpoint_name = f'pixelsnail_{experiment_name}_{hier}_{str(pixelsnail_ckpt_epoch + 1).zfill(3)}.pt'
    return pixelsnail_checkpoint_name, vqvae_checkpoint_name
Пример #4
0
def create_run(architecture, dataset, num_embeddings, num_workers,
               selection_fn, neighborhood, device, embed_dim, size, **kwargs):
    global args, scheduler

    # Get VQVAE experiment name
    experiment_name = util_funcs.create_experiment_name(
        architecture,
        dataset,
        num_embeddings,
        neighborhood,
        selection_fn=selection_fn,
        size=size,
        **kwargs)

    # Prepare logger
    writer = SummaryWriter(
        os.path.join('runs', 'pixelsnail_' + experiment_name + '2',
                     str(datetime.datetime.now())))

    # Load datasets
    test_loader, train_loader = load_datasets(args, experiment_name,
                                              num_workers, dataset)

    # Create model and optimizer
    model, optimizer = prepare_model_parts(train_loader)

    # Get checkpoint path for underlying VQ-VAE model
    checkpoint_name = util_funcs.create_checkpoint_name(
        experiment_name, kwargs['ckpt_epoch'])
    checkpoint_path = f'checkpoint/{checkpoint_name}'

    # Load underlying VQ-VAE model for logging purposes
    vqvae_model = get_model(architecture,
                            num_embeddings,
                            device,
                            neighborhood,
                            selection_fn,
                            embed_dim,
                            parallel=False,
                            **kwargs)
    vqvae_model.load_state_dict(torch.load(os.path.join(checkpoint_path)),
                                strict=False)
    vqvae_model = vqvae_model.to(args.device)
    vqvae_model.eval()

    # Train model
    train_coefficients_loss, train_num_nonzeros_loss, train_atom_loss, train_losses, \
    test_coefficients_loss, test_num_nonzeros_loss, test_atom_loss, test_losses, = \
    run_train(args, experiment_name, model, optimizer, scheduler, test_loader, train_loader, writer, vqvae_model)

    return train_coefficients_loss, train_num_nonzeros_loss, train_atom_loss, train_losses, \
           test_coefficients_loss, test_num_nonzeros_loss, test_atom_loss, test_losses
Пример #5
0
def load_datasets(args, experiment_name, num_workers, dataset):
    db_name = util_funcs.create_checkpoint_name(
        experiment_name, args.ckpt_epoch)[:-3] + '_dataset[{}]'.format(dataset)
    train_dataset = LMDBDataset(
        os.path.join('..', 'codes', 'train_codes', db_name), args.architecture)
    test_dataset = LMDBDataset(
        os.path.join('..', 'codes', 'test_codes', db_name), args.architecture)
    train_loader = DataLoader(train_dataset,
                              batch_size=1,
                              shuffle=True,
                              num_workers=num_workers,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=True,
                             num_workers=num_workers,
                             drop_last=True)
    return test_loader, train_loader
Пример #6
0
def create_run(architecture, dataset, num_embeddings, num_workers,
               selection_fn, neighborhood, device, size, ckpt_epoch, embed_dim,
               **kwargs):
    global args, scheduler

    print('creating data loaders')
    experiment_name = util_funcs.create_experiment_name(
        architecture, dataset, num_embeddings, neighborhood, selection_fn,
        size, **kwargs)
    checkpoint_name = util_funcs.create_checkpoint_name(
        experiment_name, ckpt_epoch)
    checkpoint_path = f'checkpoint/{checkpoint_name}'

    test_loader, train_loader = load_datasets(args, experiment_name,
                                              num_workers, dataset)

    print('Loading model')
    model = get_model(architecture,
                      num_embeddings,
                      device,
                      neighborhood,
                      selection_fn,
                      embed_dim,
                      parallel=False,
                      **kwargs)
    model.load_state_dict(torch.load(os.path.join('..', checkpoint_path)),
                          strict=False)
    model = model.to(device)
    model.eval()

    for batch in train_loader:
        print('decoding')
        X = model.decode_code(batch[1].to(next(model.parameters()).device))

        print('decoded')
        utils.save_image(
            torch.cat([X], 0),
            'X_img.png',
            nrow=1,
            normalize=True,
            range=(-1, 1),
        )
        a = 5
        exit()
Пример #7
0
def get_PSNR(size, device, dataset, data_path, num_workers, num_embeddings,
             architecture, ckpt_epoch, neighborhood, selection_fn, embed_dim,
             **kwargs):
    print('setting up dataset')
    _, test_dataset = get_dataset(dataset, data_path, size)

    print('creating data loaders')
    test_loader = DataLoader(test_dataset,
                             batch_size=kwargs['vae_batch'],
                             shuffle=True,
                             num_workers=num_workers)

    experiment_name = util_funcs.create_experiment_name(
        architecture, dataset, num_embeddings, neighborhood, selection_fn,
        size, **kwargs)
    checkpoint_name = util_funcs.create_checkpoint_name(
        experiment_name, ckpt_epoch)
    checkpoint_path = f"{kwargs['checkpoint_path']}/{checkpoint_name}"
    _validate_args(checkpoint_path)

    print('Calculating PSNR for: {}'.format(checkpoint_name))

    print('Loading model')
    model = get_model(architecture,
                      num_embeddings,
                      device,
                      neighborhood,
                      selection_fn,
                      embed_dim,
                      parallel=False,
                      **kwargs)
    model.load_state_dict(torch.load(os.path.join('..', checkpoint_path),
                                     map_location='cuda:0'),
                          strict=False)
    model = model.to(device)
    model.eval()

    mse = nn.MSELoss()
    MAX_i = 255
    to_MAX = lambda t: (t + 1) * MAX_i / 2
    psnr_term = 20 * torch.log10(torch.ones(1, 1) * MAX_i)

    # calculate PSNR over test set
    sparsity = 0
    top_percentiles = 0
    num_zeros = 0
    psnrs = 0
    done = 0

    for batch in tqdm(test_loader, desc='Calculating PSNR'):
        with torch.no_grad():
            img = batch[0].to(device)
            out, _, num_quantization_steps, mean_D, mean_Z, norm_Z, top_percentile, num_zeros = model(
                img)

        if psnrs == 0 and os.path.isdir(
                args.sample_save_path):  # save the first test batch
            _save_tensors(
                experiment_name,
                img,
                out,
            )

        cur_psnr = psnr_term.item() - 10 * torch.log10(
            mse(to_MAX(out), to_MAX(img)))

        # Gather data
        psnrs += cur_psnr
        sparsity += norm_Z.mean()
        top_percentiles += top_percentile.mean()
        num_zeros += num_zeros.mean()
        done += 1

    # Dump results
    print('sparsity: {}'.format(sparsity))
    print('done: {}'.format(done))
    print('(sparsity/float(done)).item(): {}'.format(
        (sparsity / float(done)).item()))
    avg_psnr = (psnrs / float(done)).item()
    avg_top_percentiles = (top_percentiles / float(done)).item()
    avg_num_zeros = (num_zeros / float(done)).item()
    avg_spasity = (sparsity / float(done)).item()

    print('#' * 30)
    print('Experiment name: {}'.format(experiment_name))
    print('Epoch name: {}'.format(ckpt_epoch))
    print('avg_psnr: {}'.format(avg_psnr))
    print('is_quantize_coefs: {}'.format(kwargs['is_quantize_coefs']))

    # dump params and stats into a JSON file
    _save_root = os.path.join(args.sample_save_path, experiment_name)
    os.makedirs(_save_root, exist_ok=True)
    with open(f'{_save_root}/eval_result_{_NOW}.json', 'w') as fp:
        json.dump(
            dict(
                # --- params
                dataset=args.dataset,
                selection_fn=args.selection_fn,
                embed_dim=args.embed_dim,
                num_atoms=args.num_embeddings,
                image_size=args.size,
                batch_size=args.vae_batch,
                num_strides=args.num_strides,
                num_nonzero=args.num_nonzero,
                normalize_x=args.normalize_x,
                normalize_d=args.normalize_dict,
                epoch=args.ckpt_epoch,
                # --- stats
                psnr=avg_psnr,
                compression=None,  # todo - calculate this
                atom_bits=None,  # todo - calculate this
                non_zero_mean=avg_spasity,
                non_zero_99pct=avg_top_percentiles,
                # tmp=args.tmp,
            ),
            fp,
            indent=1)