def _create_vector_base_transform(
    i,
    base_transform_type,
    features,
    hidden_features,
    num_transform_blocks,
    dropout_probability,
    use_batch_norm,
    num_bins,
    tail_bound,
    apply_unconditional_transform,
    context_features,
):
    transform_net_create_fn = lambda in_features, out_features: nn_.ResidualNet(
        in_features=in_features,
        out_features=out_features,
        hidden_features=hidden_features,
        context_features=context_features,
        num_blocks=num_transform_blocks,
        activation=F.relu,
        dropout_probability=dropout_probability,
        use_batch_norm=use_batch_norm,
    )

    if base_transform_type == "affine-coupling":
        return transforms.AffineCouplingTransform(
            mask=various.create_alternating_binary_mask(features,
                                                        even=(i % 2 == 0)),
            transform_net_create_fn=transform_net_create_fn)
    elif base_transform_type == "quadratic-coupling":
        return transforms.PiecewiseQuadraticCouplingTransform(
            mask=various.create_alternating_binary_mask(features,
                                                        even=(i % 2 == 0)),
            transform_net_create_fn=transform_net_create_fn,
            num_bins=num_bins,
            tails="linear",
            tail_bound=tail_bound,
            apply_unconditional_transform=apply_unconditional_transform,
        )
    elif base_transform_type == "rq-coupling":
        return transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=various.create_alternating_binary_mask(features,
                                                        even=(i % 2 == 0)),
            transform_net_create_fn=transform_net_create_fn,
            num_bins=num_bins,
            tails="linear",
            tail_bound=tail_bound,
            apply_unconditional_transform=apply_unconditional_transform,
        )
    elif base_transform_type == "affine-autoregressive":
        return transforms.MaskedAffineAutoregressiveTransform(
            features=features,
            hidden_features=hidden_features,
            context_features=context_features,
            num_blocks=num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm,
        )
    elif base_transform_type == "quadratic-autoregressive":
        return transforms.MaskedPiecewiseQuadraticAutoregressiveTransform(
            features=features,
            hidden_features=hidden_features,
            context_features=context_features,
            num_bins=num_bins,
            tails="linear",
            tail_bound=tail_bound,
            num_blocks=num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm,
        )
    elif base_transform_type == "rq-autoregressive":
        return transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
            features=features,
            hidden_features=hidden_features,
            context_features=context_features,
            num_bins=num_bins,
            tails="linear",
            tail_bound=tail_bound,
            num_blocks=num_transform_blocks,
            use_residual_blocks=True,
            random_mask=False,
            activation=F.relu,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm,
        )
    else:
        raise ValueError
Example #2
0
def _create_image_transform_step(
    num_channels,
    hidden_channels=96,
    context_channels=None,
    actnorm=True,
    coupling_layer_type="rational_quadratic_spline",
    num_res_blocks=3,
    resnet_batchnorm=True,
    dropout_prob=0.0,
    num_bins=8,
    tail_bound=3.0,
):
    def create_convnet(in_channels, out_channels):
        net = nn_.ConvResidualNet(
            in_channels=in_channels,
            out_channels=out_channels,
            hidden_channels=hidden_channels,
            context_channels=context_channels,
            num_blocks=num_res_blocks,
            use_batch_norm=resnet_batchnorm,
            dropout_probability=dropout_prob,
        )
        return net

    mask = various.create_mid_split_binary_mask(num_channels)

    if coupling_layer_type == "cubic_spline":
        coupling_layer = transforms.PiecewiseCubicCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=tail_bound,
            num_bins=num_bins,
            apply_unconditional_transform=False,
            min_bin_width=0.001,
            min_bin_height=0.001,
        )
    elif coupling_layer_type == "quadratic_spline":
        coupling_layer = transforms.PiecewiseQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=tail_bound,
            num_bins=num_bins,
            apply_unconditional_transform=False,
            min_bin_width=0.001,
            min_bin_height=0.001,
        )
    elif coupling_layer_type == "rational_quadratic_spline":
        coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=tail_bound,
            num_bins=num_bins,
            apply_unconditional_transform=False,
            min_bin_width=0.001,
            min_bin_height=0.001,
            min_derivative=0.001,
        )
    elif coupling_layer_type == "affine":
        coupling_layer = transforms.AffineCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    elif coupling_layer_type == "additive":
        coupling_layer = transforms.AdditiveCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    else:
        raise RuntimeError("Unknown coupling_layer_type")

    step_transforms = []

    if actnorm:
        step_transforms.append(transforms.ActNorm(num_channels))

    step_transforms.extend(
        [transforms.OneByOneConvolution(num_channels), coupling_layer])

    logger.debug("  Flow based on %s", coupling_layer_type)

    return transforms.CompositeTransform(step_transforms)
Example #3
0
def _create_image_transform_step(
    num_channels,
    hidden_channels=96,
    actnorm=True,
    coupling_layer_type="rational_quadratic_spline",
    spline_params=None,
    use_resnet=True,
    num_res_blocks=3,
    resnet_batchnorm=True,
    dropout_prob=0.0,
):
    if use_resnet:

        def create_convnet(in_channels, out_channels):
            net = nn_.ConvResidualNet(
                in_channels=in_channels,
                out_channels=out_channels,
                hidden_channels=hidden_channels,
                num_blocks=num_res_blocks,
                use_batch_norm=resnet_batchnorm,
                dropout_probability=dropout_prob,
            )
            return net

    else:
        if dropout_prob != 0.0:
            raise ValueError()

        def create_convnet(in_channels, out_channels):
            return ConvNet(in_channels, hidden_channels, out_channels)

    if spline_params is None:
        spline_params = {
            "apply_unconditional_transform": False,
            "min_bin_height": 0.001,
            "min_bin_width": 0.001,
            "min_derivative": 0.001,
            "num_bins": 4,
            "tail_bound": 3.0,
        }

    mask = various.create_mid_split_binary_mask(num_channels)

    if coupling_layer_type == "cubic_spline":
        coupling_layer = transforms.PiecewiseCubicCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=spline_params["tail_bound"],
            num_bins=spline_params["num_bins"],
            apply_unconditional_transform=spline_params[
                "apply_unconditional_transform"],
            min_bin_width=spline_params["min_bin_width"],
            min_bin_height=spline_params["min_bin_height"],
        )
    elif coupling_layer_type == "quadratic_spline":
        coupling_layer = transforms.PiecewiseQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=spline_params["tail_bound"],
            num_bins=spline_params["num_bins"],
            apply_unconditional_transform=spline_params[
                "apply_unconditional_transform"],
            min_bin_width=spline_params["min_bin_width"],
            min_bin_height=spline_params["min_bin_height"],
        )
    elif coupling_layer_type == "rational_quadratic_spline":
        coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=spline_params["tail_bound"],
            num_bins=spline_params["num_bins"],
            apply_unconditional_transform=spline_params[
                "apply_unconditional_transform"],
            min_bin_width=spline_params["min_bin_width"],
            min_bin_height=spline_params["min_bin_height"],
            min_derivative=spline_params["min_derivative"],
        )
    elif coupling_layer_type == "affine":
        coupling_layer = transforms.AffineCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    elif coupling_layer_type == "additive":
        coupling_layer = transforms.AdditiveCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    else:
        raise RuntimeError("Unknown coupling_layer_type")

    step_transforms = []

    if actnorm:
        step_transforms.append(transforms.ActNorm(num_channels))

    step_transforms.extend(
        [transforms.OneByOneConvolution(num_channels), coupling_layer])

    logger.debug("  Flow based on %s", coupling_layer_type)

    return transforms.CompositeTransform(step_transforms)