def get_action_probs(self, actor_params, states: np.array, actions: np.array): # todo: make std an arg so modifiable during after compilation? action_means = self.actor(actor_params, states) cov = jnp.diag(jnp.repeat(self.std, self.action_dim)) action_probs = multivariate_normal.pdf(actions, action_means, cov) return action_probs
def bootstrap_step(state, y): latent_t, state_t, key = state key_latent, key_state, key_reindex, key_next = random.split(key, 4) # Discrete states latent_t = random.categorical(key_latent, jnp.log(transition_matrix[latent_t]), shape=(nparticles, )) # Continous states state_mean = jnp.einsum("nm,sm->sn", A, state_t) + B[latent_t] state_t = random.multivariate_normal(key_state, mean=state_mean, cov=Q) # Compute weights weights_t = multivariate_normal.pdf(y, mean=state_t, cov=C) indices_t = random.categorical(key_reindex, jnp.log(weights_t), shape=(nparticles, )) # Reindex and compute weights state_t = state_t[indices_t, ...] latent_t = latent_t[indices_t, ...] # weights_t = jnp.ones(nparticles) / nparticles mu_t = state_t.mean(axis=0) return (latent_t, state_t, key_next), (mu_t, latent_t, state_t)
def act(self, key, actor_params, state: np.array): action_mean = self.actor(actor_params, state) cov = jnp.diag(jnp.repeat(self.std, self.action_dim)) action = jax.random.multivariate_normal(key, action_mean, cov) action_prob = multivariate_normal.pdf(action, action_mean, cov) return stop_gradient(action), stop_gradient(action_prob)
def compute_w_gmm(x, **kwargs): bounds = kwargs['bounds'] lb = bounds['lb'] ub = bounds['ub'] x = (x - lb) / (ub - lb) weights, means, covs = kwargs['gmm_vars'] gmm_mode = lambda w, mu, cov: w * multivariate_normal.pdf(x, mu, cov) w = np.sum(vmap(gmm_mode)(weights, means, covs), axis=0) return w
def variational_entropy(self, zeta, phi): L = self.L(phi) probs = multivariate_normal.pdf(zeta, mean=phi[: self.latent_dim], cov=L @ L.T) return -(probs * jnp.log(probs)).sum()
def pdf(self, x): return multivariate_normal.pdf(x, self.mu, self.cov)
def get_marginals(self, parameter_estimates=None, invF=None, ranges=None, gridsize=None): """ Creates list of 1D and 2D marginal distributions ready for plotting The marginal distribution lists from full distribution array. For every parameter the full distribution is summed over every other parameter to get the 1D marginals and for every combination the 2D marginals are calculated by summing over the remaining parameters. The list is made up of a list of n_params lists which contain n_columns number of objects. The value of the distribution comes from Parameters ---------- parameter_estimates: float(n_targets, n_params) or None, default=None The parameter estimates of each target data. If None the class instance parameter estimates are used invF: float(n_targets, n_params, n_params) or None, default=None The inverse Fisher information matrix for each target. If None the class instance inverse Fisher information matrices are used ranges : list or None, default=None A list of arrays containing the gridpoints for the marginal distribution for each parameter. If None the class instance ranges are used determined by the prior range gridsize : list or None, default=None If using own `ranges` then the gridsize for these ranges must be passed (not checked) Returns ------- list of lists: The 1D and 2D marginal distributions for each parameter (of pair) Todo ---- Need to multiply the distribution by the prior to get the posterior Maybe move to TensorFlow probability? Make sure that using several Fisher estimates works """ if parameter_estimates is None: parameter_estimates = self.parameter_estimates n_targets = parameter_estimates.shape[0] if invF is None: invF = self.invF if ranges is None: ranges = self.ranges if gridsize is None: gridsize = self.gridsize marginals = [] for row in range(self.n_params): marginals.append([]) for column in range(self.n_params): if column == row: marginals[row].append( jax.vmap(lambda mean, _invF: norm.pdf( ranges[column], mean, np.sqrt(_invF)))( parameter_estimates[:, column], invF[:, column, column])) elif column < row: X, Y = np.meshgrid(ranges[row], ranges[column]) unravelled = np.vstack([X.ravel(), Y.ravel()]).T marginals[row].append( jax.vmap(lambda mean, _invF: multivariate_normal.pdf( unravelled, mean, _invF).reshape( ((gridsize[column], gridsize[row]))))( parameter_estimates[:, [row, column]], invF[:, [row, row, column, column], [row, column, row, column]].reshape( (n_targets, 2, 2)))) return marginals
return 4 * np.random.uniform() - 2 def randscale(): return 3 * np.random.uniform() + 0.2 loc1 = jnp.array([randloc(), randloc()]).astype(jnp.float32) loc2 = jnp.array([randloc(), randloc()]).astype(jnp.float32) scale1 = jnp.diag(jnp.array([randscale(), randscale()]).astype(jnp.float32)) scale2 = jnp.diag(jnp.array([randscale(), randscale()]).astype(jnp.float32)) ratios = np.random.uniform(size=2) coeffs = ratios / ratios.sum() contour_xs = jnp.linspace(-3, 3, 75) contour_ys = jnp.linspace(-3, 3, 75) density = lambda x: coeffs[0] * pdf(x, loc1, scale1) + coeffs[1] * pdf( x, loc2, scale2) print(f"Loc1: {loc1}\nScale1: {scale1}") print(f"Loc2: {loc2}\nScale2: {scale2}") print(coeffs) zs = np.zeros((len(contour_xs), len(contour_ys))) for (i, x) in enumerate(contour_xs): points = jnp.array([[x, y] for y in contour_ys]) zs[i] = density(points) def update(model): plt.clf() plt.xlim([-3, 3]) plt.ylim([-3, 3])
# print('Delta:', sn.delta) # print('Psi: ', sn.psi) # print('Eigvals psi: ', jnp.linalg.eigvalsh(sn.psi)) # print('alpha:', sn.alpha) # print('Omega:', sn.omega) # print('Eigvals Omega: ', jnp.linalg.eigvalsh(sn.omega)) rng, key = random.split(rng) data = sn.sample(key, shape=(n, )) data_skew = pd.DataFrame(data, columns=['x', 'y']) x = jnp.linspace(jnp.min(data[:, 0]), jnp.max(data[:, 0]), l) y = jnp.linspace(jnp.min(data[:, 1]), jnp.max(data[:, 1]), l) xy = jnp.array(list(product(x, y))) Z_skew = sn.pdf(xy).reshape(l, l).T Z_norm = mvn.pdf(xy, jnp.zeros(p, ), cov=sn.cov).reshape(l, l).T g = sns.jointplot(data=data_skew, x='x', y='y', alpha=0.3) g.ax_joint.contour(x, y, Z_norm, colors='k', alpha=0.7, linestyles='dashed') g.ax_joint.contour(x, y, Z_skew, colors='k') plt.show() # print(sn.logpdf(data))