コード例 #1
0
 def __init__(self, name, mu, b, tracked=True):
     if not isinstance(mu, PriorTransform):
         mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_1d(mu), False)
     if not isinstance(b, PriorTransform):
         b = DeltaPrior('_{}_b'.format(name), b, False)
     U_dims = broadcast_shapes(get_shape(mu), get_shape(b))[0]
     super(LaplacePrior, self).__init__(name, U_dims, [mu, b], tracked)
コード例 #2
0
 def __init__(self, name, mu, gamma, tracked=True):
     if not isinstance(mu, PriorTransform):
         mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_1d(mu), False)
     if not isinstance(gamma, PriorTransform):
         gamma = DeltaPrior('_{}_gamma'.format(name), jnp.atleast_1d(gamma),
                            False)
     U_dims = broadcast_shapes(get_shape(mu), get_shape(gamma))[0]
     super(MVNDiagPrior, self).__init__(name, U_dims, [mu, gamma], tracked)
コード例 #3
0
 def __init__(self, name, mu, Gamma, ill_cond=False, tracked=True):
     self._ill_cond = ill_cond
     if not isinstance(mu, PriorTransform):
         mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_1d(mu), False)
     if not isinstance(Gamma, PriorTransform):
         Gamma = DeltaPrior('_{}_Gamma'.format(name), jnp.atleast_2d(Gamma),
                            False)
     U_dims = broadcast_shapes(get_shape(mu), get_shape(Gamma)[0:1])[0]
     super(MVNPrior, self).__init__(name, U_dims, [mu, Gamma], tracked)
コード例 #4
0
    def __init__(self, name, low, high, tracked=True):
        if not isinstance(low, PriorTransform):
            low = DeltaPrior('_{}_low'.format(name), low, False)
        if not isinstance(high, PriorTransform):
            high = DeltaPrior('_{}_high'.format(name), high, False)

        self._broadcast_shape = broadcast_shapes(get_shape(low),
                                                 get_shape(high))
        U_dims = tuple_prod(self._broadcast_shape)
        super(UniformPrior, self).__init__(name, U_dims, [low, high], tracked)
コード例 #5
0
 def __init__(self, name, T, x0, omega, tracked=True):
     if not isinstance(x0, PriorTransform):
         x0 = DeltaPrior('_{}_x0'.format(name), x0, False)
     if not isinstance(omega, PriorTransform):
         omega = DeltaPrior('_{}_omega'.format(name), omega, False)
     # replaces mu and gamma when parents injected
     self.dim = broadcast_shapes(get_shape(x0), get_shape(omega))[0]
     self.T = T
     super(DiagGaussianWalkPrior, self).__init__(name, self.dim * self.T,
                                                 [x0, omega], tracked)
コード例 #6
0
    def __init__(self, name, n, low, high, tracked=True):
        if not isinstance(low, PriorTransform):
            low = DeltaPrior('_{}_low'.format(name), low, False)
        if not isinstance(high, PriorTransform):
            high = DeltaPrior('_{}_high'.format(name), high, False)
        self._n = n
        # replaces mu and gamma when parents injected

        self._broadcast_shape = (self._n,) + broadcast_shapes(get_shape(low), get_shape(high))
        U_dims = tuple_prod(self._broadcast_shape)
        super(ForcedIdentifiabilityPrior, self).__init__(name, U_dims, [low, high], tracked)
コード例 #7
0
 def __init__(self, name, pi, mu, gamma, tracked=True):
     if not isinstance(pi, PriorTransform):
         pi = DeltaPrior('_{}_pi'.format(name), pi, False)
     if not isinstance(mu, PriorTransform):
         mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_2d(mu), False)
     if not isinstance(gamma, PriorTransform):
         gamma = DeltaPrior('_{}_gamma'.format(name), jnp.atleast_2d(gamma), False)
     assert (get_shape(pi)[0] == get_shape(mu)[0]) and (get_shape(pi)[0] == get_shape(gamma)[0]) \
            and (get_shape(mu)[1] == get_shape(gamma)[1])
     # replaces mu and gamma when parents injected
     U_dims = 1 + broadcast_shapes(get_shape(mu)[-1:], get_shape(gamma)[-1:])[0]
     super(GMMDiagPrior, self).__init__(name, U_dims, [pi, mu, gamma], tracked)
コード例 #8
0
 def __init__(self, name, pi, low, high, tracked=True):
     if not isinstance(pi, PriorTransform):
         pi = DeltaPrior('_{}_pi'.format(name), pi, False)
     if not isinstance(low, PriorTransform):
         low = DeltaPrior('_{}_low'.format(name), jnp.atleast_2d(low), False)
     if not isinstance(high, PriorTransform):
         high = DeltaPrior('_{}_high'.format(name), jnp.atleast_2d(high), False)
     assert (get_shape(pi)[0] == get_shape(low)[0]) and (get_shape(pi)[0] == get_shape(high)[0]) \
            and (get_shape(low)[1] == get_shape(high)[1])
     # replaces mu and high when parents injected
     U_dims = 1 + broadcast_shapes(get_shape(low)[-1:], get_shape(high)[-1:])[0]
     super(UniformMixturePrior, self).__init__(name, U_dims, [pi, low, high], tracked)
コード例 #9
0
 def __init__(self, name, T, x0, half_width, tracked=True):
     if not isinstance(x0, PriorTransform):
         x0 = DeltaPrior('_{}_x0'.format(name), x0, False)
     if not isinstance(half_width, PriorTransform):
         half_width = DeltaPrior('_{}_half_width'.format(name), half_width,
                                 False)
     # replaces mu and gamma when parents injected
     self.dim = broadcast_shapes(get_shape(x0), get_shape(half_width))[0]
     self.T = T
     super(SymmetricUniformWalkPrior,
           self).__init__(name, self.dim * self.T, [x0, half_width],
                          tracked)
コード例 #10
0
 def __init__(self, name, transform, pi, *components, tracked=True):
     self._transform = transform
     if not isinstance(pi, PriorTransform):
         pi = DeltaPrior('_{}_pi'.format(name), pi, False)
     assert (get_shape(pi)[0] == len(components))
     shape = ()
     for component in components:
         assert isinstance(component, PriorTransform)
         shape = broadcast_shapes(shape, component.to_shape)
     self._shape = shape
     # replaces mu and gamma when parents injected
     U_dims = 1
     super(MixturePrior, self).__init__(name, U_dims, [pi] + components, tracked)
コード例 #11
0
    def forward(self, U, low, high, **kwargs):
        log_x = jnp.log(jnp.reshape(U, self.to_shape))

        # theta[i] = theta[i-1] * (1 - x[i]) + theta_max * x[i]
        def body(state, X):
            (log_theta,) = state
            (log_x, i) = X
            log_theta = log_x / i + log_theta
            return (log_theta,), (log_theta,)

        log_init_theta = jnp.zeros(broadcast_shapes(low.shape, high.shape))
        _, (log_theta,) = scan(body, (log_init_theta,), (log_x, jnp.arange(1, self._n + 1)), reverse=True)
        theta = low + (high - low) * jnp.exp(log_theta)
        return theta