Exemple #1
0
def create_base_transform(i):
    if args.base_transform_type == 'rq':
        return transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=utils.create_alternating_binary_mask(
                features=dim,
                even=(i % 2 == 0)
            ),
            transform_net_create_fn=lambda in_features, out_features:
            nn_.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=args.hidden_features,
                num_blocks=args.num_transform_blocks,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm
            ),
            num_bins=args.num_bins,
            apply_unconditional_transform=False,
        )
    elif args.base_transform_type == 'affine':
        return transforms.AffineCouplingTransform(
            mask=utils.create_alternating_binary_mask(
                features=dim,
                even=(i % 2 == 0)
            ),
            transform_net_create_fn=lambda in_features, out_features:
            nn_.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=args.hidden_features,
                num_blocks=args.num_transform_blocks,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm
            )
        )
    elif args.base_transform_type == 'rl':
        return transforms.PiecewiseRationalLinearCouplingTransform(
            mask=utils.create_alternating_binary_mask(
                features=dim,
                even=(i % 2 == 0)
            ),
            transform_net_create_fn=lambda in_features, out_features:
            nn_.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=args.hidden_features,
                num_blocks=args.num_transform_blocks,
                dropout_probability=args.dropout_probability,
                use_batch_norm=args.use_batch_norm
            ),
            num_bins=args.num_bins,
            apply_unconditional_transform=False,
        )
    else:
        raise ValueError
Exemple #2
0
def create_base_transform(i):
    if args.base_transform_type == 'affine':
        return transforms.AffineCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=32,
                        num_blocks=2,
                        use_batch_norm=True))
    elif args.base_transform_type == 'rq-coupling':
        return transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=32,
                        num_blocks=2,
                        use_batch_norm=True),
            tails='linear',
            tail_bound=5,
            num_bins=args.num_bins,
            apply_unconditional_transform=False)
    elif args.base_transform_type == 'rl-coupling':
        return transforms.PiecewiseRationalLinearCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=32,
                        num_blocks=2,
                        use_batch_norm=True),
            tails='linear',
            tail_bound=5,
            num_bins=args.num_bins,
            apply_unconditional_transform=False)
Exemple #3
0
def create_base_transform(i):
    if args.base_transform_type == 'affine-coupling':
        return transforms.AffineCouplingTransform(
            mask=utils.create_alternating_binary_mask(features,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=args.hidden_features,
                        context_features=None,
                        num_blocks=args.num_transform_blocks,
                        activation=F.relu,
                        dropout_probability=args.dropout_probability,
                        use_batch_norm=args.use_batch_norm))
    elif args.base_transform_type == 'quadratic-coupling':
        return transforms.PiecewiseQuadraticCouplingTransform(
            mask=utils.create_alternating_binary_mask(features,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=args.hidden_features,
                        context_features=None,
                        num_blocks=args.num_transform_blocks,
                        activation=F.relu,
                        dropout_probability=args.dropout_probability,
                        use_batch_norm=args.use_batch_norm),
            num_bins=args.num_bins,
            tails='linear',
            tail_bound=args.tail_bound,
            apply_unconditional_transform=args.apply_unconditional_transform)
    elif args.base_transform_type == 'rq-coupling':
        return transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=utils.create_alternating_binary_mask(features,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=args.hidden_features,
                        context_features=None,
                        num_blocks=args.num_transform_blocks,
                        activation=F.relu,
                        dropout_probability=args.dropout_probability,
                        use_batch_norm=args.use_batch_norm),
            num_bins=args.num_bins,
            tails='linear',
            tail_bound=args.tail_bound,
            apply_unconditional_transform=args.apply_unconditional_transform)
    elif args.base_transform_type == 'rl-coupling':
        return transforms.PiecewiseRationalLinearCouplingTransform(
            mask=utils.create_alternating_binary_mask(features,
                                                      even=(i % 2 == 0)),
            transform_net_create_fn=lambda in_features, out_features: nn_.
            ResidualNet(in_features=in_features,
                        out_features=out_features,
                        hidden_features=args.hidden_features,
                        context_features=None,
                        num_blocks=args.num_transform_blocks,
                        activation=F.relu,
                        dropout_probability=args.dropout_probability,
                        use_batch_norm=args.use_batch_norm),
            num_bins=args.num_bins,
            tails='linear',
            tail_bound=args.tail_bound,
            apply_unconditional_transform=args.apply_unconditional_transform)
    elif args.base_transform_type == 'affine-autoregressive':
        return transforms.MaskedAffineAutoregressiveTransform(
            features=features,
            hidden_features=args.hidden_features,
            context_features=None,
            num_blocks=args.num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=args.dropout_probability,
            use_batch_norm=args.use_batch_norm)
    elif args.base_transform_type == 'quadratic-autoregressive':
        return transforms.MaskedPiecewiseQuadraticAutoregressiveTransform(
            features=features,
            hidden_features=args.hidden_features,
            context_features=None,
            num_bins=args.num_bins,
            tails='linear',
            tail_bound=args.tail_bound,
            num_blocks=args.num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=args.dropout_probability,
            use_batch_norm=args.use_batch_norm)
    elif args.base_transform_type == 'rq-autoregressive':
        return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
            features=features,
            hidden_features=args.hidden_features,
            context_features=None,
            num_bins=args.num_bins,
            tails='linear',
            tail_bound=args.tail_bound,
            num_blocks=args.num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=args.dropout_probability,
            use_batch_norm=args.use_batch_norm)
    elif args.base_transform_type == 'rl-autoregressive':
        return transforms.MaskedPiecewiseRationalLinearAutoregressiveTransform(
            features=features,
            hidden_features=args.hidden_features,
            context_features=None,
            num_bins=args.num_bins,
            tails='linear',
            tail_bound=args.tail_bound,
            num_blocks=args.num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=args.dropout_probability,
            use_batch_norm=args.use_batch_norm)
    else:
        raise ValueError
Exemple #4
0
#                 use_batch_norm=True
#             ),
#             num_bins=args.num_bins,
#             apply_unconditional_transform=False
#     )
# ])

transform = transforms.CompositeTransform([
    # transforms.Sigmoid(),
    transforms.PiecewiseRationalLinearCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim, even=True),
            transform_net_create_fn=lambda in_features, out_features: nn_.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=32,
                num_blocks=2,
                use_batch_norm=True
            ),
            tails='linear',
            tail_bound=5,
            num_bins=args.num_bins,
            apply_unconditional_transform=False
    ),
    transforms.PiecewiseRationalLinearCouplingTransform(
            mask=utils.create_alternating_binary_mask(features=dim, even=False),
            transform_net_create_fn=lambda in_features, out_features: nn_.ResidualNet(
                in_features=in_features,
                out_features=out_features,
                hidden_features=32,
                num_blocks=2,
                use_batch_norm=True
            ),
Exemple #5
0
grid_loader = data.DataLoader(dataset=grid_dataset,
                              batch_size=1000,
                              drop_last=False)

# create model
distribution = uniform.TweakedUniform(low=torch.zeros(dim),
                                      high=torch.ones(dim))
transform = transforms.CompositeTransform([
    transforms.PiecewiseRationalLinearCouplingTransform(
        mask=utils.create_alternating_binary_mask(features=dim,
                                                  even=(i % 2 == 0)),
        transform_net_create_fn=lambda in_features, out_features: nn_.
        ResidualNet(in_features=in_features,
                    out_features=out_features,
                    hidden_features=args.hidden_features,
                    num_blocks=args.num_transform_blocks,
                    dropout_probability=args.dropout_probability,
                    use_batch_norm=args.use_batch_norm),
        num_bins=args.num_bins,
        tails=None,
        tail_bound=1,
        # apply_unconditional_transform=args.apply_unconditional_transform,
        min_bin_width=args.min_bin_width) for i in range(args.num_flow_steps)
])

flow = flows.Flow(transform, distribution).to(device)
path = os.path.join(cutils.get_final_root(), '{}-final.t'.format(dataset_name))
state_dict = torch.load(path)
flow.load_state_dict(state_dict)
flow.eval()
Exemple #6
0
def create_transform_step(num_channels, hidden_channels, actnorm,
                          coupling_layer_type, spline_params, use_resnet,
                          num_res_blocks, resnet_batchnorm, dropout_prob):
    if use_resnet:

        def create_convnet(in_channels, out_channels):
            net = nn_.ConvResidualNet(in_channels=in_channels,
                                      out_channels=out_channels,
                                      hidden_channels=hidden_channels,
                                      num_blocks=num_res_blocks,
                                      use_batch_norm=resnet_batchnorm,
                                      dropout_probability=dropout_prob)
            return net
    else:
        if dropout_prob != 0.:
            raise ValueError()

        def create_convnet(in_channels, out_channels):
            return ConvNet(in_channels, hidden_channels, out_channels)

    mask = utils.create_mid_split_binary_mask(num_channels)

    if coupling_layer_type == 'cubic_spline':
        coupling_layer = transforms.PiecewiseCubicCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'])
    elif coupling_layer_type == 'quadratic_spline':
        coupling_layer = transforms.PiecewiseQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'])
    elif coupling_layer_type == 'rational_quadratic_spline':
        coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'],
            min_derivative=spline_params['min_derivative'])
    elif coupling_layer_type == 'rational_linear_spline':
        coupling_layer = transforms.PiecewiseRationalLinearCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails='linear',
            tail_bound=spline_params['tail_bound'],
            num_bins=spline_params['num_bins'],
            apply_unconditional_transform=spline_params[
                'apply_unconditional_transform'],
            min_bin_width=spline_params['min_bin_width'],
            min_bin_height=spline_params['min_bin_height'],
            min_derivative=spline_params['min_derivative'])
    elif coupling_layer_type == 'affine':
        coupling_layer = transforms.AffineCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    elif coupling_layer_type == 'additive':
        coupling_layer = transforms.AdditiveCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    else:
        raise RuntimeError('Unknown coupling_layer_type')

    step_transforms = []

    if actnorm:
        step_transforms.append(transforms.ActNorm(num_channels))

    step_transforms.extend(
        [transforms.OneByOneConvolution(num_channels), coupling_layer])

    return transforms.CompositeTransform(step_transforms)