def InitLogitTempTransform(temperature: float = 1.0, jitted: bool = False): if jitted: f = jax.jit(LogitTemperature(temperature=temperature).forward) else: f = LogitTemperature(temperature=temperature).forward def transform(inputs, **kwargs): outputs = f(inputs) return outputs def bijector(inputs=None, **kwargs): return LogitTemperature(temperature=temperature) def transform_and_bijector(inputs, **kwargs): outputs = f(inputs) return outputs, LogitTemperature(temperature=temperature) def transform_gradient_bijector(inputs, **kwargs): bijector = LogitTemperature(temperature=temperature) outputs, logabsdet = bijector.forward_and_log_det(inputs) return outputs, logabsdet, bijector return InitLayersFunctions( transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, )
def InitLogitTransform(jitted: bool = False): if jitted: f = jax.jit(Inverse(Sigmoid()).forward) else: f = Inverse(Sigmoid()).forward def transform(inputs, **kwargs): outputs = f(inputs) return outputs def bijector(inputs=None, **kwargs): return Inverse(Sigmoid()) def transform_and_bijector(inputs, **kwargs): outputs = f(inputs) return outputs, Inverse(Sigmoid()) def transform_gradient_bijector(inputs, **kwargs): bijector = Inverse(Sigmoid()) outputs, logabsdet = bijector.forward_and_log_det(inputs) return outputs, logabsdet, bijector return InitLayersFunctions( transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, )
def InitRandomRotation(rng: PRNGKey, jitted=False): # create marginal functions key = rng f = jax.partial( get_random_rotation, return_params=True, ) f_slim = jax.partial( get_random_rotation, return_params=False, ) if jitted: f = jax.jit(f) f_slim = jax.jit(f_slim) def init_params(inputs, **kwargs): key_, rng = kwargs.get("rng", jax.random.split(key, 2)) _, params = f(rng, inputs) return params def params_and_transform(inputs, **kwargs): key_, rng = kwargs.get("rng", jax.random.split(key, 2)) outputs, params = f(rng, inputs) return outputs, params def transform(inputs, **kwargs): key_, rng = kwargs.get("rng", jax.random.split(key, 2)) outputs = f_slim(inputs) return outputs def bijector(inputs, **kwargs): params = init_params(inputs, **kwargs) bijector = Rotation(rotation=params.rotation, ) return bijector def bijector_and_transform(inputs, **kwargs): print(inputs.shape) outputs, params = params_and_transform(inputs, **kwargs) bijector = Rotation(rotation=params.rotation, ) return outputs, bijector return InitLayersFunctions( bijector=bijector, bijector_and_transform=bijector_and_transform, transform=transform, params=init_params, params_and_transform=params_and_transform, )
def InitPCARotation(jitted=False): # create marginal functions f = jax.partial( get_pca_params, return_params=True, ) if jitted: f = jax.jit(f) def transform(inputs, **kwargs): params = f(inputs, **kwargs) outputs = Rotation(rotation=params.rotation).forward(inputs) return outputs outputs = f(inputs) def bijector(inputs, **kwargs): params = f(inputs, **kwargs) bijector = Rotation(rotation=params.rotation) return bijector def transform_and_bijector(inputs, **kwargs): params = f(inputs, **kwargs) bijector = Rotation(rotation=params.rotation) outputs = bijector.forward(inputs) return outputs, bijector def transform_gradient_bijector(inputs, **kwargs): params = f(inputs, **kwargs) bijector = Rotation(rotation=params.rotation) outputs, logabsdet = bijector.forward_and_log_det(inputs) return outputs, logabsdet, bijector return InitLayersFunctions( transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, )
def InitSigmoidTransform(eps: float = 1e-5, jitted: bool = False): if jitted: f = jax.jit(Sigmoid().forward) else: f = Sigmoid().forward def transform(inputs, **kwargs): inputs = jnp.clip(inputs, eps, 1 - eps) outputs = f(inputs) return outputs def bijector(inputs=None, **kwargs): return Sigmoid() def transform_and_bijector(inputs, **kwargs): inputs = jnp.clip(inputs, eps, 1 - eps) outputs = f(inputs) return outputs, Sigmoid() def transform_gradient_bijector(inputs, **kwargs): inputs = jnp.clip(inputs, eps, 1 - eps) bijector = Sigmoid() outputs, logabsdet = bijector.forward_and_log_det(inputs) return outputs, logabsdet, bijector return InitLayersFunctions( transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, )
def InitInverseGaussCDF(eps: float = 1e-5, jitted=False): # initialize bijector bijector = InverseGaussCDF(eps=eps) if jitted: f = jax.jit(bijector.forward) else: f = bijector.forward def transform(inputs, **kwargs): outputs = f(inputs) return outputs def bijector(inputs=None, **kwargs): return InverseGaussCDF(eps=eps) def transform_and_bijector(inputs, **kwargs): outputs = f(inputs) return outputs, InverseGaussCDF(eps=eps) def transform_gradient_bijector(inputs, **kwargs): bijector = InverseGaussCDF(eps=eps) outputs, logabsdet = bijector.forward_and_log_det(inputs) return outputs, logabsdet, bijector return InitLayersFunctions( transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, )
def transform_gradient_bijector(inputs, **kwargs): params = jax.vmap(f, out_axes=0, in_axes=(1, ))(inputs) bijector = MarginalUniformizeTransform( support=params.support, quantiles=params.quantiles, support_pdf=params.support_pdf, empirical_pdf=params.empirical_pdf, ) outputs, logabsdet = bijector.forward_and_log_det(inputs) return outputs, logabsdet, bijector return InitLayersFunctions( transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, ) def init_kde_params( X: jnp.ndarray, bw: float = 0.1, support_extension: Union[int, float] = 10, precision: int = 1_000, return_params: bool = True, ): # generate support points lb, ub = get_domain_extension(X, support_extension) support = jnp.linspace(lb, ub, precision)
def InitHouseHolder(n_reflections: int, method: str = "random") -> Callable: """Performs the householder transformation. This is a useful method to parameterize an orthogonal matrix. Parameters ---------- n_features : int the number of features of the data n_reflections: int the number of householder reflections """ def bijector(inputs: Array, n_features: int, rng: PRNGKey = None, **kwargs) -> HouseHolder: # initialize weight matrix V = init_householder_weights( rng=rng, n_features=n_features, n_reflections=n_reflections, method=method, X=inputs, ) # initialize bijector bijector = HouseHolder(V=V) return bijector def transform_and_bijector(inputs: Array, n_features: int, rng: PRNGKey = None, **kwargs) -> Tuple[Array, HouseHolder]: # initialize weight matrix V = init_householder_weights( rng=rng, n_features=n_features, n_reflections=n_reflections, method=method, X=inputs, ) # initialize bijector bijector = HouseHolder(V=V) # forward transform outputs = bijector.forward(inputs=inputs) return outputs, bijector def transform(inputs: Array, n_features: int, rng: PRNGKey = None, **kwargs) -> Array: # initialize weight matrix V = init_householder_weights( rng=rng, n_features=n_features, n_reflections=n_reflections, method=method, X=inputs, ) # initialize bijector bijector = HouseHolder(V=V) # forward transform outputs = bijector.forward(inputs=inputs) return outputs def transform_gradient_bijector(inputs: Array, n_features: int, rng: PRNGKey = None, **kwargs) -> Tuple[Array, HouseHolder]: # initialize weight matrix V = init_householder_weights( rng=rng, n_features=n_features, n_reflections=n_reflections, method=method, X=inputs, ) # initialize bijector bijector = HouseHolder(V=V) # forward transform outputs, logabsdet = bijector.forward_and_log_det(inputs=inputs) return outputs, logabsdet, bijector return InitLayersFunctions( transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, )
def InitPiecewiseRationalQuadraticCDF( n_bins: int, range_min: float, range_max: float, identity_init: bool = False, boundary_slopes: str = "identity", min_bin_size: float = 1e-4, min_knot_slope: float = 1e-4, ): # preliminary checks of parameters if range_min >= range_max: raise ValueError(f"`range_min` is less than or equal to `range_max`; " f"Got: {range_min} and {range_max}") if min_bin_size <= 0: raise ValueError(f"Minimum bin size must be positive; " f"Got {min_bin_size}") if min_knot_slope <= 0: raise ValueError(f"Minimum knot slope must be positive; " f"Got {min_knot_slope}") def bijector(inputs: Array = None, rng: PRNGKey = None, shape: int = None, **kwargs) -> Bijector: bijector = init_spline_params( n_bins=n_bins, rng=rng, shape=shape, identity_init=identity_init, min_knot_slope=min_knot_slope, range_min=range_min, range_max=range_max, boundary_slopes=boundary_slopes, ) return bijector def transform_and_bijector(inputs: Array = None, rng: PRNGKey = None, shape: int = None, **kwargs) -> Tuple[Array, Bijector]: # init bijector bijector = init_spline_params( n_bins=n_bins, rng=rng, shape=shape, identity_init=identity_init, min_knot_slope=min_knot_slope, range_min=range_min, range_max=range_max, boundary_slopes=boundary_slopes, ) # forward transform outputs = bijector.forward(inputs=inputs) return outputs, bijector def transform(inputs: Array = None, rng: PRNGKey = None, shape: int = None, **kwargs) -> Array: # init bijector bijector = init_spline_params( n_bins=n_bins, rng=rng, shape=shape, identity_init=identity_init, min_knot_slope=min_knot_slope, range_min=range_min, range_max=range_max, boundary_slopes=boundary_slopes, ) outputs = bijector.forward(inputs=inputs) return outputs def transform_gradient_bijector(inputs: Array = None, rng: PRNGKey = None, shape: int = None, **kwargs) -> Array: # init bijector bijector = init_spline_params( n_bins=n_bins, rng=rng, shape=shape, identity_init=identity_init, min_knot_slope=min_knot_slope, range_min=range_min, range_max=range_max, boundary_slopes=boundary_slopes, ) outputs, logabsdet = bijector.forward_and_log_det(inputs=inputs) return outputs, logabsdet, bijector return InitLayersFunctions( transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, )
def InitMixtureLogisticCDF(n_components: int, init_method: str = "gmm", seed: int = 123) -> Callable: """Performs the householder transformation. This is a useful method to parameterize an orthogonal matrix. Parameters ---------- n_features : int the number of features of the data n_reflections: int the number of householder reflections """ def bijector(inputs, n_features: int = None, rng: PRNGKey = None, **kwargs) -> MixtureLogisticCDF: prior_logits, means, log_scales = init_mixture_weights( seed=seed if rng is None else rng, n_features=n_features if n_features is not None else inputs.shape[1], n_components=n_components, method=init_method, X=inputs, ) bijector = MixtureLogisticCDF(means=means, log_scales=log_scales, prior_logits=prior_logits) return bijector def transform_and_bijector(inputs, n_features: int = None, rng: PRNGKey = None, **kwargs) -> MixtureLogisticCDF: prior_logits, means, log_scales = init_mixture_weights( rng=seed if rng is None else rng, n_features=n_features if n_features is not None else inputs.shape[1], n_components=n_components, method=init_method, X=inputs, ) bijector = MixtureLogisticCDF(means=means, log_scales=log_scales, prior_logits=prior_logits) # forward transform outputs = bijector.forward(inputs=inputs) return outputs, bijector def transform(inputs, n_features: int = None, rng: PRNGKey = None, **kwargs) -> MixtureLogisticCDF: prior_logits, means, log_scales = init_mixture_weights( rng=seed if rng is None else rng, n_features=n_features if n_features is not None else inputs.shape[1], n_components=n_components, method=init_method, X=inputs, ) bijector = MixtureLogisticCDF(means=means, log_scales=log_scales, prior_logits=prior_logits) # forward transform outputs = bijector.forward(inputs=inputs) return outputs def transform_gradient_bijector(inputs, n_features: int = None, rng: PRNGKey = None, **kwargs) -> MixtureLogisticCDF: prior_logits, means, log_scales = init_mixture_weights( rng=seed if rng is None else rng, n_features=n_features if n_features is not None else inputs.shape[1], n_components=n_components, method=init_method, X=inputs, ) bijector = MixtureLogisticCDF(means=means, log_scales=log_scales, prior_logits=prior_logits) # forward transform outputs, logabsdet = bijector.forward_and_log_Det(inputs=inputs) return outputs, logabsdet, bijector return InitLayersFunctions( transform=transform, bijector=bijector, transform_and_bijector=transform_and_bijector, transform_gradient_bijector=transform_gradient_bijector, )