Exemple #1
0
    def create_approximate_posterior():
        if args.approximate_posterior_type == 'diagonal-normal':
            context_encoder = nn_.ConvEncoder(
                context_features=args.context_features,
                channels_multiplier=16,
                dropout_probability=args.dropout_probability_encoder_decoder)
            approximate_posterior = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

        else:
            context_encoder = nn.Linear(args.context_features,
                                        2 * args.latent_features)
            distribution = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

            transform = transforms.CompositeTransform([
                transforms.CompositeTransform([
                    create_linear_transform(),
                    create_base_transform(
                        i, context_features=args.context_features)
                ]) for i in range(args.num_flow_steps)
            ])
            transform = transforms.CompositeTransform(
                [transform, create_linear_transform()])
            approximate_posterior = flows.Flow(
                transforms.InverseTransform(transform), distribution)

        return approximate_posterior
Exemple #2
0
def eval_reconstruct(num_bits,
                     batch_size,
                     seed,
                     num_reconstruct_batches,
                     _log,
                     output_path=''):
    torch.set_grad_enabled(False)

    device = set_device()

    torch.manual_seed(seed)
    np.random.seed(seed)

    train_dataset, _, (c, h, w) = get_train_valid_data()

    flow = create_flow(c, h, w).to(device)
    flow.eval()

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True)

    identity_transform = transforms.CompositeTransform(
        [flow._transform,
         transforms.InverseTransform(flow._transform)])

    first_batch = True
    abs_diff = []
    for batch, _ in tqdm(load_num_batches(train_loader,
                                          num_reconstruct_batches),
                         total=num_reconstruct_batches):
        batch = batch.to(device)
        batch_rec, _ = identity_transform(batch)
        abs_diff.append(torch.abs(batch_rec - batch))

        if first_batch:
            batch = Preprocess(num_bits).inverse(batch[:36, ...])
            batch_rec = Preprocess(num_bits).inverse(batch_rec[:36, ...])

            save_image(batch.cpu(),
                       os.path.join(output_path, 'invertibility_orig.png'),
                       nrow=6,
                       padding=0)

            save_image(batch_rec.cpu(),
                       os.path.join(output_path, 'invertibility_rec.png'),
                       nrow=6,
                       padding=0)

            first_batch = False

    abs_diff = torch.cat(abs_diff)

    print('max abs diff: {:.4f}'.format(torch.max(abs_diff).item()))
Exemple #3
0
def train_flow(flow, train_dataset, val_dataset, dataset_dims, device,
               batch_size, num_steps, learning_rate, cosine_annealing,
               warmup_fraction, temperatures, num_bits, num_workers, intervals,
               multi_gpu, actnorm, optimizer_checkpoint, start_step, eta_min,
               _log):
    run_dir = fso.dir

    flow = flow.to(device)

    summary_writer = SummaryWriter(run_dir, max_queue=100)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              num_workers=num_workers)

    if val_dataset:
        val_loader = DataLoader(dataset=val_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers)
    else:
        val_loader = None

    # Random batch and identity transform for reconstruction evaluation.
    random_batch, _ = next(
        iter(
            DataLoader(
                dataset=train_dataset,
                batch_size=batch_size,
                num_workers=
                0  # Faster than starting all workers just to get a single batch.
            )))
    identity_transform = transforms.CompositeTransform(
        [flow._transform,
         transforms.InverseTransform(flow._transform)])

    optimizer = torch.optim.Adam(flow.parameters(), lr=learning_rate)

    if optimizer_checkpoint is not None:
        optimizer.load_state_dict(torch.load(optimizer_checkpoint))
        _log.info(
            'Optimizer state loaded from {}'.format(optimizer_checkpoint))

    if cosine_annealing:
        if warmup_fraction == 0.:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer=optimizer,
                T_max=num_steps,
                last_epoch=-1 if start_step == 0 else start_step,
                eta_min=eta_min)
        else:
            scheduler = optim.CosineAnnealingWarmUpLR(
                optimizer=optimizer,
                warm_up_epochs=int(warmup_fraction * num_steps),
                total_epochs=num_steps,
                last_epoch=-1 if start_step == 0 else start_step,
                eta_min=eta_min)
    else:
        scheduler = None

    def nats_to_bits_per_dim(x):
        c, h, w = dataset_dims
        return autils.nats_to_bits_per_dim(x, c, h, w)

    _log.info('Starting training...')

    best_val_log_prob = None
    start_time = None
    num_batches = num_steps - start_step

    for step, (batch,
               _) in enumerate(load_num_batches(loader=train_loader,
                                                num_batches=num_batches),
                               start=start_step):
        if step == 0:
            start_time = time.time(
            )  # Runtime estimate will be more accurate if set here.

        flow.train()

        optimizer.zero_grad()

        batch = batch.to(device)

        if multi_gpu:
            if actnorm and step == 0:
                # Is using actnorm, data-dependent initialization doesn't work with data_parallel,
                # so pass a single batch on a single GPU before the first step.
                flow.log_prob(batch[:batch.shape[0] //
                                    torch.cuda.device_count(), ...])

            # Split along the batch dimension and put each split on a separate GPU. All available
            # GPUs are used.
            log_density = nn.parallel.data_parallel(LogProbWrapper(flow),
                                                    batch)
        else:
            log_density = flow.log_prob(batch)

        loss = -nats_to_bits_per_dim(torch.mean(log_density))

        loss.backward()
        optimizer.step()

        if scheduler is not None:
            scheduler.step()
            summary_writer.add_scalar('learning_rate',
                                      scheduler.get_lr()[0], step)

        summary_writer.add_scalar('loss', loss.item(), step)

        if best_val_log_prob:
            summary_writer.add_scalar('best_val_log_prob', best_val_log_prob,
                                      step)

        flow.eval()  # Everything beyond this point is evaluation.

        if step % intervals['log'] == 0:
            elapsed_time = time.time() - start_time
            progress = autils.progress_string(elapsed_time, step, num_steps)
            _log.info("It: {}/{} loss: {:.3f} [{}]".format(
                step, num_steps, loss, progress))

        if step % intervals['sample'] == 0:
            fig, axs = plt.subplots(1,
                                    len(temperatures),
                                    figsize=(4 * len(temperatures), 4))
            for temperature, ax in zip(temperatures, axs.flat):
                with torch.no_grad():
                    noise = flow._distribution.sample(64) * temperature
                    samples, _ = flow._transform.inverse(noise)
                    samples = Preprocess(num_bits).inverse(samples)

                autils.imshow(make_grid(samples, nrow=8), ax)

                ax.set_title('T={:.2f}'.format(temperature))

            summary_writer.add_figure(tag='samples',
                                      figure=fig,
                                      global_step=step)

            plt.close(fig)

        if step > 0 and step % intervals['eval'] == 0 and (val_loader
                                                           is not None):
            if multi_gpu:

                def log_prob_fn(batch):
                    return nn.parallel.data_parallel(LogProbWrapper(flow),
                                                     batch.to(device))
            else:

                def log_prob_fn(batch):
                    return flow.log_prob(batch.to(device))

            val_log_prob = autils.eval_log_density(log_prob_fn=log_prob_fn,
                                                   data_loader=val_loader)
            val_log_prob = nats_to_bits_per_dim(val_log_prob).item()

            _log.info("It: {}/{} val_log_prob: {:.3f}".format(
                step, num_steps, val_log_prob))
            summary_writer.add_scalar('val_log_prob', val_log_prob, step)

            if best_val_log_prob is None or val_log_prob > best_val_log_prob:
                best_val_log_prob = val_log_prob

                torch.save(flow.state_dict(),
                           os.path.join(run_dir, 'flow_best.pt'))
                _log.info(
                    'It: {}/{} best val_log_prob improved, saved flow_best.pt'.
                    format(step, num_steps))

        if step > 0 and (step % intervals['save'] == 0
                         or step == (num_steps - 1)):
            torch.save(optimizer.state_dict(),
                       os.path.join(run_dir, 'optimizer_last.pt'))
            torch.save(flow.state_dict(), os.path.join(run_dir,
                                                       'flow_last.pt'))
            _log.info(
                'It: {}/{} saved optimizer_last.pt and flow_last.pt'.format(
                    step, num_steps))

        if step > 0 and step % intervals['reconstruct'] == 0:
            with torch.no_grad():
                random_batch_ = random_batch.to(device)
                random_batch_rec, logabsdet = identity_transform(random_batch_)

                max_abs_diff = torch.max(
                    torch.abs(random_batch_rec - random_batch_))
                max_logabsdet = torch.max(logabsdet)

            # fig, axs = plt.subplots(1, 2, figsize=(8, 4))
            # autils.imshow(make_grid(Preprocess(num_bits).inverse(random_batch[:36, ...]),
            #                         nrow=6), axs[0])
            # autils.imshow(make_grid(Preprocess(num_bits).inverse(random_batch_rec[:36, ...]),
            #                         nrow=6), axs[1])
            # summary_writer.add_figure(tag='reconstr', figure=fig, global_step=step)
            # plt.close(fig)

            summary_writer.add_scalar(tag='max_reconstr_abs_diff',
                                      scalar_value=max_abs_diff.item(),
                                      global_step=step)
            summary_writer.add_scalar(tag='max_reconstr_logabsdet',
                                      scalar_value=max_logabsdet.item(),
                                      global_step=step)
Exemple #4
0
 def __init__(self, squashing_transform, cdf_transform):
     super().__init__([
         squashing_transform,
         cdf_transform,
         transforms.InverseTransform(squashing_transform)
     ])