Ejemplo n.º 1
0
def generate_data(
        rng=0,
        num_points=10000,
        sig_mean=(-1, 1),
        bup_mean=(2.5, 2),
        bdown_mean=(-2.5, -1.5),
        b_mean=(1, -1),
):
    sig = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(sig_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points, ),
    )
    bkg_up = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(bup_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points, ),
    )
    bkg_down = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(bdown_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points, ),
    )

    bkg_nom = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(b_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points, ),
    )
    return sig, bkg_nom, bkg_up, bkg_down
def draw_state(val, key, params):
    """
    Simulate one step of a system that evolves as
                A z_{t-1} + Bk + eps,
    where eps ~ N(0, Q).
    
    Parameters
    ----------
    val: tuple (int, jnp.array)
        (latent value of system, state value of system).
    params: PRBPFParamsDiscrete
    key: PRNGKey
    """
    latent_old, state_old = val
    probabilities = params.transition_matrix[latent_old, :]
    logits = logit(probabilities)
    latent_new = random.categorical(key, logits)

    key_latent, key_obs = random.split(key)
    state_new = params.A @ state_old + params.B[latent_new, :]
    state_new = random.multivariate_normal(key_latent, state_new, params.Q)
    obs_new = random.multivariate_normal(key_obs, params.C @ state_new,
                                         params.R)

    return (latent_new, state_new), (latent_new, state_new, obs_new)
Ejemplo n.º 3
0
    def filter(self, key, init_state, sample_obs, nsamples=2000):
        """
        init_state: array(state_size,)
            Initial state estimate
        sample_obs: array(nsamples, obs_size)
            Samples of the observations
        """
        nsteps = sample_obs.shape[0]
        mu_hist = jnp.zeros((nsteps, 2))
        keys = split(key, nsteps)
        for t, key_t in enumerate(keys):
            if t == 0:
                zt_rvs = random.multivariate_normal(key_t, init_state, self.Q,
                                                    (nsamples, ))
            else:
                zt_rvs = random.multivariate_normal(key_t, self.fz(zt_rvs),
                                                    self.Q)

            xt_rvs = random.multivariate_normal(key_t, self.fx(zt_rvs), self.R)

            weights_t = stats.multivariate_normal.pdf(sample_obs[t], xt_rvs,
                                                      self.Q)
            weights_t = weights_t / weights_t.sum()

            mu_t = (zt_rvs * weights_t[:, None]).sum(axis=0)
            mu_hist = index_update(mu_hist, t, mu_t)
        return mu_hist
Ejemplo n.º 4
0
  def bcR(self, rng=None, aspect_ratio=1.0):
    """bcR creates a random boundary condition for a rectangular domain defined

    by aspect_ratio. The boundary is a random 3rd order polynomial.
    rng variable allows to reproduce the results.
    The current boundary conditions are not periodic.
    """
    if rng is None:
      rng = random.PRNGKey(1)

    x = self.bcmesh
    n = self.n
    n_y = equations.num_row(n, aspect_ratio)
    y = np.linspace(0, 1, num=n_y)
    if rng is not None:
      coeffs = random.multivariate_normal(rng, np.zeros(16),
                                          np.diag(np.ones(16)))
    else:
      key = random.randint(random.PRNGKey(1), (1,), 1, 1000)
      coeffs = random.multivariate_normal(
          random.PRNGKey(key[0]), np.zeros(16), np.diag(np.ones(16)))
    left = coeffs[0] * y**3 + coeffs[1] * y**2 + coeffs[2] * y + coeffs[3]
    right = coeffs[4] * y**3 + coeffs[5] * y**2 + coeffs[6] * y + coeffs[7]
    lower = coeffs[8] * x**3 + coeffs[9] * x**2 + coeffs[10] * x + coeffs[11]
    upper = coeffs[12] * x**3 + coeffs[13] * x**2 + coeffs[14] * x + coeffs[15]
    shape = 2 * x.shape
    source = onp.zeros(shape)
    source[0, :] = upper
    source[n_y - 1, :] = lower
    source[0:n_y, -1] = right
    source[0:n_y, 0] = left
    # because this makes the correct order of boundary conditions
    return source * (n + 1)**2
Ejemplo n.º 5
0
  def bcL(self, rng=None):
    """bcL creates a random boundary condition for a L-shaped domain.

    The boundary is a random 3rd order polynomial of sine functions.
    rng variable allows to reproduce the results. Sine functions are chosen so
    that the boundary is periodic and does not have discontinuities.
    """
    if rng is None:
      rng = random.PRNGKey(1)
    n = self.n
    x = onp.sin(self.bcmesh * np.pi)
    n_y = (np.floor((n + 1) / 2) - 1).astype(int)
    if rng is not None:
      coeffs = random.multivariate_normal(rng, np.zeros(16),
                                          np.diag(np.ones(16)))
    else:
      key = random.randint(random.PRNGKey(1), (1,), 1, 1000)
      coeffs = random.multivariate_normal(
          random.PRNGKey(key[0]), np.zeros(16), np.diag(np.ones(16)))
    left = coeffs[0] * x**3 + coeffs[1] * x**2 + coeffs[2] * x  #+ coeffs[3]
    right = coeffs[4] * x**3 + coeffs[5] * x**2 + coeffs[6] * x  #+ coeffs[7]
    lower = coeffs[8] * x**3 + coeffs[9] * x**2 + coeffs[10] * x  #+ coeffs[11]
    upper = coeffs[12] * x**3 + coeffs[13] * x**2 + coeffs[14] * x  #+ coeffs[15]
    shape = 2 * x.shape
    source = onp.zeros(shape)
    source[0, :] = upper
    source[n_y - 1, n_y - 1:] = lower[:n - n_y + 1]
    source[n_y - 1:, n_y - 1] = right[:n - n_y + 1]
    source[:, 0] = left
    source[-1, :n_y - 1] = right[n:n - n_y:-1]
    source[:n_y - 1, -1] = lower[n:n - n_y:-1]
    # because this makes the correct order of boundary conditions
    return source * (n + 1)**2
Ejemplo n.º 6
0
    def __sample_step(self, input_vals, obs):
        key, state_t = input_vals
        key_system, key_obs, key = random.split(key, 3)

        state_t = random.multivariate_normal(key_system, self.fz(state_t), self.Q(state_t))
        obs_t = random.multivariate_normal(key_obs, self.fx(state_t, *obs), self.R(state_t, *obs))

        return (key, state_t), (state_t, obs_t)
Ejemplo n.º 7
0
    def sample(self, key, n_samples=1, sample_intial_state=False):
        """
        Simulate a run of n_sample independent stochastic
        linear dynamical systems

        Parameters
        ----------
        key: jax.random.PRNGKey
            Seed of initial random states
        n_samples: int
            Number of independent linear systems with shared dynamics (optional)
        sample_initial_state: bool
            Whether to sample from an initial state or sepecified

        Returns
        -------
        * array(n_samples, timesteps, state_size):
            Simulation of Latent states
        * array(n_samples, timesteps, observation_size):
            Simulation of observed states
        """
        key_z1, key_system_noise, key_obs_noise = random.split(key, 3)
        if not sample_intial_state:
            state_t =self.mu0 * jnp.ones((n_samples, self.state_size))
        else:
            state_t = random.multivariate_normal(key_z1, self.mu0, self.Sigma0, (n_samples,))

        # Generate all future noise terms
        zeros_state = jnp.zeros(self.state_size)
        zeros_obs = jnp.zeros(self.observation_size)
        system_noise = random.multivariate_normal(key_system_noise, zeros_state, self.Q, (n_samples, self.timesteps))
        obs_noise = random.multivariate_normal(key_obs_noise, zeros_obs, self.R, (n_samples, self.timesteps))
        
        state_hist = jnp.zeros((n_samples, self.timesteps, self.state_size))
        obs_hist = jnp.zeros((n_samples, self.timesteps, self.observation_size))

        obs_t = jnp.einsum("ij,sj->si", self.C, state_t) + obs_noise[:, 0, :]

        state_hist = index_update(state_hist, jax.ops.index[:, 0, :], state_t)
        obs_hist = index_update(obs_hist, jax.ops.index[:, 0, :], obs_t)

        for t in range(1, self.timesteps):
            system_noise_t = system_noise[:, t, :]
            obs_noise_t = obs_noise[:, t, :]

            state_new = jnp.einsum("ij,sj->si", self.A, state_t) + system_noise_t
            obs_t = jnp.einsum("ij,sj->si", self.C, state_new) + obs_noise_t
            state_t = state_new

            state_hist = index_update(state_hist, jax.ops.index[:, t, :], state_t)
            obs_hist = index_update(obs_hist, jax.ops.index[:, t, :], obs_t)

        if n_samples == 1:
            state_hist = state_hist[0, ...]
            obs_hist = obs_hist[0, ...]
        return state_hist, obs_hist
Ejemplo n.º 8
0
    def sample(self, key, x0, T, nsamples, dt=0.01, noisy=False):
        """
        Run the Extended Kalman Filter algorithm. First, we integrate
        up to time T, then we obtain nsamples equally-spaced points. Finally,
        we transform the latent space to obtain the observations

        Parameters
        ----------
        key: jax.random.PRNGKey
            Initial seed
        x0: array(state_size)
            Initial state of simulation
        T: float
            Final time of integration
        nsamples: int
            Number of observations to take from the total integration
        dt: float
            integration step size
        noisy: bool
            Whether to (naively) add noise to the state space

        Returns
        -------
        * array(nsamples, state_size)
            State-space values
        * array(nsamples, obs_size)
            Observed-space values
        * int
            Number of observations skipped between one
            datapoint and the next
        """
        nsteps = ceil(T / dt)
        jump_size = ceil(nsteps / nsamples)
        correction = nsamples - ceil(nsteps / jump_size)
        nsteps += correction * jump_size

        key_state, key_obs = random.split(key)
        state_noise = random.multivariate_normal(key_state,
                                                 jnp.zeros(self.state_size),
                                                 self.Q, (nsteps, ))
        obs_noise = random.multivariate_normal(key_obs,
                                               jnp.zeros(self.obs_size),
                                               self.R, (nsteps, ))
        simulation = self._rk2(x0, self.fz, nsteps, dt)

        if noisy:
            simulation = simulation + jnp.sqrt(dt) * state_noise

        sample_state = simulation[::jump_size]
        sample_obs = jnp.apply_along_axis(
            self.fx, 1, sample_state) + obs_noise[:len(sample_state)]

        return sample_state, sample_obs, jump_size
Ejemplo n.º 9
0
def driven_Langevin_move(x, potential, dt, A_function, b_function, potential_parameter, A_parameter, b_parameter, key):
    """
    driven Langevin Algorithm; a driven langevin propagator is
    N(x_t; \Theta_t^i(x_{t-1})*(f_t-dt*b_t^i)(x_{t-1}), dt*\Theta_t^i(x_{t-1}))
    where \Theta_t^i(x_{t-1}) = (I_d + 2*dt*A_t^i(x_{t-1}))^{-1},
    f_t(x_{t-1}) = x_{t-1} + 0.5*dt*\nabla \pi_t(x_{t-1})

    arguments:
        x : jnp.array(N)
            current position (or latent variable)
        potential : function
            potential function (takes args x and parameters)
        dt : float
            incremental time
        A_function : function
            covariance driver function
        b_function : function
            mean driver function
        potential_parameter : jnp.array
            parameters passed to potential
        A_parameter : jnp.array()
            second argument of A function
        b_parameter : jnp.array()
            second argument of b function
        key : float
            randomization key

    returns
        out : jnp.array(N)
            multivariate gaussian proposal
    """
    A, b, f, theta = driven_Langevin_parameters(x, potential, dt, A_function, b_function, potential_parameter, A_parameter, b_parameter)
    mu, cov = driven_mu_cov(b, f, theta, dt)
    return random.multivariate_normal(key, mu, cov)
Ejemplo n.º 10
0
def Wishart(key, dof, scale, shape=None):
    if scale is None:
        scale = jnp.eye(shape)
    batch_shape = ()
    if jnp.ndim(scale) > 2:
        batch_shape = scale.shape[:-2]
    p = scale.shape[-1]

    if dof is None:
        dof = p
    if jnp.ndim(dof) > 0:
        raise ValueError("only scalar dof implemented")
    if ~(int(dof) == dof):
        raise ValueError(
            "dof should be integer-like (i.e. int(dof) == dof should return true)"
        )
    else:
        dof = int(dof)

    if shape is not None:
        if batch_shape != ():
            assert batch_shape == shape, "Disagreement in batch shape between scale and shape"
        else:
            batch_shape = shape

    mn = jnp.zeros(shape=batch_shape + (p, ))
    mvn_shape = (dof, ) + batch_shape

    mvn = random.multivariate_normal(key, mean=mn, cov=scale, shape=mvn_shape)
    if jnp.ndim(mvn) > 2:
        mvn = jnp.swapaxes(mvn, 0, -2)

    S = jnp.einsum('...ji,...jk', mvn, mvn)

    return S
Ejemplo n.º 11
0
    def sample(self, key, shape=None):
        """
        Sample from the skewnormal distribution.

        Arguments:
         - key:
            a PRNGKey used as the random key.
         - shape: (Optional)
            optional, a tuple of nonnegative integers specifying the
            result batch shape; that is, the prefix of the result shape
            excluding the last axis. Must be broadcast-compatible
            with `mean.shape[:-1]` and `cov.shape[:-2]`. The default (None)
            produces a result batch shape by broadcasting together the
            batch shapes of `mean` and `cov`.
        """
        if shape is None:
            shape = (1, )
        X = random.multivariate_normal(
            key=key,
            shape=shape,
            mean=jnp.zeros(shape=(self.k + 1), ),
            cov=self.omega,
        )
        X0 = jnp.expand_dims(X[:, -1], 1)
        X = X[:, :-1]
        Z = jnp.where(X0 > 0, X, -X)
        return self.loc + jnp.einsum('i,ji->ji', self.scale, Z)
def bootstrap_step(state, y):
    latent_t, state_t, key = state
    key_latent, key_state, key_reindex, key_next = random.split(key, 4)

    # Discrete states
    latent_t = random.categorical(key_latent,
                                  jnp.log(transition_matrix[latent_t]),
                                  shape=(nparticles, ))
    # Continous states
    state_mean = jnp.einsum("nm,sm->sn", A, state_t) + B[latent_t]
    state_t = random.multivariate_normal(key_state, mean=state_mean, cov=Q)

    # Compute weights
    weights_t = multivariate_normal.pdf(y, mean=state_t, cov=C)
    indices_t = random.categorical(key_reindex,
                                   jnp.log(weights_t),
                                   shape=(nparticles, ))

    # Reindex and compute weights
    state_t = state_t[indices_t, ...]
    latent_t = latent_t[indices_t, ...]
    # weights_t = jnp.ones(nparticles) / nparticles

    mu_t = state_t.mean(axis=0)

    return (latent_t, state_t, key_next), (mu_t, latent_t, state_t)
Ejemplo n.º 13
0
def test_1D_ULA_propagator(key=random.PRNGKey(0), num_runs=1000):
    """
    take a batch of 1000 particles distributed according to N(0,2), run dynamics with ULA for 1000 steps with dt=0.01 on a potential whose invariant is N(0,2)
    and assert that the mean and variance is unchanged within a tolerance
    """
    key, genkey = random.split(key)
    potential, (mu, cov), dG = get_nondefault_potential_initializer(1)
    x_ula_starter = random.multivariate_normal(key=genkey,
                                               mean=mu,
                                               cov=cov,
                                               shape=[num_runs])
    dt = 1e-2
    batch_ula_move = vmap(ULA_move, in_axes=(0, None, None, 0, None))
    potential_parameter = jnp.array([0.])

    for i in tqdm.trange(100):
        key, ula_keygen = random.split(key, 2)
        ula_keys = random.split(ula_keygen, num_runs)
        x_ULA = batch_ula_move(x_ula_starter, potential, dt, ula_keys,
                               potential_parameter)
        x_ula_starter = x_ULA

    ula_mean, ula_std = x_ula_starter.mean(), x_ula_starter.std()

    assert checker_function(ula_mean, 0.2)
    assert checker_function(ula_std - jnp.sqrt(2), 0.2)
Ejemplo n.º 14
0
 def rvs(self, prng_key: DeviceArray, n_samples: int) -> DeviceArray:
     return random.multivariate_normal(
         key=prng_key,
         mean=self._mean,
         cov=self._cov,
         shape=(n_samples, ),
     )
Ejemplo n.º 15
0
    def __filter_step(self, state, obs_t):
        nsamples = self.nsamples
        indices = jnp.arange(nsamples)
        zt_rvs, key_t = state

        key_t, key_reindex, key_next = random.split(key_t, 3)
        # 1. Draw new points from the dynamic model
        zt_rvs = random.multivariate_normal(key_t, self.fz(zt_rvs), self.Q)

        # 2. Calculate unnormalised weights
        xt_rvs = self.fx(zt_rvs)
        weights_t = stats.multivariate_normal.pdf(obs_t, xt_rvs, self.R)

        # 3. Resampling
        pi = random.choice(key_reindex,
                           indices,
                           p=weights_t,
                           shape=(nsamples, ))
        zt_rvs = zt_rvs[pi, ...]
        weights_t = jnp.ones(nsamples) / nsamples

        # 4. Compute latent-state estimate,
        #    Set next covariance state matrix
        mu_t = jnp.einsum("im,i->m", zt_rvs, weights_t)

        return (zt_rvs, key_next), mu_t
Ejemplo n.º 16
0
    def testMultivariateNormalCovariance(self):
        # test code based on https://github.com/google/jax/issues/1869
        N = 100000
        cov = jnp.array([[0.19, 0.00, -0.13, 0.00], [0.00, 0.29, 0.00, -0.23],
                         [-0.13, 0.00, 0.39, 0.00], [0.00, -0.23, 0.00, 0.49]])
        mean = jnp.zeros(4)

        out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N)

        key = random.PRNGKey(0)
        out_jnp = random.multivariate_normal(key,
                                             mean=mean,
                                             cov=cov,
                                             shape=(N, ))

        var_np = out_np.var(axis=0)
        var_jnp = out_jnp.var(axis=0)
        self.assertAllClose(var_np,
                            var_jnp,
                            rtol=1e-2,
                            atol=1e-2,
                            check_dtypes=False)

        var_np = np.cov(out_np, rowvar=False)
        var_jnp = np.cov(out_jnp, rowvar=False)
        self.assertAllClose(var_np,
                            var_jnp,
                            rtol=1e-2,
                            atol=1e-2,
                            check_dtypes=False)
Ejemplo n.º 17
0
    def test_integrated_pos_enc(self):
        num_dims = 2  # The number of input dimensions.
        min_deg = 0
        max_deg = 4
        num_samples = 100000
        rng = random.PRNGKey(0)
        for _ in range(5):
            # Generate a coordinate's mean and covariance matrix.
            key, rng = random.split(rng)
            mean = random.normal(key, (2, ))
            key, rng = random.split(rng)
            half_cov = jax.random.normal(key, [num_dims] * 2)
            cov = half_cov @ half_cov.T
            for diag in [False, True]:
                # Generate an IPE.
                enc = mip.integrated_pos_enc(
                    (mean, jnp.diag(cov) if diag else cov),
                    min_deg,
                    max_deg,
                    diag,
                )

                # Draw samples, encode them, and take their mean.
                key, rng = random.split(rng)
                samples = random.multivariate_normal(key, mean, cov,
                                                     [num_samples])
                enc_samples = mip.pos_enc(samples,
                                          min_deg,
                                          max_deg,
                                          append_identity=False)
                enc_gt = jnp.mean(enc_samples, 0)
                self.assertAllClose(enc, enc_gt, rtol=1e-2, atol=1e-2)
Ejemplo n.º 18
0
    def sample(self, rng_key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        draws = random.multivariate_normal(rng_key,
                                           mean=self.mu,
                                           cov=self.covariance_matrix,
                                           shape=shape)

        return draws
Ejemplo n.º 19
0
def plot_mlp_prediction(key, xobs, yobs, xtest, fw, w, Sw, ax, n_samples=100):
    W_samples = multivariate_normal(key, w, Sw, (n_samples,))
    sample_yhat = fw(W_samples, xtest[:, None])
    for sample in sample_yhat: # sample curves
        ax.plot(xtest, sample, c="tab:gray", alpha=0.07)
    ax.plot(xtest, sample_yhat.mean(axis=0)) # mean of posterior predictive
    ax.scatter(xobs, yobs, s=14, c="none", edgecolor="black", label="observations", alpha=0.5)
    ax.set_xlim(xobs.min(), xobs.max())
Ejemplo n.º 20
0
 def prior_sample(self, X=None, num_samps=1, key=0):
     if X is None:
         X = self.X
     N = X.shape[0]
     m = np.zeros(N)
     K = self.kernel(X, X) + 1e-12 * np.eye(N)
     s = multivariate_normal(PRNGKey(key), m, K, shape=[num_samps])
     return s.T
 def sample(self, n_samples, key=None):
     if key is None:
         self.threadkey, key = random.split(self.threadkey)
     out = random.multivariate_normal(key,
                                      self.gauss_mean,
                                      self.gauss_cov,
                                      shape=(n_samples, ))
     return vmap(self.bananify)(out.reshape((n_samples, self.d)))
Ejemplo n.º 22
0
    def sample(self, rng_key, sample_shape=()):
        # random.multivariate_normal automatically adds the event shape
        # to the shape passed as argument.
        shape = sample_shape + self.batch_shape
        draws = random.multivariate_normal(
            rng_key, mean=self.mu, cov=self.covariance_matrix, shape=shape
        )

        return draws
 def sample(self, n_samples, key=None):
     """mutates self.key if key is None"""
     if key is None:
         self.threadkey, key = random.split(self.threadkey)
     out = random.multivariate_normal(key,
                                      self.mean,
                                      self.gauss_cov,
                                      shape=(n_samples, ))
     return vmap(self.squiggle)(out.reshape((n_samples, self.d)))
Ejemplo n.º 24
0
    def sample(self, key, x0, nsteps):
        """
        Sample discrete elements of a nonlinear system

        Parameters
        ----------
        key: jax.random.PRNGKey
        x0: array(state_size)
            Initial state of simulation
        nsteps: int
            Total number of steps to sample from the system

        Returns
        -------
        * array(nsamples, state_size)
            State-space values
        * array(nsamples, obs_size)
            Observed-space values
        """
        key, key_system_noise, key_obs_noise = random.split(key, 3)

        state_hist = jnp.zeros((nsteps, self.state_size))
        obs_hist = jnp.zeros((nsteps, self.obs_size))

        state_t = x0.copy()
        obs_t = self.fx(state_t)

        state_noise = random.multivariate_normal(
            key_system_noise, jnp.zeros((self.state_size, )), self.Q,
            (nsteps, ))
        obs_noise = random.multivariate_normal(key_obs_noise,
                                               jnp.zeros((self.obs_size, )),
                                               self.R, (nsteps, ))
        state_hist = index_update(state_hist, 0, state_t)
        obs_hist = index_update(obs_hist, 0, obs_t)

        for t in range(1, nsteps):
            state_t = self.fz(state_t) + state_noise[t]
            obs_t = self.fx(state_t) + obs_noise[t]

            state_hist = index_update(state_hist, t, state_t)
            obs_hist = index_update(obs_hist, t, obs_t)

        return state_hist, obs_hist
 def sample(self, n_samples, key=None):
     """mutates self.key if key is None"""
     if key is None:
         self.threadkey, key = random.split(self.threadkey)
     out = random.multivariate_normal(key,
                                      self.mean,
                                      self.cov,
                                      shape=(n_samples, ))
     shape = (n_samples, self.d)
     return out.reshape(shape)
Ejemplo n.º 26
0
 def sample(rng, params, num_samples=1):
     cluster_samples = []
     for mean, cov in zip(means, covariances):
         rng, temp_rng = random.split(rng)
         cluster_sample = random.multivariate_normal(
             temp_rng, mean, cov, (num_samples, ))
         cluster_samples.append(cluster_sample)
     samples = np.dstack(cluster_samples)
     idx = random.categorical(rng, weights, shape=(num_samples, 1, 1))
     return np.squeeze(np.take_along_axis(samples, idx, -1))
Ejemplo n.º 27
0
def experiment(key, K, cliques0, cliques1, n):
    p = K.shape[0]
    X = random.multivariate_normal(key, jnp.zeros(p), jnp.linalg.inv(K), (n, ))
    ssd = X.T @ X
    Shat = ssd / n  # + np.finfo(np.float32).eps * jnp.eye(p)

    Khat0 = itpropscaling(cliques0, Shat)
    Khat1 = itpropscaling(cliques1, Shat)

    w = n * (jnp.linalg.slogdet(Khat1)[1] - jnp.linalg.slogdet(Khat0)[1])
    return jnp.exp(-w / 2)
Ejemplo n.º 28
0
 def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
                                  shape):
   r = np.random.RandomState(0)
   key = random.PRNGKey(0)
   eff_batch_size = mean_batch_size \
     if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size
   mean = r.randn(*(mean_batch_size + (dim,)))
   cov_factor = r.randn(*(cov_batch_size + (dim, dim)))
   cov = np.einsum('...ij,...kj->...ik', cov_factor, cov_factor)
   cov += 1e-3 * np.eye(dim)
   shape = shape + eff_batch_size
   samples = random.multivariate_normal(key, mean, cov, shape=shape)
   assert samples.shape == shape + (dim,)
    def _generate_user_preferences(self):
        ''' 
        generate_user_preferences creates a NxK list of bernoulli probabilities for 
        how likely a user will purchase an item from each bin (each row sums to 1)
        '''
        # Create one column that is largely preferred and peter out over the rest of the bins.
        user_prefs = random.multivariate_normal(self.rng_key,np.array([0.9]+[0.1/(self.num_bins-1)]*(self.num_bins-1)) , 0.005*np.eye(self.num_bins), shape=(self.num_users,))
        all_perm = onp.array((list(itertools.permutations(list(range(self.num_bins)))))) # Note all possible permutations of the bins
        temp = all_perm[random.randint(self.rng_key,(self.num_users,),0,len(all_perm))] # Randomly select a permutation for each row of 'user_prefs'
        # Randomly permute the columns of each row in 'user_prefs'
        user_prefs = (user_prefs.flatten()[(temp+self.num_bins*np.arange(self.num_users)[...,np.newaxis]).flatten()]).reshape(user_prefs.shape) 
        # Normalize user preferences
        user_prefs = abs(user_prefs) * (1/np.sum(abs(user_prefs),axis=1))[:,np.newaxis]

        self.user_prefs = user_prefs # Set class values
        self._generate_user_context() # Generate the various context variables that will be provided to the algorithm
Ejemplo n.º 30
0
 def sample(
     self,
     objective: Union[Quadratic, LeastSquares],
     prng_key: jnp.ndarray,
     num_samples: int = 1,
     **_unused_kwargs,
 ) -> jnp.ndarray:
     """Generates exact samples from a quadratic posterior (Gaussian)."""
     if isinstance(objective, LeastSquares):
         objective = Quadratic.from_least_squares(objective)
     state_mean = objective.solve()
     state_cov = jnp.linalg.pinv(objective.A)
     samples = random.multivariate_normal(prng_key,
                                          state_mean,
                                          state_cov,
                                          shape=(num_samples, ))
     return samples