def __init__(self, name, mu, b, tracked=True): if not isinstance(mu, PriorTransform): mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_1d(mu), False) if not isinstance(b, PriorTransform): b = DeltaPrior('_{}_b'.format(name), b, False) U_dims = broadcast_shapes(get_shape(mu), get_shape(b))[0] super(LaplacePrior, self).__init__(name, U_dims, [mu, b], tracked)
def __init__(self, name, mu, gamma, tracked=True): if not isinstance(mu, PriorTransform): mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_1d(mu), False) if not isinstance(gamma, PriorTransform): gamma = DeltaPrior('_{}_gamma'.format(name), jnp.atleast_1d(gamma), False) U_dims = broadcast_shapes(get_shape(mu), get_shape(gamma))[0] super(MVNDiagPrior, self).__init__(name, U_dims, [mu, gamma], tracked)
def __init__(self, name, mu, Gamma, ill_cond=False, tracked=True): self._ill_cond = ill_cond if not isinstance(mu, PriorTransform): mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_1d(mu), False) if not isinstance(Gamma, PriorTransform): Gamma = DeltaPrior('_{}_Gamma'.format(name), jnp.atleast_2d(Gamma), False) U_dims = broadcast_shapes(get_shape(mu), get_shape(Gamma)[0:1])[0] super(MVNPrior, self).__init__(name, U_dims, [mu, Gamma], tracked)
def __init__(self, name, low, high, tracked=True): if not isinstance(low, PriorTransform): low = DeltaPrior('_{}_low'.format(name), low, False) if not isinstance(high, PriorTransform): high = DeltaPrior('_{}_high'.format(name), high, False) self._broadcast_shape = broadcast_shapes(get_shape(low), get_shape(high)) U_dims = tuple_prod(self._broadcast_shape) super(UniformPrior, self).__init__(name, U_dims, [low, high], tracked)
def __init__(self, name, logits, tracked=True): if not isinstance(logits, PriorTransform): logits = DeltaPrior('_{}_logits'.format(name), jnp.atleast_1d(logits), False) U_dims = get_shape(logits)[0] gumbel = Gumbel('_{}_gumbel'.format(name), U_dims, False) self._shape = (1, ) U_dims = get_shape(logits)[0] super(CategoricalPrior, self).__init__(name, U_dims, [gumbel, logits], tracked)
def __init__(self, name, T, x0, omega, tracked=True): if not isinstance(x0, PriorTransform): x0 = DeltaPrior('_{}_x0'.format(name), x0, False) if not isinstance(omega, PriorTransform): omega = DeltaPrior('_{}_omega'.format(name), omega, False) # replaces mu and gamma when parents injected self.dim = broadcast_shapes(get_shape(x0), get_shape(omega))[0] self.T = T super(DiagGaussianWalkPrior, self).__init__(name, self.dim * self.T, [x0, omega], tracked)
def __init__(self, name, n, low, high, tracked=True): if not isinstance(low, PriorTransform): low = DeltaPrior('_{}_low'.format(name), low, False) if not isinstance(high, PriorTransform): high = DeltaPrior('_{}_high'.format(name), high, False) self._n = n # replaces mu and gamma when parents injected self._broadcast_shape = (self._n,) + broadcast_shapes(get_shape(low), get_shape(high)) U_dims = tuple_prod(self._broadcast_shape) super(ForcedIdentifiabilityPrior, self).__init__(name, U_dims, [low, high], tracked)
def __init__(self, name, T, x0, half_width, tracked=True): if not isinstance(x0, PriorTransform): x0 = DeltaPrior('_{}_x0'.format(name), x0, False) if not isinstance(half_width, PriorTransform): half_width = DeltaPrior('_{}_half_width'.format(name), half_width, False) # replaces mu and gamma when parents injected self.dim = broadcast_shapes(get_shape(x0), get_shape(half_width))[0] self.T = T super(SymmetricUniformWalkPrior, self).__init__(name, self.dim * self.T, [x0, half_width], tracked)
def __init__(self, name, logits, tracked=True): if not isinstance(logits, PriorTransform): logits = DeltaPrior('_{}_logits'.format(name), jnp.atleast_1d(logits), False) self._shape = get_shape(logits) U_dims = tuple_prod(self._shape) super(BernoulliPrior, self).__init__(name, U_dims, [logits], tracked)
def __init__(self, name, pi, mu, gamma, tracked=True): if not isinstance(pi, PriorTransform): pi = DeltaPrior('_{}_pi'.format(name), pi, False) if not isinstance(mu, PriorTransform): mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_2d(mu), False) if not isinstance(gamma, PriorTransform): gamma = DeltaPrior('_{}_gamma'.format(name), jnp.atleast_2d(gamma), False) assert (get_shape(pi)[0] == get_shape(mu)[0]) and (get_shape(pi)[0] == get_shape(gamma)[0]) \ and (get_shape(mu)[1] == get_shape(gamma)[1]) # replaces mu and gamma when parents injected U_dims = 1 + broadcast_shapes(get_shape(mu)[-1:], get_shape(gamma)[-1:])[0] super(GMMDiagPrior, self).__init__(name, U_dims, [pi, mu, gamma], tracked)
def __init__(self, name, pi, low, high, tracked=True): if not isinstance(pi, PriorTransform): pi = DeltaPrior('_{}_pi'.format(name), pi, False) if not isinstance(low, PriorTransform): low = DeltaPrior('_{}_low'.format(name), jnp.atleast_2d(low), False) if not isinstance(high, PriorTransform): high = DeltaPrior('_{}_high'.format(name), jnp.atleast_2d(high), False) assert (get_shape(pi)[0] == get_shape(low)[0]) and (get_shape(pi)[0] == get_shape(high)[0]) \ and (get_shape(low)[1] == get_shape(high)[1]) # replaces mu and high when parents injected U_dims = 1 + broadcast_shapes(get_shape(low)[-1:], get_shape(high)[-1:])[0] super(UniformMixturePrior, self).__init__(name, U_dims, [pi, low, high], tracked)
def __init__(self, name, transform, pi, *components, tracked=True): self._transform = transform if not isinstance(pi, PriorTransform): pi = DeltaPrior('_{}_pi'.format(name), pi, False) assert (get_shape(pi)[0] == len(components)) shape = () for component in components: assert isinstance(component, PriorTransform) shape = broadcast_shapes(shape, component.to_shape) self._shape = shape # replaces mu and gamma when parents injected U_dims = 1 super(MixturePrior, self).__init__(name, U_dims, [pi] + components, tracked)
def __init__(self, name, b, tracked=True): if not isinstance(b, PriorTransform): b = DeltaPrior('_{}_b'.format(name), b, False) U_dims = get_shape(b)[0] super(HalfLaplacePrior, self).__init__(name, U_dims, [b], tracked)
def __init__(self, name, kernel: Kernel, X, *gp_params, tracked=False): gp_params = [X] + list(gp_params) def _transform(X, *gp_params): return kernel(X, X, *gp_params) + 1e-6 * jnp.eye(X.shape[0]) to_shape = (get_shape(X)[0], get_shape(X)[0]) super(GaussianProcessKernelPrior, self).__init__(name, _transform, to_shape, *gp_params, tracked=tracked)
def __init__(self, name, dist, tracked=True): if not isinstance(dist, PriorTransform): dist = DeltaPrior('_{}_dist'.format(name), dist, False) self._shape = get_shape(dist)[::-1] super(TransposePrior, self).__init__(name, lambda x: jnp.transpose(x), get_shape(dist)[::-1], [dist], tracked=tracked)