def bijector( inputs, n_features: int = None, rng: PRNGKey = None, **kwargs ) -> MixtureGaussianCDF: prior_logits, means, log_scales = init_mixture_weights( seed=seed if rng is None else rng, n_features=n_features if n_features is not None else inputs.shape[1], n_components=n_components, method=init_method, X=inputs, ) bijector = MixtureGaussianCDF( means=means, log_scales=log_scales, prior_logits=prior_logits ) return bijector
def transform_gradient_bijector( inputs, n_features: int = None, rng: PRNGKey = None, **kwargs ) -> MixtureGaussianCDF: prior_logits, means, log_scales = init_mixture_weights( rng=seed if rng is None else rng, n_features=n_features if n_features is not None else inputs.shape[1], n_components=n_components, method=init_method, X=inputs, ) bijector = MixtureGaussianCDF( means=means, log_scales=log_scales, prior_logits=prior_logits ) # forward transform outputs, logabsdet = bijector.forward_and_log_Det(inputs=inputs) return outputs, logabsdet, bijector
def transform(inputs, n_features: int = None, rng: PRNGKey = None, **kwargs) -> MixtureLogisticCDF: prior_logits, means, log_scales = init_mixture_weights( rng=seed if rng is None else rng, n_features=n_features if n_features is not None else inputs.shape[1], n_components=n_components, method=init_method, X=inputs, ) bijector = MixtureLogisticCDF(means=means, log_scales=log_scales, prior_logits=prior_logits) # forward transform outputs = bijector.forward(inputs=inputs) return outputs