def __init__(self, name, T, x0, omega, tracked=True): if not isinstance(x0, PriorTransform): x0 = DeltaPrior('_{}_x0'.format(name), x0, False) if not isinstance(omega, PriorTransform): omega = DeltaPrior('_{}_omega'.format(name), omega, False) # replaces mu and gamma when parents injected self.dim = broadcast_shapes(get_shape(x0), get_shape(omega))[0] self.T = T super(DiagGaussianWalkPrior, self).__init__(name, self.dim * self.T, [x0, omega], tracked)
def __init__(self, name, n, low, high, tracked=True): if not isinstance(low, PriorTransform): low = DeltaPrior('_{}_low'.format(name), low, False) if not isinstance(high, PriorTransform): high = DeltaPrior('_{}_high'.format(name), high, False) self._n = n # replaces mu and gamma when parents injected self._broadcast_shape = (self._n,) + broadcast_shapes(get_shape(low), get_shape(high)) U_dims = tuple_prod(self._broadcast_shape) super(ForcedIdentifiabilityPrior, self).__init__(name, U_dims, [low, high], tracked)
def __init__(self, name, pi, mu, gamma, tracked=True): if not isinstance(pi, PriorTransform): pi = DeltaPrior('_{}_pi'.format(name), pi, False) if not isinstance(mu, PriorTransform): mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_2d(mu), False) if not isinstance(gamma, PriorTransform): gamma = DeltaPrior('_{}_gamma'.format(name), jnp.atleast_2d(gamma), False) assert (get_shape(pi)[0] == get_shape(mu)[0]) and (get_shape(pi)[0] == get_shape(gamma)[0]) \ and (get_shape(mu)[1] == get_shape(gamma)[1]) # replaces mu and gamma when parents injected U_dims = 1 + broadcast_shapes(get_shape(mu)[-1:], get_shape(gamma)[-1:])[0] super(GMMDiagPrior, self).__init__(name, U_dims, [pi, mu, gamma], tracked)
def __init__(self, name, pi, low, high, tracked=True): if not isinstance(pi, PriorTransform): pi = DeltaPrior('_{}_pi'.format(name), pi, False) if not isinstance(low, PriorTransform): low = DeltaPrior('_{}_low'.format(name), jnp.atleast_2d(low), False) if not isinstance(high, PriorTransform): high = DeltaPrior('_{}_high'.format(name), jnp.atleast_2d(high), False) assert (get_shape(pi)[0] == get_shape(low)[0]) and (get_shape(pi)[0] == get_shape(high)[0]) \ and (get_shape(low)[1] == get_shape(high)[1]) # replaces mu and high when parents injected U_dims = 1 + broadcast_shapes(get_shape(low)[-1:], get_shape(high)[-1:])[0] super(UniformMixturePrior, self).__init__(name, U_dims, [pi, low, high], tracked)
def __init__(self, name, T, x0, half_width, tracked=True): if not isinstance(x0, PriorTransform): x0 = DeltaPrior('_{}_x0'.format(name), x0, False) if not isinstance(half_width, PriorTransform): half_width = DeltaPrior('_{}_half_width'.format(name), half_width, False) # replaces mu and gamma when parents injected self.dim = broadcast_shapes(get_shape(x0), get_shape(half_width))[0] self.T = T super(SymmetricUniformWalkPrior, self).__init__(name, self.dim * self.T, [x0, half_width], tracked)
def __init__(self, name, transform, to_shape, *params, tracked=True): params = list(params) for i, param in enumerate(params): if not isinstance(param, PriorTransform): params[i] = DeltaPrior('_{}_param[{:d}]'.format(name, i), param, False) self._to_shape = to_shape self._transform = transform super(DeterministicTransformPrior, self).__init__(name, 0, params, tracked)
def __init__(self, name, transform, pi, *components, tracked=True): self._transform = transform if not isinstance(pi, PriorTransform): pi = DeltaPrior('_{}_pi'.format(name), pi, False) assert (get_shape(pi)[0] == len(components)) shape = () for component in components: assert isinstance(component, PriorTransform) shape = broadcast_shapes(shape, component.to_shape) self._shape = shape # replaces mu and gamma when parents injected U_dims = 1 super(MixturePrior, self).__init__(name, U_dims, [pi] + components, tracked)
def __init__(self, name, dist, tracked=True): if not isinstance(dist, PriorTransform): dist = DeltaPrior('_{}_dist'.format(name), dist, False) self._shape = get_shape(dist)[::-1] super(TransposePrior, self).__init__(name, lambda x: jnp.transpose(x), get_shape(dist)[::-1], [dist], tracked=tracked)