Exemplo n.º 1
0
def create_base_transform(i):
    if args.base_transform_type == 'rq':
        return transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=args.hidden_features,
                        num_blocks=args.num_transform_blocks,
                        dropout_probability=args.dropout_probability,
                        use_batch_norm=args.use_batch_norm),
            num_bins=args.num_bins,
            apply_unconditional_transform=False,
        )
    elif args.base_transform_type == 'affine':
        return transforms.AffineCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=args.hidden_features,
                        num_blocks=args.num_transform_blocks,
                        dropout_probability=args.dropout_probability,
                        use_batch_norm=args.use_batch_norm))
    else:
        raise ValueError
Exemplo n.º 2
0
 def create_base_transform(i, context_features=None):
     if args.prior_type == 'affine-coupling':
         return transforms.AffineCouplingTransform(
             mask=utils.create_alternating_binary_mask(
                 features=args.latent_features, even=(i % 2 == 0)),
             transform_net_create_fn=lambda in_features, out_features: nn_.
             ResidualNet(in_features=in_features,
                         out_features=out_features,
                         hidden_features=args.hidden_features,
                         context_features=context_features,
                         num_blocks=args.num_transform_blocks,
                         activation=F.relu,
                         dropout_probability=args.dropout_probability,
                         use_batch_norm=args.use_batch_norm))
     elif args.prior_type == 'rq-coupling':
         return transforms.PiecewiseRationalQuadraticCouplingTransform(
             mask=utils.create_alternating_binary_mask(
                 features=args.latent_features, even=(i % 2 == 0)),
             transform_net_create_fn=lambda in_features, out_features: nn_.
             ResidualNet(in_features=in_features,
                         out_features=out_features,
                         hidden_features=args.hidden_features,
                         context_features=context_features,
                         num_blocks=args.num_transform_blocks,
                         activation=F.relu,
                         dropout_probability=args.dropout_probability,
                         use_batch_norm=args.use_batch_norm),
             num_bins=args.num_bins,
             tails='linear',
             tail_bound=args.tail_bound,
             apply_unconditional_transform=args.
             apply_unconditional_transform,
         )
     elif args.prior_type == 'affine-autoregressive':
         return transforms.MaskedAffineAutoregressiveTransform(
             features=args.latent_features,
             hidden_features=args.hidden_features,
             context_features=context_features,
             num_blocks=args.num_transform_blocks,
             use_residual_blocks=True,
             random_mask=False,
             activation=F.relu,
             dropout_probability=args.dropout_probability,
             use_batch_norm=args.use_batch_norm)
     elif args.prior_type == 'rq-autoregressive':
         return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
             features=args.latent_features,
             hidden_features=args.hidden_features,
             context_features=context_features,
             num_bins=args.num_bins,
             tails='linear',
             tail_bound=args.tail_bound,
             num_blocks=args.num_transform_blocks,
             use_residual_blocks=True,
             random_mask=False,
             activation=F.relu,
             dropout_probability=args.dropout_probability,
             use_batch_norm=args.use_batch_norm)
     else:
         raise ValueError
Exemplo n.º 3
0
def create_base_transform(i):
    if args.base_transform_type == 'affine':
        return transforms.AffineCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=32,
                        num_blocks=2,
                        use_batch_norm=True))
    else:
        return transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=32,
                        num_blocks=2,
                        use_batch_norm=True),
            tails='linear',
            tail_bound=5,
            num_bins=args.num_bins,
            apply_unconditional_transform=False)
Exemplo n.º 4
0
def create_transform_step(num_channels, hidden_channels, actnorm,
                          coupling_layer_type, spline_params, use_resnet,
                          num_res_blocks, resnet_batchnorm, dropout_prob):
    if use_resnet:

        def create_convnet(in_channels, out_channels):
            net = nn_.ConvResidualNet(in_channels=in_channels,
                                      out_channels=out_channels,
                                      hidden_channels=hidden_channels,
                                      num_blocks=num_res_blocks,
                                      use_batch_norm=resnet_batchnorm,
                                      dropout_probability=dropout_prob)
            return net
    else:
        if dropout_prob != 0.:
            raise ValueError()

        def create_convnet(in_channels, out_channels):
            return ConvNet(in_channels, hidden_channels, out_channels)

    mask = utils.create_mid_split_binary_mask(num_channels)

    if coupling_layer_type == 'cubic_spline':
        coupling_layer = transforms.PiecewiseCubicCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'])
    elif coupling_layer_type == 'quadratic_spline':
        coupling_layer = transforms.PiecewiseQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'])
    elif coupling_layer_type == 'rational_quadratic_spline':
        coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'],
            min_derivative=spline_params['min_derivative'])
    elif coupling_layer_type == 'affine':
        coupling_layer = transforms.AffineCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    elif coupling_layer_type == 'additive':
        coupling_layer = transforms.AdditiveCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    else:
        raise RuntimeError('Unknown coupling_layer_type')

    step_transforms = []

    if actnorm:
        step_transforms.append(transforms.ActNorm(num_channels))

    step_transforms.extend(
        [transforms.OneByOneConvolution(num_channels), coupling_layer])

    return transforms.CompositeTransform(step_transforms)
Exemplo n.º 5
0
distribution = distributions.TweakedUniform(
    low=torch.zeros(dim),
    high=torch.ones(dim)
)
transform = transforms.CompositeTransform([
    transforms.CompositeTransform([
        transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=utils.create_alternating_binary_mask(
                features=dim,
                even=(i % 2 == 0)
            ),
            transform_net_create_fn=lambda in_features, out_features:
            nn_.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=args.hidden_features,
                num_blocks=args.num_transform_blocks,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm
            ),
            num_bins=args.num_bins,
            tails=None,
            tail_bound=1,
            # apply_unconditional_transform=args.apply_unconditional_transform,
            min_bin_width=args.min_bin_width
        ),
    ]) for i in range(args.num_flow_steps)
])

flow = flows.Flow(transform, distribution).to(device)
path = os.path.join(cutils.get_final_root(), '{}-final.t'.format(dataset_name))
state_dict = torch.load(path)