Example #1
0
    def create_approximate_posterior():
        if args.approximate_posterior_type == 'diagonal-normal':
            context_encoder = nn_.ConvEncoder(
                context_features=args.context_features,
                channels_multiplier=16,
                dropout_probability=args.dropout_probability_encoder_decoder)
            approximate_posterior = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

        else:
            context_encoder = nn.Linear(args.context_features,
                                        2 * args.latent_features)
            distribution = distributions_.ConditionalDiagonalNormal(
                shape=[args.latent_features], context_encoder=context_encoder)

            transform = transforms.CompositeTransform([
                transforms.CompositeTransform([
                    create_linear_transform(),
                    create_base_transform(
                        i, context_features=args.context_features)
                ]) for i in range(args.num_flow_steps)
            ])
            transform = transforms.CompositeTransform(
                [transform, create_linear_transform()])
            approximate_posterior = flows.Flow(
                transforms.InverseTransform(transform), distribution)

        return approximate_posterior
Example #2
0
def create_transform():
    transform = transforms.CompositeTransform([
        transforms.CompositeTransform(
            [create_linear_transform(),
             create_base_transform(i)]) for i in range(args.num_flow_steps)
    ] + [create_linear_transform()])
    return transform
Example #3
0
File: uci.py Project: xqding/nsf
def create_linear_transform():
    if args.linear_transform_type == 'permutation':
        return transforms.RandomPermutation(features=features)
    elif args.linear_transform_type == 'lu':
        return transforms.CompositeTransform([
            transforms.RandomPermutation(features=features),
            transforms.LULinear(features, identity_init=True)
        ])
    elif args.linear_transform_type == 'svd':
        return transforms.CompositeTransform([
            transforms.RandomPermutation(features=features),
            transforms.SVDLinear(features, num_householder=10, identity_init=True)
        ])
    else:
        raise ValueError
Example #4
0
    def create_prior():
        if args.prior_type == 'standard-normal':
            prior = distributions_.StandardNormal((args.latent_features, ))

        else:
            distribution = distributions_.StandardNormal(
                (args.latent_features, ))
            transform = transforms.CompositeTransform([
                transforms.CompositeTransform(
                    [create_linear_transform(),
                     create_base_transform(i)])
                for i in range(args.num_flow_steps)
            ])
            transform = transforms.CompositeTransform(
                [transform, create_linear_transform()])
            prior = flows.Flow(transform, distribution)

        return prior
Example #5
0
def eval_reconstruct(num_bits,
                     batch_size,
                     seed,
                     num_reconstruct_batches,
                     _log,
                     output_path=''):
    torch.set_grad_enabled(False)

    device = set_device()

    torch.manual_seed(seed)
    np.random.seed(seed)

    train_dataset, _, (c, h, w) = get_train_valid_data()

    flow = create_flow(c, h, w).to(device)
    flow.eval()

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True)

    identity_transform = transforms.CompositeTransform(
        [flow._transform,
         transforms.InverseTransform(flow._transform)])

    first_batch = True
    abs_diff = []
    for batch, _ in tqdm(load_num_batches(train_loader,
                                          num_reconstruct_batches),
                         total=num_reconstruct_batches):
        batch = batch.to(device)
        batch_rec, _ = identity_transform(batch)
        abs_diff.append(torch.abs(batch_rec - batch))

        if first_batch:
            batch = Preprocess(num_bits).inverse(batch[:36, ...])
            batch_rec = Preprocess(num_bits).inverse(batch_rec[:36, ...])

            save_image(batch.cpu(),
                       os.path.join(output_path, 'invertibility_orig.png'),
                       nrow=6,
                       padding=0)

            save_image(batch_rec.cpu(),
                       os.path.join(output_path, 'invertibility_rec.png'),
                       nrow=6,
                       padding=0)

            first_batch = False

    abs_diff = torch.cat(abs_diff)

    print('max abs diff: {:.4f}'.format(torch.max(abs_diff).item()))
Example #6
0
    def __init__(self,
                 features,
                 hidden_features,
                 num_layers,
                 num_blocks_per_layer,
                 use_volume_preserving=False,
                 activation=F.relu,
                 dropout_probability=0.,
                 batch_norm_within_layers=False,
                 batch_norm_between_layers=False):

        if use_volume_preserving:
            coupling_constructor = transforms.AdditiveCouplingTransform
        else:
            coupling_constructor = transforms.AffineCouplingTransform

        mask = torch.ones(features)
        mask[::2] = -1

        def create_resnet(in_features, out_features):
            return nn_.ResidualNet(
                in_features,
                out_features,
                hidden_features=hidden_features,
                num_blocks=num_blocks_per_layer,
                activation=activation,
                dropout_probability=dropout_probability,
                use_batch_norm=batch_norm_within_layers
            )

        layers = []
        for _ in range(num_layers):
            transform = coupling_constructor(
                mask=mask,
                transform_net_create_fn=create_resnet
            )
            layers.append(transform)
            mask *= -1
            if batch_norm_between_layers:
                layers.append(transforms.BatchNorm(features=features))

        super().__init__(
            transform=transforms.CompositeTransform(layers),
            distribution=distributions.StandardNormal([features]),
        )
Example #7
0
    def __init__(self,
                 features,
                 hidden_features,
                 num_layers,
                 num_blocks_per_layer,
                 use_residual_blocks=True,
                 use_random_masks=False,
                 use_random_permutations=False,
                 activation=F.relu,
                 dropout_probability=0.,
                 batch_norm_within_layers=False,
                 batch_norm_between_layers=False):

        if use_random_permutations:
            permutation_constructor = transforms.RandomPermutation
        else:
            permutation_constructor = transforms.ReversePermutation

        layers = []
        for _ in range(num_layers):
            layers.append(permutation_constructor(features))
            layers.append(
                transforms.MaskedAffineAutoregressiveTransform(
                    features=features,
                    hidden_features=hidden_features,
                    num_blocks=num_blocks_per_layer,
                    use_residual_blocks=use_residual_blocks,
                    random_mask=use_random_masks,
                    activation=activation,
                    dropout_probability=dropout_probability,
                    use_batch_norm=batch_norm_within_layers,
                ))
            if batch_norm_between_layers:
                layers.append(transforms.BatchNorm(features))

        super().__init__(
            transform=transforms.CompositeTransform(layers),
            distribution=distributions.StandardNormal([features]),
        )
Example #8
0
        return transforms.PiecewiseRationalLinearCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=32,
                        num_blocks=2,
                        use_batch_norm=True),
            tails='linear',
            tail_bound=5,
            num_bins=args.num_bins,
            apply_unconditional_transform=False)


transform = transforms.CompositeTransform(
    [create_base_transform(i) for i in range(10)])

flow = flows.Flow(transform, distribution).to(device)

n_params = utils.get_num_parameters(flow)
print('There are {} trainable parameters in this model.'.format(n_params))

# create optimizer
optimizer = optim.Adam(flow.parameters(), lr=args.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                 args.num_training_steps, 0)

# create summary writer and write to log directory
timestamp = cutils.get_timestamp()
log_dir = os.path.join(cutils.get_log_root(), args.dataset_name, timestamp)
writer = SummaryWriter(log_dir=log_dir)
def create_linear_transform():
    linear_transform = transforms.CompositeTransform(
        [transforms.RandomPermutation(features=feature_dim)])
    return linear_transform
Example #10
0
transform = transforms.CompositeTransform([
    # transforms.Sigmoid(),
    transforms.PiecewiseRationalLinearCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim, even=True),
            transform_net_create_fn=lambda in_features, out_features: nn_.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=32,
                num_blocks=2,
                use_batch_norm=True
            ),
            tails='linear',
            tail_bound=5,
            num_bins=args.num_bins,
            apply_unconditional_transform=False
    ),
    transforms.PiecewiseRationalLinearCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim, even=False),
            transform_net_create_fn=lambda in_features, out_features: nn_.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=32,
                num_blocks=2,
                use_batch_norm=True
            ),
            tails='linear',
            tail_bound=5,
            num_bins=args.num_bins,
            apply_unconditional_transform=False
    )
])
Example #11
0
                in_features=in_features,
                out_features=out_features,
                hidden_features=args.hidden_features,
                num_blocks=args.num_transform_blocks,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm
            ),
            num_bins=args.num_bins,
            apply_unconditional_transform=False,
        )
    else:
        raise ValueError


transform = transforms.CompositeTransform([
    create_base_transform(i) for i in range(args.num_flow_steps)
])

flow = flows.Flow(transform, distribution).to(device)

n_params = utils.get_num_parameters(flow)
print('There are {} trainable parameters in this model.'.format(n_params))

# create optimizer
optimizer = optim.Adam(flow.parameters(), lr=args.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_total_steps)

# create summary writer and write to log directory
timestamp = cutils.get_timestamp()
log_dir   = os.path.join(cutils.get_log_root(), args.dataset_name)
writer    = SummaryWriter(log_dir=log_dir)
Example #12
0
grid_loader = data.DataLoader(dataset=grid_dataset,
                              batch_size=1000,
                              drop_last=False)

# create model
distribution = uniform.TweakedUniform(low=torch.zeros(dim),
                                      high=torch.ones(dim))
transform = transforms.CompositeTransform([
    transforms.PiecewiseRationalLinearCouplingTransform(
        mask=utils.create_alternating_binary_mask(features=dim,
                                                  even=(i % 2 == 0)),
        transform_net_create_fn=lambda in_features, out_features: nn_.
        ResidualNet(in_features=in_features,
                    out_features=out_features,
                    hidden_features=args.hidden_features,
                    num_blocks=args.num_transform_blocks,
                    dropout_probability=args.dropout_probability,
                    use_batch_norm=args.use_batch_norm),
        num_bins=args.num_bins,
        tails=None,
        tail_bound=1,
        # apply_unconditional_transform=args.apply_unconditional_transform,
        min_bin_width=args.min_bin_width) for i in range(args.num_flow_steps)
])

flow = flows.Flow(transform, distribution).to(device)
path = os.path.join(cutils.get_final_root(), '{}-final.t'.format(dataset_name))
state_dict = torch.load(path)
flow.load_state_dict(state_dict)
flow.eval()
Example #13
0
def train_flow(flow, train_dataset, val_dataset, dataset_dims, device,
               batch_size, num_steps, learning_rate, cosine_annealing,
               warmup_fraction, temperatures, num_bits, num_workers, intervals,
               multi_gpu, actnorm, optimizer_checkpoint, start_step, eta_min,
               _log):
    run_dir = fso.dir

    flow = flow.to(device)

    summary_writer = SummaryWriter(run_dir, max_queue=100)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              num_workers=num_workers)

    if val_dataset:
        val_loader = DataLoader(dataset=val_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers)
    else:
        val_loader = None

    # Random batch and identity transform for reconstruction evaluation.
    random_batch, _ = next(
        iter(
            DataLoader(
                dataset=train_dataset,
                batch_size=batch_size,
                num_workers=
                0  # Faster than starting all workers just to get a single batch.
            )))
    identity_transform = transforms.CompositeTransform(
        [flow._transform,
         transforms.InverseTransform(flow._transform)])

    optimizer = torch.optim.Adam(flow.parameters(), lr=learning_rate)

    if optimizer_checkpoint is not None:
        optimizer.load_state_dict(torch.load(optimizer_checkpoint))
        _log.info(
            'Optimizer state loaded from {}'.format(optimizer_checkpoint))

    if cosine_annealing:
        if warmup_fraction == 0.:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer=optimizer,
                T_max=num_steps,
                last_epoch=-1 if start_step == 0 else start_step,
                eta_min=eta_min)
        else:
            scheduler = optim.CosineAnnealingWarmUpLR(
                optimizer=optimizer,
                warm_up_epochs=int(warmup_fraction * num_steps),
                total_epochs=num_steps,
                last_epoch=-1 if start_step == 0 else start_step,
                eta_min=eta_min)
    else:
        scheduler = None

    def nats_to_bits_per_dim(x):
        c, h, w = dataset_dims
        return autils.nats_to_bits_per_dim(x, c, h, w)

    _log.info('Starting training...')

    best_val_log_prob = None
    start_time = None
    num_batches = num_steps - start_step

    for step, (batch,
               _) in enumerate(load_num_batches(loader=train_loader,
                                                num_batches=num_batches),
                               start=start_step):
        if step == 0:
            start_time = time.time(
            )  # Runtime estimate will be more accurate if set here.

        flow.train()

        optimizer.zero_grad()

        batch = batch.to(device)

        if multi_gpu:
            if actnorm and step == 0:
                # Is using actnorm, data-dependent initialization doesn't work with data_parallel,
                # so pass a single batch on a single GPU before the first step.
                flow.log_prob(batch[:batch.shape[0] //
                                    torch.cuda.device_count(), ...])

            # Split along the batch dimension and put each split on a separate GPU. All available
            # GPUs are used.
            log_density = nn.parallel.data_parallel(LogProbWrapper(flow),
                                                    batch)
        else:
            log_density = flow.log_prob(batch)

        loss = -nats_to_bits_per_dim(torch.mean(log_density))

        loss.backward()
        optimizer.step()

        if scheduler is not None:
            scheduler.step()
            summary_writer.add_scalar('learning_rate',
                                      scheduler.get_lr()[0], step)

        summary_writer.add_scalar('loss', loss.item(), step)

        if best_val_log_prob:
            summary_writer.add_scalar('best_val_log_prob', best_val_log_prob,
                                      step)

        flow.eval()  # Everything beyond this point is evaluation.

        if step % intervals['log'] == 0:
            elapsed_time = time.time() - start_time
            progress = autils.progress_string(elapsed_time, step, num_steps)
            _log.info("It: {}/{} loss: {:.3f} [{}]".format(
                step, num_steps, loss, progress))

        if step % intervals['sample'] == 0:
            fig, axs = plt.subplots(1,
                                    len(temperatures),
                                    figsize=(4 * len(temperatures), 4))
            for temperature, ax in zip(temperatures, axs.flat):
                with torch.no_grad():
                    noise = flow._distribution.sample(64) * temperature
                    samples, _ = flow._transform.inverse(noise)
                    samples = Preprocess(num_bits).inverse(samples)

                autils.imshow(make_grid(samples, nrow=8), ax)

                ax.set_title('T={:.2f}'.format(temperature))

            summary_writer.add_figure(tag='samples',
                                      figure=fig,
                                      global_step=step)

            plt.close(fig)

        if step > 0 and step % intervals['eval'] == 0 and (val_loader
                                                           is not None):
            if multi_gpu:

                def log_prob_fn(batch):
                    return nn.parallel.data_parallel(LogProbWrapper(flow),
                                                     batch.to(device))
            else:

                def log_prob_fn(batch):
                    return flow.log_prob(batch.to(device))

            val_log_prob = autils.eval_log_density(log_prob_fn=log_prob_fn,
                                                   data_loader=val_loader)
            val_log_prob = nats_to_bits_per_dim(val_log_prob).item()

            _log.info("It: {}/{} val_log_prob: {:.3f}".format(
                step, num_steps, val_log_prob))
            summary_writer.add_scalar('val_log_prob', val_log_prob, step)

            if best_val_log_prob is None or val_log_prob > best_val_log_prob:
                best_val_log_prob = val_log_prob

                torch.save(flow.state_dict(),
                           os.path.join(run_dir, 'flow_best.pt'))
                _log.info(
                    'It: {}/{} best val_log_prob improved, saved flow_best.pt'.
                    format(step, num_steps))

        if step > 0 and (step % intervals['save'] == 0
                         or step == (num_steps - 1)):
            torch.save(optimizer.state_dict(),
                       os.path.join(run_dir, 'optimizer_last.pt'))
            torch.save(flow.state_dict(), os.path.join(run_dir,
                                                       'flow_last.pt'))
            _log.info(
                'It: {}/{} saved optimizer_last.pt and flow_last.pt'.format(
                    step, num_steps))

        if step > 0 and step % intervals['reconstruct'] == 0:
            with torch.no_grad():
                random_batch_ = random_batch.to(device)
                random_batch_rec, logabsdet = identity_transform(random_batch_)

                max_abs_diff = torch.max(
                    torch.abs(random_batch_rec - random_batch_))
                max_logabsdet = torch.max(logabsdet)

            # fig, axs = plt.subplots(1, 2, figsize=(8, 4))
            # autils.imshow(make_grid(Preprocess(num_bits).inverse(random_batch[:36, ...]),
            #                         nrow=6), axs[0])
            # autils.imshow(make_grid(Preprocess(num_bits).inverse(random_batch_rec[:36, ...]),
            #                         nrow=6), axs[1])
            # summary_writer.add_figure(tag='reconstr', figure=fig, global_step=step)
            # plt.close(fig)

            summary_writer.add_scalar(tag='max_reconstr_abs_diff',
                                      scalar_value=max_abs_diff.item(),
                                      global_step=step)
            summary_writer.add_scalar(tag='max_reconstr_logabsdet',
                                      scalar_value=max_logabsdet.item(),
                                      global_step=step)
Example #14
0
def create_transform(c, h, w, levels, hidden_channels, steps_per_level, alpha,
                     num_bits, preprocessing, multi_scale):
    if not isinstance(hidden_channels, list):
        hidden_channels = [hidden_channels] * levels

    if multi_scale:
        mct = transforms.MultiscaleCompositeTransform(num_transforms=levels)
        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)

            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    create_transform_step(c, level_hidden_channels)
                    for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )

            new_shape = mct.add_transform(transform_level, (c, h, w))
            if new_shape:  # If not last layer
                c, h, w = new_shape
    else:
        all_transforms = []

        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)

            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    create_transform_step(c, level_hidden_channels)
                    for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )
            all_transforms.append(transform_level)

        all_transforms.append(
            transforms.ReshapeTransform(input_shape=(c, h, w),
                                        output_shape=(c * h * w, )))
        mct = transforms.CompositeTransform(all_transforms)

    # Inputs to the model in [0, 2 ** num_bits]

    if preprocessing == 'glow':
        # Map to [-0.5,0.5]
        preprocess_transform = transforms.AffineScalarTransform(
            scale=(1. / 2**num_bits), shift=-0.5)
    elif preprocessing == 'realnvp':
        preprocess_transform = transforms.CompositeTransform([
            # Map to [0,1]
            transforms.AffineScalarTransform(scale=(1. / 2**num_bits)),
            # Map into unconstrained space as done in RealNVP
            transforms.AffineScalarTransform(shift=alpha, scale=(1 - alpha)),
            transforms.Logit()
        ])

    elif preprocessing == 'realnvp_2alpha':
        preprocess_transform = transforms.CompositeTransform([
            transforms.AffineScalarTransform(scale=(1. / 2**num_bits)),
            transforms.AffineScalarTransform(shift=alpha,
                                             scale=(1 - 2. * alpha)),
            transforms.Logit()
        ])
    else:
        raise RuntimeError(
            'Unknown preprocessing type: {}'.format(preprocessing))

    return transforms.CompositeTransform([preprocess_transform, mct])
Example #15
0
def create_transform_step(num_channels, hidden_channels, actnorm,
                          coupling_layer_type, spline_params, use_resnet,
                          num_res_blocks, resnet_batchnorm, dropout_prob):
    if use_resnet:

        def create_convnet(in_channels, out_channels):
            net = nn_.ConvResidualNet(in_channels=in_channels,
                                      out_channels=out_channels,
                                      hidden_channels=hidden_channels,
                                      num_blocks=num_res_blocks,
                                      use_batch_norm=resnet_batchnorm,
                                      dropout_probability=dropout_prob)
            return net
    else:
        if dropout_prob != 0.:
            raise ValueError()

        def create_convnet(in_channels, out_channels):
            return ConvNet(in_channels, hidden_channels, out_channels)

    mask = utils.create_mid_split_binary_mask(num_channels)

    if coupling_layer_type == 'cubic_spline':
        coupling_layer = transforms.PiecewiseCubicCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'])
    elif coupling_layer_type == 'quadratic_spline':
        coupling_layer = transforms.PiecewiseQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'])
    elif coupling_layer_type == 'rational_quadratic_spline':
        coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'],
            min_derivative=spline_params['min_derivative'])
    elif coupling_layer_type == 'affine':
        coupling_layer = transforms.AffineCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    elif coupling_layer_type == 'additive':
        coupling_layer = transforms.AdditiveCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    else:
        raise RuntimeError('Unknown coupling_layer_type')

    step_transforms = []

    if actnorm:
        step_transforms.append(transforms.ActNorm(num_channels))

    step_transforms.extend(
        [transforms.OneByOneConvolution(num_channels), coupling_layer])

    return transforms.CompositeTransform(step_transforms)