Exemple #1
0
    def get_action_probs(self, actor_params, states: np.array, actions: np.array):
        # todo: make std an arg so modifiable during after compilation?
        action_means = self.actor(actor_params, states)
        cov = jnp.diag(jnp.repeat(self.std, self.action_dim))
        action_probs = multivariate_normal.pdf(actions, action_means, cov)

        return action_probs
def bootstrap_step(state, y):
    latent_t, state_t, key = state
    key_latent, key_state, key_reindex, key_next = random.split(key, 4)

    # Discrete states
    latent_t = random.categorical(key_latent,
                                  jnp.log(transition_matrix[latent_t]),
                                  shape=(nparticles, ))
    # Continous states
    state_mean = jnp.einsum("nm,sm->sn", A, state_t) + B[latent_t]
    state_t = random.multivariate_normal(key_state, mean=state_mean, cov=Q)

    # Compute weights
    weights_t = multivariate_normal.pdf(y, mean=state_t, cov=C)
    indices_t = random.categorical(key_reindex,
                                   jnp.log(weights_t),
                                   shape=(nparticles, ))

    # Reindex and compute weights
    state_t = state_t[indices_t, ...]
    latent_t = latent_t[indices_t, ...]
    # weights_t = jnp.ones(nparticles) / nparticles

    mu_t = state_t.mean(axis=0)

    return (latent_t, state_t, key_next), (mu_t, latent_t, state_t)
Exemple #3
0
    def act(self, key, actor_params, state: np.array):
        action_mean = self.actor(actor_params, state)
        cov = jnp.diag(jnp.repeat(self.std, self.action_dim))

        action = jax.random.multivariate_normal(key, action_mean, cov)
        action_prob = multivariate_normal.pdf(action, action_mean, cov)

        return stop_gradient(action), stop_gradient(action_prob)
Exemple #4
0
def compute_w_gmm(x, **kwargs):
    bounds = kwargs['bounds']
    lb = bounds['lb']
    ub = bounds['ub']
    x = (x - lb) / (ub - lb)
    weights, means, covs = kwargs['gmm_vars']
    gmm_mode = lambda w, mu, cov: w * multivariate_normal.pdf(x, mu, cov)
    w = np.sum(vmap(gmm_mode)(weights, means, covs), axis=0)
    return w
Exemple #5
0
 def variational_entropy(self, zeta, phi):
     L = self.L(phi)
     probs = multivariate_normal.pdf(zeta, mean=phi[: self.latent_dim], cov=L @ L.T)
     return -(probs * jnp.log(probs)).sum()
Exemple #6
0
 def pdf(self, x):
     return multivariate_normal.pdf(x, self.mu, self.cov)
    def get_marginals(self,
                      parameter_estimates=None,
                      invF=None,
                      ranges=None,
                      gridsize=None):
        """
        Creates list of 1D and 2D marginal distributions ready for plotting

        The marginal distribution lists from full distribution array. For every
        parameter the full distribution is summed over every other parameter to
        get the 1D marginals and for every combination the 2D marginals are
        calculated by summing over the remaining parameters. The list is made
        up of a list of n_params lists which contain n_columns number of
        objects. The value of the distribution comes from

        Parameters
        ----------
        parameter_estimates: float(n_targets, n_params) or None, default=None
            The parameter estimates of each target data. If None the class
            instance parameter estimates are used
        invF: float(n_targets, n_params, n_params) or None, default=None
            The inverse Fisher information matrix for each target. If None the
            class instance inverse Fisher information matrices are used
        ranges : list or None, default=None
            A list of arrays containing the gridpoints for the marginal
            distribution for each parameter. If None the class instance ranges
            are used determined by the prior range
        gridsize : list or None, default=None
            If using own `ranges` then the gridsize for these ranges must be
            passed (not checked)

        Returns
        -------
        list of lists:
            The 1D and 2D marginal distributions for each parameter (of pair)

        Todo
        ----
        Need to multiply the distribution by the prior to get the posterior
        Maybe move to TensorFlow probability?
        Make sure that using several Fisher estimates works
        """
        if parameter_estimates is None:
            parameter_estimates = self.parameter_estimates
        n_targets = parameter_estimates.shape[0]
        if invF is None:
            invF = self.invF
        if ranges is None:
            ranges = self.ranges
        if gridsize is None:
            gridsize = self.gridsize
        marginals = []
        for row in range(self.n_params):
            marginals.append([])
            for column in range(self.n_params):
                if column == row:
                    marginals[row].append(
                        jax.vmap(lambda mean, _invF: norm.pdf(
                            ranges[column], mean, np.sqrt(_invF)))(
                                parameter_estimates[:, column], invF[:, column,
                                                                     column]))
                elif column < row:
                    X, Y = np.meshgrid(ranges[row], ranges[column])
                    unravelled = np.vstack([X.ravel(), Y.ravel()]).T
                    marginals[row].append(
                        jax.vmap(lambda mean, _invF: multivariate_normal.pdf(
                            unravelled, mean, _invF).reshape(
                                ((gridsize[column], gridsize[row]))))(
                                    parameter_estimates[:, [row, column]],
                                    invF[:, [row, row, column, column],
                                         [row, column, row, column]].reshape(
                                             (n_targets, 2, 2))))
        return marginals
Exemple #8
0
    return 4 * np.random.uniform() - 2


def randscale():
    return 3 * np.random.uniform() + 0.2


loc1 = jnp.array([randloc(), randloc()]).astype(jnp.float32)
loc2 = jnp.array([randloc(), randloc()]).astype(jnp.float32)
scale1 = jnp.diag(jnp.array([randscale(), randscale()]).astype(jnp.float32))
scale2 = jnp.diag(jnp.array([randscale(), randscale()]).astype(jnp.float32))
ratios = np.random.uniform(size=2)
coeffs = ratios / ratios.sum()
contour_xs = jnp.linspace(-3, 3, 75)
contour_ys = jnp.linspace(-3, 3, 75)
density = lambda x: coeffs[0] * pdf(x, loc1, scale1) + coeffs[1] * pdf(
    x, loc2, scale2)
print(f"Loc1: {loc1}\nScale1: {scale1}")
print(f"Loc2: {loc2}\nScale2: {scale2}")
print(coeffs)

zs = np.zeros((len(contour_xs), len(contour_ys)))
for (i, x) in enumerate(contour_xs):
    points = jnp.array([[x, y] for y in contour_ys])
    zs[i] = density(points)


def update(model):
    plt.clf()
    plt.xlim([-3, 3])
    plt.ylim([-3, 3])
Exemple #9
0
    # print('Delta:', sn.delta)
    # print('Psi: ', sn.psi)
    # print('Eigvals psi: ', jnp.linalg.eigvalsh(sn.psi))
    # print('alpha:', sn.alpha)
    # print('Omega:', sn.omega)
    # print('Eigvals Omega: ', jnp.linalg.eigvalsh(sn.omega))

    rng, key = random.split(rng)
    data = sn.sample(key, shape=(n, ))

    data_skew = pd.DataFrame(data, columns=['x', 'y'])

    x = jnp.linspace(jnp.min(data[:, 0]), jnp.max(data[:, 0]), l)
    y = jnp.linspace(jnp.min(data[:, 1]), jnp.max(data[:, 1]), l)
    xy = jnp.array(list(product(x, y)))
    Z_skew = sn.pdf(xy).reshape(l, l).T
    Z_norm = mvn.pdf(xy, jnp.zeros(p, ), cov=sn.cov).reshape(l, l).T

    g = sns.jointplot(data=data_skew, x='x', y='y', alpha=0.3)
    g.ax_joint.contour(x,
                       y,
                       Z_norm,
                       colors='k',
                       alpha=0.7,
                       linestyles='dashed')
    g.ax_joint.contour(x, y, Z_skew, colors='k')
    plt.show()

    # print(sn.logpdf(data))