def create_vector_transform(
    dim,
    flow_steps,
    linear_transform_type="permutation",
    base_transform_type="rq-coupling",
    hidden_features=100,
    num_transform_blocks=2,
    dropout_probability=0.0,
    use_batch_norm=False,
    num_bins=8,
    tail_bound=3,
    apply_unconditional_transform=False,
    context_features=None,
):
    transform = transforms.CompositeTransform([
        transforms.CompositeTransform([
            _create_vector_linear_transform(linear_transform_type, dim),
            _create_vector_base_transform(
                i,
                base_transform_type,
                dim,
                hidden_features,
                num_transform_blocks,
                dropout_probability,
                use_batch_norm,
                num_bins,
                tail_bound,
                apply_unconditional_transform,
                context_features,
            ),
        ]) for i in range(flow_steps)
    ] + [_create_vector_linear_transform(linear_transform_type, dim)])
    return transform
Example #2
0
def _create_preprocessing(alpha, c, h, num_bits, preprocessing, w):
    # Preprocessing
    # Inputs to the model in [0, 2 ** num_bits]
    if preprocessing == "glow":
        # Map to [-0.5,0.5]
        preprocess_transform = transforms.AffineScalarTransform(
            scale=(1.0 / 2**num_bits), shift=-0.5)
        logger.debug("Preprocessing: Glow")
    elif preprocessing == "realnvp":
        preprocess_transform = transforms.CompositeTransform([
            # Map to [0,1]
            transforms.AffineScalarTransform(scale=(1.0 / 2**num_bits)),
            # Map into unconstrained space as done in RealNVP
            transforms.AffineScalarTransform(shift=alpha, scale=(1 - alpha)),
            transforms.Logit(),
        ])
        logger.debug("Preprocessing: RealNVP")
    elif preprocessing == "realnvp_2alpha":
        preprocess_transform = transforms.CompositeTransform([
            transforms.AffineScalarTransform(scale=(1.0 / 2**num_bits)),
            transforms.AffineScalarTransform(shift=alpha,
                                             scale=(1 - 2.0 * alpha)),
            transforms.Logit(),
        ])
        logger.debug("Preprocessing: RealNVP2alpha")
    elif preprocessing == "unflatten":
        preprocess_transform = transforms.ReshapeTransform(
            input_shape=(c * h * w, ), output_shape=(c, h, w))
        logger.debug("Preprocessing: Unflattening from %s to (%s, %s, %s)",
                     c * h * w, c, h, w)
    else:
        raise RuntimeError(
            "Unknown preprocessing type: {}".format(preprocessing))
    return preprocess_transform
Example #3
0
def _create_vector_linear_transform(linear_transform_type, features):
    if linear_transform_type == "permutation":
        return transforms.RandomPermutation(features=features)
    elif linear_transform_type == "lu":
        return transforms.CompositeTransform([transforms.RandomPermutation(features=features), transforms.LULinear(features, identity_init=True)])
    elif linear_transform_type == "svd":
        return transforms.CompositeTransform(
            [transforms.RandomPermutation(features=features), transforms.SVDLinear(features, num_householder=10, identity_init=True)]
        )
    else:
        raise ValueError
Example #4
0
def _create_image_transform_step(
    num_channels,
    hidden_channels=96,
    context_channels=None,
    actnorm=True,
    coupling_layer_type="rational_quadratic_spline",
    num_res_blocks=3,
    resnet_batchnorm=True,
    dropout_prob=0.0,
    num_bins=8,
    tail_bound=3.0,
):
    def create_convnet(in_channels, out_channels):
        net = nn_.ConvResidualNet(
            in_channels=in_channels,
            out_channels=out_channels,
            hidden_channels=hidden_channels,
            context_channels=context_channels,
            num_blocks=num_res_blocks,
            use_batch_norm=resnet_batchnorm,
            dropout_probability=dropout_prob,
        )
        return net

    mask = various.create_mid_split_binary_mask(num_channels)

    if coupling_layer_type == "cubic_spline":
        coupling_layer = transforms.PiecewiseCubicCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=tail_bound,
            num_bins=num_bins,
            apply_unconditional_transform=False,
            min_bin_width=0.001,
            min_bin_height=0.001,
        )
    elif coupling_layer_type == "quadratic_spline":
        coupling_layer = transforms.PiecewiseQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=tail_bound,
            num_bins=num_bins,
            apply_unconditional_transform=False,
            min_bin_width=0.001,
            min_bin_height=0.001,
        )
    elif coupling_layer_type == "rational_quadratic_spline":
        coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=tail_bound,
            num_bins=num_bins,
            apply_unconditional_transform=False,
            min_bin_width=0.001,
            min_bin_height=0.001,
            min_derivative=0.001,
        )
    elif coupling_layer_type == "affine":
        coupling_layer = transforms.AffineCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    elif coupling_layer_type == "additive":
        coupling_layer = transforms.AdditiveCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    else:
        raise RuntimeError("Unknown coupling_layer_type")

    step_transforms = []

    if actnorm:
        step_transforms.append(transforms.ActNorm(num_channels))

    step_transforms.extend(
        [transforms.OneByOneConvolution(num_channels), coupling_layer])

    logger.debug("  Flow based on %s", coupling_layer_type)

    return transforms.CompositeTransform(step_transforms)
Example #5
0
def _create_postprocessing(dim, multi_scale, postprocessing,
                           postprocessing_channel_factor,
                           postprocessing_layers, res, context_features,
                           tail_bound, num_bins):
    # TODO: take context_features into account here

    if postprocessing == "linear":
        final_transform = transforms.LULinear(dim, identity_init=True)
        logger.debug("LULinear(%s)", dim)

    elif postprocessing == "partial_linear":
        if multi_scale:
            mask = various.create_mlt_channel_mask(
                dim,
                channels_per_level=postprocessing_channel_factor *
                np.array([1, 2, 4, 8], dtype=np.int),
                resolution=res)
            partial_dim = torch.sum(mask.to(dtype=torch.int)).item()
        else:
            partial_dim = postprocessing_channel_factor * 1024
            mask = various.create_split_binary_mask(dim, partial_dim)

        partial_transform = transforms.LULinear(partial_dim,
                                                identity_init=True)
        final_transform = transforms.PartialTransform(mask, partial_transform)
        logger.debug("PartialTransform (LULinear) (%s)", partial_dim)

    elif postprocessing == "partial_mlp":
        if multi_scale:
            mask = various.create_mlt_channel_mask(
                dim,
                channels_per_level=postprocessing_channel_factor *
                np.array([1, 2, 4, 8], dtype=np.int),
                resolution=res)
            partial_dim = torch.sum(mask.to(dtype=torch.int)).item()
        else:
            partial_dim = postprocessing_channel_factor * 1024
            mask = various.create_split_binary_mask(dim, partial_dim)

        partial_transforms = [
            transforms.LULinear(partial_dim, identity_init=True)
        ]
        logger.debug("PartialTransform (LULinear) (%s)", partial_dim)
        for _ in range(postprocessing_layers - 1):
            partial_transforms.append(transforms.LogTanh(cut_point=1))
            logger.debug("PartialTransform (LogTanh) (%s)", partial_dim)
            partial_transforms.append(
                transforms.LULinear(partial_dim, identity_init=True))
            logger.debug("PartialTransform (LULinear) (%s)", partial_dim)
        partial_transform = transforms.CompositeTransform(partial_transforms)

        final_transform = transforms.CompositeTransform([
            transforms.PartialTransform(mask, partial_transform),
            transforms.MaskBasedPermutation(mask)
        ])
        logging.debug("MaskBasedPermutation (%s)", mask)

    elif postprocessing == "partial_nsf":
        if multi_scale:
            mask = various.create_mlt_channel_mask(
                dim,
                channels_per_level=postprocessing_channel_factor *
                np.array([1, 2, 4, 16], dtype=np.int),
                resolution=res)
            partial_dim = torch.sum(mask.to(dtype=torch.int)).item()
        else:
            partial_dim = postprocessing_channel_factor * 1024
            mask = various.create_split_binary_mask(dim, partial_dim)

        partial_transform = create_vector_transform(
            dim=partial_dim,
            flow_steps=postprocessing_layers,
            linear_transform_type="permutation",
            tail_bound=tail_bound,
            num_bins=num_bins)
        logging.debug("RQ-NSF transform on %s features with %s steps",
                      partial_dim, postprocessing_layers)

        final_transform = transforms.CompositeTransform([
            transforms.PartialTransform(mask, partial_transform),
            transforms.MaskBasedPermutation(mask)
        ])
        logging.debug("MaskBasedPermutation (%s)", mask)

    elif postprocessing == "permutation":
        # Random permutation
        final_transform = transforms.RandomPermutation(dim)
        logger.debug("RandomPermutation(%s)", dim)

    elif postprocessing == "none":
        final_transform = transforms.IdentityTransform()

    else:
        raise NotImplementedError(postprocessing)
    return final_transform
Example #6
0
def create_image_transform(
    c,
    h,
    w,
    levels=3,
    hidden_channels=96,
    steps_per_level=7,
    alpha=0.05,
    num_bits=8,
    preprocessing="glow",
    multi_scale=True,
    dropout_prob=0.0,
    num_res_blocks=3,
    coupling_layer_type="rational_quadratic_spline",
    use_batchnorm=True,
    use_actnorm=True,
    postprocessing="permutation",
    postprocessing_layers=2,
    postprocessing_channel_factor=2,
    context_features=None,
    num_bins=8,
    tail_bound=3.0,
):
    assert h == w
    res = h
    dim = c * h * w

    if not isinstance(hidden_channels, list):
        hidden_channels = [hidden_channels] * levels

    preprocess_transform = _create_preprocessing(alpha, c, h, num_bits,
                                                 preprocessing, w)

    # Main part
    if multi_scale:
        logger.debug("Input: c, h, w = %s, %s, %s", c, h, w)
        mct = transforms.MultiscaleCompositeTransform(num_transforms=levels)
        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            logger.debug("Level %s", level)
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)
            logger.debug("  c, h, w = %s, %s, %s", c, h, w)

            logger.debug("  SqueezeTransform()")
            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    _create_image_transform_step(
                        c,
                        level_hidden_channels,
                        actnorm=use_actnorm,
                        coupling_layer_type=coupling_layer_type,
                        num_bins=num_bins,
                        tail_bound=tail_bound,
                        num_res_blocks=num_res_blocks,
                        resnet_batchnorm=use_batchnorm,
                        dropout_prob=dropout_prob,
                        context_channels=context_features,
                    ) for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )
            logger.debug("  OneByOneConvolution(%s)", c)

            new_shape = mct.add_transform(transform_level, (c, h, w))
            if new_shape:  # If not last layer
                c, h, w = new_shape
                logger.debug("  new_shape = %s, %s, %s", c, h, w)
    else:
        all_transforms = []

        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)

            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    _create_image_transform_step(
                        c,
                        level_hidden_channels,
                        actnorm=use_actnorm,
                        coupling_layer_type=coupling_layer_type,
                        num_res_blocks=num_res_blocks,
                        resnet_batchnorm=use_batchnorm,
                        dropout_prob=dropout_prob,
                        context_channels=context_features,
                    ) for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )
            all_transforms.append(transform_level)

        all_transforms.append(
            transforms.ReshapeTransform(input_shape=(c, h, w),
                                        output_shape=(c * h * w, )))
        mct = transforms.CompositeTransform(all_transforms)

    # Final transformation
    final_transform = _create_postprocessing(dim,
                                             multi_scale,
                                             postprocessing,
                                             postprocessing_channel_factor,
                                             postprocessing_layers,
                                             res,
                                             context_features,
                                             num_bins=num_bins,
                                             tail_bound=tail_bound)

    return transforms.CompositeTransform(
        [preprocess_transform, mct, final_transform])
Example #7
0
def _create_image_transform_step(
    num_channels,
    hidden_channels=96,
    actnorm=True,
    coupling_layer_type="rational_quadratic_spline",
    spline_params=None,
    use_resnet=True,
    num_res_blocks=3,
    resnet_batchnorm=True,
    dropout_prob=0.0,
):
    if use_resnet:

        def create_convnet(in_channels, out_channels):
            net = nn_.ConvResidualNet(
                in_channels=in_channels,
                out_channels=out_channels,
                hidden_channels=hidden_channels,
                num_blocks=num_res_blocks,
                use_batch_norm=resnet_batchnorm,
                dropout_probability=dropout_prob,
            )
            return net

    else:
        if dropout_prob != 0.0:
            raise ValueError()

        def create_convnet(in_channels, out_channels):
            return ConvNet(in_channels, hidden_channels, out_channels)

    if spline_params is None:
        spline_params = {
            "apply_unconditional_transform": False,
            "min_bin_height": 0.001,
            "min_bin_width": 0.001,
            "min_derivative": 0.001,
            "num_bins": 4,
            "tail_bound": 3.0,
        }

    mask = various.create_mid_split_binary_mask(num_channels)

    if coupling_layer_type == "cubic_spline":
        coupling_layer = transforms.PiecewiseCubicCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=spline_params["tail_bound"],
            num_bins=spline_params["num_bins"],
            apply_unconditional_transform=spline_params[
                "apply_unconditional_transform"],
            min_bin_width=spline_params["min_bin_width"],
            min_bin_height=spline_params["min_bin_height"],
        )
    elif coupling_layer_type == "quadratic_spline":
        coupling_layer = transforms.PiecewiseQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=spline_params["tail_bound"],
            num_bins=spline_params["num_bins"],
            apply_unconditional_transform=spline_params[
                "apply_unconditional_transform"],
            min_bin_width=spline_params["min_bin_width"],
            min_bin_height=spline_params["min_bin_height"],
        )
    elif coupling_layer_type == "rational_quadratic_spline":
        coupling_layer = transforms.PiecewiseRationalQuadraticCouplingTransform(
            mask=mask,
            transform_net_create_fn=create_convnet,
            tails="linear",
            tail_bound=spline_params["tail_bound"],
            num_bins=spline_params["num_bins"],
            apply_unconditional_transform=spline_params[
                "apply_unconditional_transform"],
            min_bin_width=spline_params["min_bin_width"],
            min_bin_height=spline_params["min_bin_height"],
            min_derivative=spline_params["min_derivative"],
        )
    elif coupling_layer_type == "affine":
        coupling_layer = transforms.AffineCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    elif coupling_layer_type == "additive":
        coupling_layer = transforms.AdditiveCouplingTransform(
            mask=mask, transform_net_create_fn=create_convnet)
    else:
        raise RuntimeError("Unknown coupling_layer_type")

    step_transforms = []

    if actnorm:
        step_transforms.append(transforms.ActNorm(num_channels))

    step_transforms.extend(
        [transforms.OneByOneConvolution(num_channels), coupling_layer])

    logger.debug("  Flow based on %s", coupling_layer_type)

    return transforms.CompositeTransform(step_transforms)
Example #8
0
def create_image_transform(
    c,
    h,
    w,
    levels=3,
    hidden_channels=96,
    steps_per_level=7,
    alpha=0.05,
    num_bits=8,
    preprocessing="glow",
    multi_scale=True,
    use_resnet=True,
    dropout_prob=0.0,
    num_res_blocks=3,
    coupling_layer_type="rational_quadratic_spline",
    use_batchnorm=False,
    use_actnorm=True,
    spline_params=None,
):
    dim = c * h * w
    if not isinstance(hidden_channels, list):
        hidden_channels = [hidden_channels] * levels

    if multi_scale:
        mct = transforms.MultiscaleCompositeTransform(num_transforms=levels)
        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            logger.debug("Level %s", level)
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)
            logger.debug("  c, h, w = %s, %s, %s", c, h, w)

            logger.debug("  SqueezeTransform()")
            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    _create_image_transform_step(
                        c,
                        level_hidden_channels,
                        actnorm=use_actnorm,
                        coupling_layer_type=coupling_layer_type,
                        spline_params=spline_params,
                        use_resnet=use_resnet,
                        num_res_blocks=num_res_blocks,
                        resnet_batchnorm=use_batchnorm,
                        dropout_prob=dropout_prob,
                    ) for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )
            logger.debug("  OneByOneConvolution(%s)", c)

            new_shape = mct.add_transform(transform_level, (c, h, w))
            if new_shape:  # If not last layer
                c, h, w = new_shape
                logger.debug("  new_shape = %s, %s, %s", c, h, w)
    else:
        all_transforms = []

        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)

            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    _create_image_transform_step(
                        c,
                        level_hidden_channels,
                        actnorm=use_actnorm,
                        coupling_layer_type=coupling_layer_type,
                        spline_params=spline_params,
                        use_resnet=use_resnet,
                        num_res_blocks=num_res_blocks,
                        resnet_batchnorm=use_batchnorm,
                        dropout_prob=dropout_prob,
                    ) for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )
            all_transforms.append(transform_level)

        all_transforms.append(
            transforms.ReshapeTransform(input_shape=(c, h, w),
                                        output_shape=(c * h * w, )))
        mct = transforms.CompositeTransform(all_transforms)

    # Inputs to the model in [0, 2 ** num_bits]

    if preprocessing == "glow":
        # Map to [-0.5,0.5]
        preprocess_transform = transforms.AffineScalarTransform(
            scale=(1.0 / 2**num_bits), shift=-0.5)
    elif preprocessing == "realnvp":
        preprocess_transform = transforms.CompositeTransform([
            # Map to [0,1]
            transforms.AffineScalarTransform(scale=(1.0 / 2**num_bits)),
            # Map into unconstrained space as done in RealNVP
            transforms.AffineScalarTransform(shift=alpha, scale=(1 - alpha)),
            transforms.Logit(),
        ])

    elif preprocessing == "realnvp_2alpha":
        preprocess_transform = transforms.CompositeTransform([
            transforms.AffineScalarTransform(scale=(1.0 / 2**num_bits)),
            transforms.AffineScalarTransform(shift=alpha,
                                             scale=(1 - 2.0 * alpha)),
            transforms.Logit(),
        ])
    else:
        raise RuntimeError(
            "Unknown preprocessing type: {}".format(preprocessing))

    # Random permutation
    permutation = transforms.RandomPermutation(dim)
    logger.debug("RandomPermutation(%s)", dim)

    return transforms.CompositeTransform(
        [preprocess_transform, mct, permutation])