Пример #1
0
def _create_vector_linear_transform(linear_transform_type, features):
    if linear_transform_type == "permutation":
        return transforms.RandomPermutation(features=features)
    elif linear_transform_type == "lu":
        return transforms.CompositeTransform([transforms.RandomPermutation(features=features), transforms.LULinear(features, identity_init=True)])
    elif linear_transform_type == "svd":
        return transforms.CompositeTransform(
            [transforms.RandomPermutation(features=features), transforms.SVDLinear(features, num_householder=10, identity_init=True)]
        )
    else:
        raise ValueError
Пример #2
0
def _create_postprocessing(dim, multi_scale, postprocessing,
                           postprocessing_channel_factor,
                           postprocessing_layers, res, context_features,
                           tail_bound, num_bins):
    # TODO: take context_features into account here

    if postprocessing == "linear":
        final_transform = transforms.LULinear(dim, identity_init=True)
        logger.debug("LULinear(%s)", dim)

    elif postprocessing == "partial_linear":
        if multi_scale:
            mask = various.create_mlt_channel_mask(
                dim,
                channels_per_level=postprocessing_channel_factor *
                np.array([1, 2, 4, 8], dtype=np.int),
                resolution=res)
            partial_dim = torch.sum(mask.to(dtype=torch.int)).item()
        else:
            partial_dim = postprocessing_channel_factor * 1024
            mask = various.create_split_binary_mask(dim, partial_dim)

        partial_transform = transforms.LULinear(partial_dim,
                                                identity_init=True)
        final_transform = transforms.PartialTransform(mask, partial_transform)
        logger.debug("PartialTransform (LULinear) (%s)", partial_dim)

    elif postprocessing == "partial_mlp":
        if multi_scale:
            mask = various.create_mlt_channel_mask(
                dim,
                channels_per_level=postprocessing_channel_factor *
                np.array([1, 2, 4, 8], dtype=np.int),
                resolution=res)
            partial_dim = torch.sum(mask.to(dtype=torch.int)).item()
        else:
            partial_dim = postprocessing_channel_factor * 1024
            mask = various.create_split_binary_mask(dim, partial_dim)

        partial_transforms = [
            transforms.LULinear(partial_dim, identity_init=True)
        ]
        logger.debug("PartialTransform (LULinear) (%s)", partial_dim)
        for _ in range(postprocessing_layers - 1):
            partial_transforms.append(transforms.LogTanh(cut_point=1))
            logger.debug("PartialTransform (LogTanh) (%s)", partial_dim)
            partial_transforms.append(
                transforms.LULinear(partial_dim, identity_init=True))
            logger.debug("PartialTransform (LULinear) (%s)", partial_dim)
        partial_transform = transforms.CompositeTransform(partial_transforms)

        final_transform = transforms.CompositeTransform([
            transforms.PartialTransform(mask, partial_transform),
            transforms.MaskBasedPermutation(mask)
        ])
        logging.debug("MaskBasedPermutation (%s)", mask)

    elif postprocessing == "partial_nsf":
        if multi_scale:
            mask = various.create_mlt_channel_mask(
                dim,
                channels_per_level=postprocessing_channel_factor *
                np.array([1, 2, 4, 16], dtype=np.int),
                resolution=res)
            partial_dim = torch.sum(mask.to(dtype=torch.int)).item()
        else:
            partial_dim = postprocessing_channel_factor * 1024
            mask = various.create_split_binary_mask(dim, partial_dim)

        partial_transform = create_vector_transform(
            dim=partial_dim,
            flow_steps=postprocessing_layers,
            linear_transform_type="permutation",
            tail_bound=tail_bound,
            num_bins=num_bins)
        logging.debug("RQ-NSF transform on %s features with %s steps",
                      partial_dim, postprocessing_layers)

        final_transform = transforms.CompositeTransform([
            transforms.PartialTransform(mask, partial_transform),
            transforms.MaskBasedPermutation(mask)
        ])
        logging.debug("MaskBasedPermutation (%s)", mask)

    elif postprocessing == "permutation":
        # Random permutation
        final_transform = transforms.RandomPermutation(dim)
        logger.debug("RandomPermutation(%s)", dim)

    elif postprocessing == "none":
        final_transform = transforms.IdentityTransform()

    else:
        raise NotImplementedError(postprocessing)
    return final_transform
Пример #3
0
 def __init__(self, num_channels, using_cache=False, identity_init=True):
     super().__init__(num_channels, using_cache, identity_init)
     self.permutation = transforms.RandomPermutation(num_channels, dim=1)
Пример #4
0
def create_image_transform(
    c,
    h,
    w,
    levels=3,
    hidden_channels=96,
    steps_per_level=7,
    alpha=0.05,
    num_bits=8,
    preprocessing="glow",
    multi_scale=True,
    use_resnet=True,
    dropout_prob=0.0,
    num_res_blocks=3,
    coupling_layer_type="rational_quadratic_spline",
    use_batchnorm=False,
    use_actnorm=True,
    spline_params=None,
):
    dim = c * h * w
    if not isinstance(hidden_channels, list):
        hidden_channels = [hidden_channels] * levels

    if multi_scale:
        mct = transforms.MultiscaleCompositeTransform(num_transforms=levels)
        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            logger.debug("Level %s", level)
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)
            logger.debug("  c, h, w = %s, %s, %s", c, h, w)

            logger.debug("  SqueezeTransform()")
            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    _create_image_transform_step(
                        c,
                        level_hidden_channels,
                        actnorm=use_actnorm,
                        coupling_layer_type=coupling_layer_type,
                        spline_params=spline_params,
                        use_resnet=use_resnet,
                        num_res_blocks=num_res_blocks,
                        resnet_batchnorm=use_batchnorm,
                        dropout_prob=dropout_prob,
                    ) for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )
            logger.debug("  OneByOneConvolution(%s)", c)

            new_shape = mct.add_transform(transform_level, (c, h, w))
            if new_shape:  # If not last layer
                c, h, w = new_shape
                logger.debug("  new_shape = %s, %s, %s", c, h, w)
    else:
        all_transforms = []

        for level, level_hidden_channels in zip(range(levels),
                                                hidden_channels):
            squeeze_transform = transforms.SqueezeTransform()
            c, h, w = squeeze_transform.get_output_shape(c, h, w)

            transform_level = transforms.CompositeTransform(
                [squeeze_transform] + [
                    _create_image_transform_step(
                        c,
                        level_hidden_channels,
                        actnorm=use_actnorm,
                        coupling_layer_type=coupling_layer_type,
                        spline_params=spline_params,
                        use_resnet=use_resnet,
                        num_res_blocks=num_res_blocks,
                        resnet_batchnorm=use_batchnorm,
                        dropout_prob=dropout_prob,
                    ) for _ in range(steps_per_level)
                ] + [transforms.OneByOneConvolution(c)
                     ]  # End each level with a linear transformation.
            )
            all_transforms.append(transform_level)

        all_transforms.append(
            transforms.ReshapeTransform(input_shape=(c, h, w),
                                        output_shape=(c * h * w, )))
        mct = transforms.CompositeTransform(all_transforms)

    # Inputs to the model in [0, 2 ** num_bits]

    if preprocessing == "glow":
        # Map to [-0.5,0.5]
        preprocess_transform = transforms.AffineScalarTransform(
            scale=(1.0 / 2**num_bits), shift=-0.5)
    elif preprocessing == "realnvp":
        preprocess_transform = transforms.CompositeTransform([
            # Map to [0,1]
            transforms.AffineScalarTransform(scale=(1.0 / 2**num_bits)),
            # Map into unconstrained space as done in RealNVP
            transforms.AffineScalarTransform(shift=alpha, scale=(1 - alpha)),
            transforms.Logit(),
        ])

    elif preprocessing == "realnvp_2alpha":
        preprocess_transform = transforms.CompositeTransform([
            transforms.AffineScalarTransform(scale=(1.0 / 2**num_bits)),
            transforms.AffineScalarTransform(shift=alpha,
                                             scale=(1 - 2.0 * alpha)),
            transforms.Logit(),
        ])
    else:
        raise RuntimeError(
            "Unknown preprocessing type: {}".format(preprocessing))

    # Random permutation
    permutation = transforms.RandomPermutation(dim)
    logger.debug("RandomPermutation(%s)", dim)

    return transforms.CompositeTransform(
        [preprocess_transform, mct, permutation])