예제 #1
0
    def softmax_grad(self, softmax: jnp.ndarray) -> jnp.ndarray:
        """
        Description: Vectorized softmax Jacobian

        Args:
            softmax (jnp.ndarray)
        """
        s = softmax.reshape(-1, 1)
        return jnp.diagflat(s) - jnp.dot(s, s.T)
예제 #2
0
    def test_expm_precision(self, expm_type, dim, knn):
        key = jrandom.PRNGKey(0)
        embeddings = jrandom.normal(key, (dim, 32))
        x0 = jrandom.randint(key, (64, ), 0, dim)

        num_steps = 128

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear', beta_min=5e-3, beta_max=5e-2, num_steps=num_steps)

        diff = diffusion.NearestNeighborCachedDiffusion(
            dim,
            schedule,
            use_numpy=True,
            use_matrix_exponential=True,
            expm_type=expm_type,
            knn=knn)

        state = diff.update_state(embeddings)
        diff.set_state(state)

        neighbors = model_utils.get_nearest_neighbors(embeddings,
                                                      k=knn,
                                                      include_self=False,
                                                      num_chunks=10)

        matrix = jnp.zeros((dim, dim), jnp.float32)
        matrix = matrix.at[neighbors, jnp.arange(dim)[:, None]].set(1.)

        matrix = matrix + matrix.T
        transition_rate = matrix - jnp.diagflat(jnp.sum(matrix, axis=1))

        beta_min = diff.min_exponent

        for t in range(num_steps, 5):
            q_t = diff.get_qt_given_q0(x0, t, make_one_hot=True)

            power = diff.powers[t]
            transition = jax.scipy.linalg.expm(beta_min * power *
                                               transition_rate)
            expected = transition[x0]

            np.testing.assert_array_almost_equal(q_t, expected)
예제 #3
0
def diagflat(v, k=0):
  v = _remove_jaxarray(v)
  return JaxArray(jnp.diagflat(v, k))
예제 #4
0
 def chol_to_param(self, chol):
     chol = self._standardize(chol)
     return np.tril(chol, -1) + np.diagflat(
         self.diag_bij.inv(np.diagonal(chol)))
예제 #5
0
    def param_to_chol(self, param):
        param = self._standardize(param)

        return np.tril(param, -1) + np.diagflat(
            self.diag_bij(np.diagonal(param)))