Example #1
0
def InitLogitTempTransform(temperature: float = 1.0, jitted: bool = False):

    if jitted:
        f = jax.jit(LogitTemperature(temperature=temperature).forward)
    else:
        f = LogitTemperature(temperature=temperature).forward

    def transform(inputs, **kwargs):

        outputs = f(inputs)

        return outputs

    def bijector(inputs=None, **kwargs):

        return LogitTemperature(temperature=temperature)

    def transform_and_bijector(inputs, **kwargs):
        outputs = f(inputs)
        return outputs, LogitTemperature(temperature=temperature)

    def transform_gradient_bijector(inputs, **kwargs):
        bijector = LogitTemperature(temperature=temperature)

        outputs, logabsdet = bijector.forward_and_log_det(inputs)

        return outputs, logabsdet, bijector

    return InitLayersFunctions(
        transform=transform,
        bijector=bijector,
        transform_and_bijector=transform_and_bijector,
        transform_gradient_bijector=transform_gradient_bijector,
    )
Example #2
0
def InitLogitTransform(jitted: bool = False):

    if jitted:
        f = jax.jit(Inverse(Sigmoid()).forward)
    else:
        f = Inverse(Sigmoid()).forward

    def transform(inputs, **kwargs):

        outputs = f(inputs)

        return outputs

    def bijector(inputs=None, **kwargs):

        return Inverse(Sigmoid())

    def transform_and_bijector(inputs, **kwargs):
        outputs = f(inputs)
        return outputs, Inverse(Sigmoid())

    def transform_gradient_bijector(inputs, **kwargs):
        bijector = Inverse(Sigmoid())

        outputs, logabsdet = bijector.forward_and_log_det(inputs)

        return outputs, logabsdet, bijector

    return InitLayersFunctions(
        transform=transform,
        bijector=bijector,
        transform_and_bijector=transform_and_bijector,
        transform_gradient_bijector=transform_gradient_bijector,
    )
Example #3
0
def InitRandomRotation(rng: PRNGKey, jitted=False):
    # create marginal functions
    key = rng
    f = jax.partial(
        get_random_rotation,
        return_params=True,
    )

    f_slim = jax.partial(
        get_random_rotation,
        return_params=False,
    )

    if jitted:
        f = jax.jit(f)
        f_slim = jax.jit(f_slim)

    def init_params(inputs, **kwargs):

        key_, rng = kwargs.get("rng", jax.random.split(key, 2))

        _, params = f(rng, inputs)
        return params

    def params_and_transform(inputs, **kwargs):

        key_, rng = kwargs.get("rng", jax.random.split(key, 2))

        outputs, params = f(rng, inputs)
        return outputs, params

    def transform(inputs, **kwargs):

        key_, rng = kwargs.get("rng", jax.random.split(key, 2))

        outputs = f_slim(inputs)
        return outputs

    def bijector(inputs, **kwargs):
        params = init_params(inputs, **kwargs)
        bijector = Rotation(rotation=params.rotation, )
        return bijector

    def bijector_and_transform(inputs, **kwargs):
        print(inputs.shape)
        outputs, params = params_and_transform(inputs, **kwargs)
        bijector = Rotation(rotation=params.rotation, )
        return outputs, bijector

    return InitLayersFunctions(
        bijector=bijector,
        bijector_and_transform=bijector_and_transform,
        transform=transform,
        params=init_params,
        params_and_transform=params_and_transform,
    )
Example #4
0
def InitPCARotation(jitted=False):
    # create marginal functions

    f = jax.partial(
        get_pca_params,
        return_params=True,
    )

    if jitted:
        f = jax.jit(f)

    def transform(inputs, **kwargs):
        params = f(inputs, **kwargs)

        outputs = Rotation(rotation=params.rotation).forward(inputs)
        return outputs

        outputs = f(inputs)

    def bijector(inputs, **kwargs):
        params = f(inputs, **kwargs)

        bijector = Rotation(rotation=params.rotation)
        return bijector

    def transform_and_bijector(inputs, **kwargs):
        params = f(inputs, **kwargs)

        bijector = Rotation(rotation=params.rotation)

        outputs = bijector.forward(inputs)
        return outputs, bijector

    def transform_gradient_bijector(inputs, **kwargs):
        params = f(inputs, **kwargs)

        bijector = Rotation(rotation=params.rotation)

        outputs, logabsdet = bijector.forward_and_log_det(inputs)

        return outputs, logabsdet, bijector

    return InitLayersFunctions(
        transform=transform,
        bijector=bijector,
        transform_and_bijector=transform_and_bijector,
        transform_gradient_bijector=transform_gradient_bijector,
    )
Example #5
0
def InitSigmoidTransform(eps: float = 1e-5, jitted: bool = False):

    if jitted:
        f = jax.jit(Sigmoid().forward)
    else:
        f = Sigmoid().forward

    def transform(inputs, **kwargs):
        inputs = jnp.clip(inputs, eps, 1 - eps)

        outputs = f(inputs)

        return outputs

    def bijector(inputs=None, **kwargs):

        return Sigmoid()

    def transform_and_bijector(inputs, **kwargs):
        inputs = jnp.clip(inputs, eps, 1 - eps)
        outputs = f(inputs)
        return outputs, Sigmoid()

    def transform_gradient_bijector(inputs, **kwargs):
        inputs = jnp.clip(inputs, eps, 1 - eps)
        bijector = Sigmoid()

        outputs, logabsdet = bijector.forward_and_log_det(inputs)

        return outputs, logabsdet, bijector

    return InitLayersFunctions(
        transform=transform,
        bijector=bijector,
        transform_and_bijector=transform_and_bijector,
        transform_gradient_bijector=transform_gradient_bijector,
    )
Example #6
0
def InitInverseGaussCDF(eps: float = 1e-5, jitted=False):

    # initialize bijector
    bijector = InverseGaussCDF(eps=eps)

    if jitted:
        f = jax.jit(bijector.forward)
    else:
        f = bijector.forward

    def transform(inputs, **kwargs):

        outputs = f(inputs)

        return outputs

    def bijector(inputs=None, **kwargs):

        return InverseGaussCDF(eps=eps)

    def transform_and_bijector(inputs, **kwargs):
        outputs = f(inputs)
        return outputs, InverseGaussCDF(eps=eps)

    def transform_gradient_bijector(inputs, **kwargs):
        bijector = InverseGaussCDF(eps=eps)

        outputs, logabsdet = bijector.forward_and_log_det(inputs)

        return outputs, logabsdet, bijector

    return InitLayersFunctions(
        transform=transform,
        bijector=bijector,
        transform_and_bijector=transform_and_bijector,
        transform_gradient_bijector=transform_gradient_bijector,
    )
Example #7
0
    def transform_gradient_bijector(inputs, **kwargs):
        params = jax.vmap(f, out_axes=0, in_axes=(1, ))(inputs)
        bijector = MarginalUniformizeTransform(
            support=params.support,
            quantiles=params.quantiles,
            support_pdf=params.support_pdf,
            empirical_pdf=params.empirical_pdf,
        )
        outputs, logabsdet = bijector.forward_and_log_det(inputs)

        return outputs, logabsdet, bijector

    return InitLayersFunctions(
        transform=transform,
        bijector=bijector,
        transform_and_bijector=transform_and_bijector,
        transform_gradient_bijector=transform_gradient_bijector,
    )


def init_kde_params(
    X: jnp.ndarray,
    bw: float = 0.1,
    support_extension: Union[int, float] = 10,
    precision: int = 1_000,
    return_params: bool = True,
):
    # generate support points
    lb, ub = get_domain_extension(X, support_extension)
    support = jnp.linspace(lb, ub, precision)
Example #8
0
def InitHouseHolder(n_reflections: int, method: str = "random") -> Callable:
    """Performs the householder transformation.

    This is a useful method to parameterize an orthogonal matrix.
    
    Parameters
    ----------
    n_features : int
        the number of features of the data
    n_reflections: int
        the number of householder reflections
    """
    def bijector(inputs: Array,
                 n_features: int,
                 rng: PRNGKey = None,
                 **kwargs) -> HouseHolder:

        # initialize weight matrix
        V = init_householder_weights(
            rng=rng,
            n_features=n_features,
            n_reflections=n_reflections,
            method=method,
            X=inputs,
        )

        # initialize bijector
        bijector = HouseHolder(V=V)

        return bijector

    def transform_and_bijector(inputs: Array,
                               n_features: int,
                               rng: PRNGKey = None,
                               **kwargs) -> Tuple[Array, HouseHolder]:

        # initialize weight matrix
        V = init_householder_weights(
            rng=rng,
            n_features=n_features,
            n_reflections=n_reflections,
            method=method,
            X=inputs,
        )

        # initialize bijector
        bijector = HouseHolder(V=V)

        # forward transform
        outputs = bijector.forward(inputs=inputs)

        return outputs, bijector

    def transform(inputs: Array,
                  n_features: int,
                  rng: PRNGKey = None,
                  **kwargs) -> Array:

        # initialize weight matrix
        V = init_householder_weights(
            rng=rng,
            n_features=n_features,
            n_reflections=n_reflections,
            method=method,
            X=inputs,
        )

        # initialize bijector
        bijector = HouseHolder(V=V)

        # forward transform
        outputs = bijector.forward(inputs=inputs)

        return outputs

    def transform_gradient_bijector(inputs: Array,
                                    n_features: int,
                                    rng: PRNGKey = None,
                                    **kwargs) -> Tuple[Array, HouseHolder]:

        # initialize weight matrix
        V = init_householder_weights(
            rng=rng,
            n_features=n_features,
            n_reflections=n_reflections,
            method=method,
            X=inputs,
        )

        # initialize bijector
        bijector = HouseHolder(V=V)

        # forward transform
        outputs, logabsdet = bijector.forward_and_log_det(inputs=inputs)

        return outputs, logabsdet, bijector

    return InitLayersFunctions(
        transform=transform,
        bijector=bijector,
        transform_and_bijector=transform_and_bijector,
        transform_gradient_bijector=transform_gradient_bijector,
    )
Example #9
0
def InitPiecewiseRationalQuadraticCDF(
    n_bins: int,
    range_min: float,
    range_max: float,
    identity_init: bool = False,
    boundary_slopes: str = "identity",
    min_bin_size: float = 1e-4,
    min_knot_slope: float = 1e-4,
):
    # preliminary checks of parameters
    if range_min >= range_max:
        raise ValueError(f"`range_min` is less than or equal to `range_max`; "
                         f"Got: {range_min} and {range_max}")
    if min_bin_size <= 0:
        raise ValueError(f"Minimum bin size must be positive; "
                         f"Got {min_bin_size}")
    if min_knot_slope <= 0:
        raise ValueError(f"Minimum knot slope must be positive; "
                         f"Got {min_knot_slope}")

    def bijector(inputs: Array = None,
                 rng: PRNGKey = None,
                 shape: int = None,
                 **kwargs) -> Bijector:

        bijector = init_spline_params(
            n_bins=n_bins,
            rng=rng,
            shape=shape,
            identity_init=identity_init,
            min_knot_slope=min_knot_slope,
            range_min=range_min,
            range_max=range_max,
            boundary_slopes=boundary_slopes,
        )

        return bijector

    def transform_and_bijector(inputs: Array = None,
                               rng: PRNGKey = None,
                               shape: int = None,
                               **kwargs) -> Tuple[Array, Bijector]:

        # init bijector
        bijector = init_spline_params(
            n_bins=n_bins,
            rng=rng,
            shape=shape,
            identity_init=identity_init,
            min_knot_slope=min_knot_slope,
            range_min=range_min,
            range_max=range_max,
            boundary_slopes=boundary_slopes,
        )
        # forward transform
        outputs = bijector.forward(inputs=inputs)

        return outputs, bijector

    def transform(inputs: Array = None,
                  rng: PRNGKey = None,
                  shape: int = None,
                  **kwargs) -> Array:

        # init bijector
        bijector = init_spline_params(
            n_bins=n_bins,
            rng=rng,
            shape=shape,
            identity_init=identity_init,
            min_knot_slope=min_knot_slope,
            range_min=range_min,
            range_max=range_max,
            boundary_slopes=boundary_slopes,
        )
        outputs = bijector.forward(inputs=inputs)

        return outputs

    def transform_gradient_bijector(inputs: Array = None,
                                    rng: PRNGKey = None,
                                    shape: int = None,
                                    **kwargs) -> Array:

        # init bijector
        bijector = init_spline_params(
            n_bins=n_bins,
            rng=rng,
            shape=shape,
            identity_init=identity_init,
            min_knot_slope=min_knot_slope,
            range_min=range_min,
            range_max=range_max,
            boundary_slopes=boundary_slopes,
        )
        outputs, logabsdet = bijector.forward_and_log_det(inputs=inputs)

        return outputs, logabsdet, bijector

    return InitLayersFunctions(
        transform=transform,
        bijector=bijector,
        transform_and_bijector=transform_and_bijector,
        transform_gradient_bijector=transform_gradient_bijector,
    )
Example #10
0
def InitMixtureLogisticCDF(n_components: int,
                           init_method: str = "gmm",
                           seed: int = 123) -> Callable:
    """Performs the householder transformation.

    This is a useful method to parameterize an orthogonal matrix.
    
    Parameters
    ----------
    n_features : int
        the number of features of the data
    n_reflections: int
        the number of householder reflections
    """
    def bijector(inputs,
                 n_features: int = None,
                 rng: PRNGKey = None,
                 **kwargs) -> MixtureLogisticCDF:
        prior_logits, means, log_scales = init_mixture_weights(
            seed=seed if rng is None else rng,
            n_features=n_features
            if n_features is not None else inputs.shape[1],
            n_components=n_components,
            method=init_method,
            X=inputs,
        )

        bijector = MixtureLogisticCDF(means=means,
                                      log_scales=log_scales,
                                      prior_logits=prior_logits)

        return bijector

    def transform_and_bijector(inputs,
                               n_features: int = None,
                               rng: PRNGKey = None,
                               **kwargs) -> MixtureLogisticCDF:
        prior_logits, means, log_scales = init_mixture_weights(
            rng=seed if rng is None else rng,
            n_features=n_features
            if n_features is not None else inputs.shape[1],
            n_components=n_components,
            method=init_method,
            X=inputs,
        )

        bijector = MixtureLogisticCDF(means=means,
                                      log_scales=log_scales,
                                      prior_logits=prior_logits)
        # forward transform
        outputs = bijector.forward(inputs=inputs)
        return outputs, bijector

    def transform(inputs,
                  n_features: int = None,
                  rng: PRNGKey = None,
                  **kwargs) -> MixtureLogisticCDF:

        prior_logits, means, log_scales = init_mixture_weights(
            rng=seed if rng is None else rng,
            n_features=n_features
            if n_features is not None else inputs.shape[1],
            n_components=n_components,
            method=init_method,
            X=inputs,
        )

        bijector = MixtureLogisticCDF(means=means,
                                      log_scales=log_scales,
                                      prior_logits=prior_logits)
        # forward transform
        outputs = bijector.forward(inputs=inputs)

        return outputs

    def transform_gradient_bijector(inputs,
                                    n_features: int = None,
                                    rng: PRNGKey = None,
                                    **kwargs) -> MixtureLogisticCDF:
        prior_logits, means, log_scales = init_mixture_weights(
            rng=seed if rng is None else rng,
            n_features=n_features
            if n_features is not None else inputs.shape[1],
            n_components=n_components,
            method=init_method,
            X=inputs,
        )

        bijector = MixtureLogisticCDF(means=means,
                                      log_scales=log_scales,
                                      prior_logits=prior_logits)
        # forward transform
        outputs, logabsdet = bijector.forward_and_log_Det(inputs=inputs)
        return outputs, logabsdet, bijector

    return InitLayersFunctions(
        transform=transform,
        bijector=bijector,
        transform_and_bijector=transform_and_bijector,
        transform_gradient_bijector=transform_gradient_bijector,
    )