示例#1
0
    def test_precision(self):
        dim = 1000
        key = jrandom.PRNGKey(0)
        embedding = jrandom.normal(key, (dim, 32))
        num_steps = 100

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear', beta_min=5e-4, beta_max=5e-3, num_steps=num_steps)

        diff = diffusion.NearestNeighborCachedDiffusion(dim,
                                                        schedule,
                                                        use_numpy=True)

        state = diff.update_state(embedding)
        diff.set_state(state)

        q = diff.get_qt_given_q0(jnp.array([0, 1, 2]),
                                 99,
                                 make_one_hot=True,
                                 return_logits=False)[0]
        expected = np.linalg.matrix_power(
            np.array(state['matrix_power_state'].cache[0], np.float64),
            diff.powers[99])[0]

        np.testing.assert_array_almost_equal(q, expected)
示例#2
0
    def test_state_init(self):
        """Tests that the discrete process predicted probabilities are correct."""

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=1e-3,
            beta_max=1e-2,
            num_steps=100,
        )

        diff = diffusion.NearestNeighborDiffusion(dim=4,
                                                  schedule=schedule,
                                                  knn=2)
        embeddings = jnp.array([[0., 0.], [1., 1.], [3., 3.], [1., 0.]])
        matrix = jnp.array([[0, 1, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1],
                            [1, 1, 0, 0]])
        mat = matrix + matrix.T

        for _ in range(diff.num_sinkhorn_iterations):
            mat = mat / mat.sum(1, keepdims=True)
            mat = mat / mat.sum(0, keepdims=True)

        mat = mat / mat.sum(0, keepdims=True)

        state = diff.update_state(embeddings)
        np.testing.assert_array_almost_equal(mat, state, decimal=3)
示例#3
0
    def test_autoregressive_diffusion(self):
        """Test the Diffusion noise diffusion."""
        seq_len = 100
        dim = 100
        sequence = jnp.arange(seq_len, dtype=jnp.int32)
        sequence2 = jnp.arange(seq_len, dtype=jnp.int32) + 1

        q0 = (losses.onehot(sequence, dim) + losses.onehot(sequence2, dim)) / 2

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=1e-3,
            beta_max=1e-1,
            num_steps=seq_len,
        )

        diff = diffusion.AutoRegressiveDiffusion(dim=dim, schedule=schedule)

        for t in range(1, seq_len):
            xt = diff.get_qt_given_q0(q0=sequence, t=t, make_one_hot=True)

            argmax = xt.argmax(-1)
            np.testing.assert_array_equal(argmax[:-t], sequence[:-t])
            np.testing.assert_array_equal(argmax[-t:], 99)

            xt = diff.get_qt_given_q0(q0=q0, t=t)

            np.testing.assert_array_equal(xt[:-t], q0[:-t])
            np.testing.assert_array_equal(xt[-t:][:, 99], 1.)

        key = jrandom.PRNGKey(0)
        _, sample = diff.sample_and_compute_posterior_q(key=key,
                                                        x_0=sequence,
                                                        t=0,
                                                        make_one_hot=True,
                                                        return_logits=False)

        np.testing.assert_array_equal(sample[:-1], sequence[:-1])
        np.testing.assert_array_equal(sample[-1], 99)

        qt, sample = diff.sample_and_compute_posterior_q(key=key,
                                                         x_0=q0,
                                                         t=10,
                                                         make_one_hot=False,
                                                         return_logits=False)

        sample_q = losses.onehot(sample, dim)
        np.testing.assert_array_equal(qt[:-12], sample_q[:-12])
        np.testing.assert_array_equal(qt[-11], q0[-11])
        np.testing.assert_array_equal(qt[-10:][:, 99], 1.0)
示例#4
0
    def test_slow_and_fast(self):
        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='standard',
            beta_min=5e-4,
            beta_max=5e-2,
            num_steps=100,
        )

        x0 = jnp.array([0, 1, 2])
        key = jrandom.PRNGKey(0)
        dim = 16

        fast_diff = diffusion.BetaDiagonalDiffusion(dim=dim,
                                                    schedule=schedule,
                                                    use_fast_inference=True)

        slow_diff = diffusion.BetaDiagonalDiffusion(dim=dim,
                                                    schedule=schedule,
                                                    use_fast_inference=False)

        for t in range(100):
            qt_slow = slow_diff.get_qt_matrix(t)
            qt_fast = fast_diff.get_qt_matrix(t)

            np.testing.assert_array_almost_equal(qt_slow, qt_fast, decimal=3)

            qt_slow = slow_diff.get_qt_given_q0(q0=x0, t=t, make_one_hot=True)
            qt_fast = fast_diff.get_qt_given_q0(q0=x0, t=t, make_one_hot=True)

            np.testing.assert_array_almost_equal(qt_slow, qt_fast, decimal=3)

            posterior_slow, samples_slow = slow_diff.sample_and_compute_posterior_q(
                key, x_0=x0, t=t, make_one_hot=True)
            posterior_fast, samples_fast = fast_diff.sample_and_compute_posterior_q(
                key, x_0=x0, t=t, make_one_hot=True)

            np.testing.assert_array_almost_equal(posterior_slow,
                                                 posterior_fast,
                                                 decimal=3)
            np.testing.assert_array_equal(samples_slow, samples_fast)

        qt = fast_diff.get_qt_given_q0(q0=x0, t=100, make_one_hot=True)
        np.testing.assert_allclose(qt, 1 / dim, rtol=1e-5)

        qt = slow_diff.get_qt_given_q0(q0=x0, t=100, make_one_hot=True)
        np.testing.assert_allclose(qt, 1 / dim, rtol=1e-5)
示例#5
0
    def test_large_matrices(self):
        """Tests precision for large matrices."""
        key = jrandom.PRNGKey(0)
        dim = 1000
        length = 64
        x0 = jrandom.randint(key, (length, ), 0, dim)

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=5e-4,
            beta_max=5e-2,
            num_steps=100,
        )

        diff = diffusion.MaskDiffusion(dim, schedule, use_fast_inference=True)
        fn = functools.partial(diff.get_qt_given_q0, make_one_hot=True)
        result = fn(x0, 100)
        np.testing.assert_array_almost_equal(result.sum(axis=-1), 1.0)
示例#6
0
    def test_compute_posterior(self):
        """Tests that the discrete process predicted probabilities are correct."""

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=1e-3,
            beta_max=1e-3,
            num_steps=100,
        )

        diff = diffusion.BetaDiagonalDiffusion(dim=100, schedule=schedule)

        inputs = jnp.ones((2, ), jnp.int32)
        q_t = diff.get_qt_given_q0(inputs, 0, make_one_hot=True)

        self.assertEqual(q_t.shape, (2, 100))
        self.assertAlmostEqual(float(q_t[0][1]), 1.0)
        self.assertAlmostEqual(float(q_t[0][0]), 0.0)
示例#7
0
    def test_toy_example(self):
        """Tests that the discrete process predicted probabilities are correct."""

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=1e-3,
            beta_max=1e-3,
            num_steps=100,
        )

        diff = diffusion.BetaDiagonalDiffusion(dim=2, schedule=schedule)

        key = jrandom.PRNGKey(0)
        transition = np.array(
            [[1 - 1e-3 / 2, 1e-3 / 2], [1e-3 / 2, 1 - 1e-3 / 2]],
            dtype=np.float64)
        np.testing.assert_array_almost_equal(transition, diff.get(0))

        mat_power = np.linalg.matrix_power(transition, 5)
        np.testing.assert_array_almost_equal(mat_power, diff.get_qt_matrix(5))

        ## test starting ins tate 0
        inputs = jnp.zeros((1, ), jnp.int32)
        probs = diff.get_qt_given_q0(inputs, t=5, make_one_hot=True)
        expected_probs = mat_power[:, 0]
        np.testing.assert_array_almost_equal(probs[0], expected_probs)

        ## test starting in state 1
        inputs = jnp.ones((1, ), jnp.int32)
        probs = diff.get_qt_given_q0(inputs, 5, make_one_hot=True)
        expected_probs = mat_power[:, 1]
        np.testing.assert_array_almost_equal(probs[0], expected_probs)

        probs, _ = diff.sample_and_compute_posterior_q(key,
                                                       inputs,
                                                       5,
                                                       return_logits=False,
                                                       samples=jnp.ones(
                                                           (1, 1), jnp.int32))

        expected_logits = transition[1] * mat_power[:, 1]
        expected_probs = expected_logits / expected_logits.sum()

        np.testing.assert_array_almost_equal(probs[0, 0], expected_probs)
示例#8
0
    def test_product_fast(self):
        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=1e-3,
            beta_max=1e-3,
            num_steps=100,
        )

        diff = diffusion.BetaDiagonalDiffusion(dim=100,
                                               schedule=schedule,
                                               use_fast_inference=True)

        self.assertTrue(diff.supports_efficient_inference())

        product = diff.get_qt_matrix(0)
        np.testing.assert_array_almost_equal(product, jnp.eye(100))

        product = diff.get_qt_matrix(1)
        np.testing.assert_array_almost_equal(product, diff.get(0))
示例#9
0
    def test_expm_precision(self, expm_type, dim, knn):
        key = jrandom.PRNGKey(0)
        embeddings = jrandom.normal(key, (dim, 32))
        x0 = jrandom.randint(key, (64, ), 0, dim)

        num_steps = 128

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear', beta_min=5e-3, beta_max=5e-2, num_steps=num_steps)

        diff = diffusion.NearestNeighborCachedDiffusion(
            dim,
            schedule,
            use_numpy=True,
            use_matrix_exponential=True,
            expm_type=expm_type,
            knn=knn)

        state = diff.update_state(embeddings)
        diff.set_state(state)

        neighbors = model_utils.get_nearest_neighbors(embeddings,
                                                      k=knn,
                                                      include_self=False,
                                                      num_chunks=10)

        matrix = jnp.zeros((dim, dim), jnp.float32)
        matrix = matrix.at[neighbors, jnp.arange(dim)[:, None]].set(1.)

        matrix = matrix + matrix.T
        transition_rate = matrix - jnp.diagflat(jnp.sum(matrix, axis=1))

        beta_min = diff.min_exponent

        for t in range(num_steps, 5):
            q_t = diff.get_qt_given_q0(x0, t, make_one_hot=True)

            power = diff.powers[t]
            transition = jax.scipy.linalg.expm(beta_min * power *
                                               transition_rate)
            expected = transition[x0]

            np.testing.assert_array_almost_equal(q_t, expected)
示例#10
0
    def test_beta_diagonal_diffusion(self):
        """Test the Diffusion noise diffusion."""

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=1e-3,
            beta_max=1e-1,
            num_steps=100,
        )

        diff = diffusion.BetaDiagonalDiffusion(dim=100, schedule=schedule)

        expected = (1 - 1e-3) * jnp.eye(100) + 1e-3 * jnp.ones(
            (100, 100)) / 100
        np.testing.assert_array_almost_equal(diff.get(0), expected)

        expected = (1 - 1e-1) * jnp.eye(100) + 1e-1 * jnp.ones(
            (100, 100)) / 100
        np.testing.assert_array_almost_equal(diff.get(100), expected)
示例#11
0
    def test_all_models(self, diffusion_cls):
        """Test the Diffusion noise diffusion."""

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='standard',
            beta_min=1e-3,
            beta_max=1e-3,
            num_steps=100,
        )

        dim = 100
        length = 100
        key = jrandom.PRNGKey(0)

        x0 = jrandom.randint(key, (length, ), 0, dim)
        diff = diffusion_cls(dim=100, schedule=schedule)

        if hasattr(diffusion, 'get'):
            np.testing.assert_allclose(diff.get(0).sum(0), 1.0, rtol=1e-6)
            np.testing.assert_allclose(diff.get(10).sum(0), 1.0, rtol=1e-6)
            np.testing.assert_allclose(diff.get(99).sum(0), 1.0, rtol=1e-6)

            np.testing.assert_allclose(diff.get_qt_matrix(0),
                                       jnp.eye(100),
                                       rtol=1e-6)

        expected = losses.onehot(x0, dim)
        result = diff.get_qt_given_q0(q0=x0, t=0, make_one_hot=True)
        np.testing.assert_allclose(result, expected)

        expected = jax.nn.softmax(jrandom.normal(key, (length, dim)))
        result = diff.get_qt_given_q0(q0=expected, t=0, make_one_hot=False)
        np.testing.assert_allclose(result, expected)

        q0 = jax.nn.softmax(jrandom.normal(key, (length, dim)))
        result = diff.get_qt_given_q0(q0=q0, t=0, make_one_hot=False)
        np.testing.assert_allclose(result.sum(axis=-1), 1.0, rtol=1e-6)

        expected = diff.stationary_probs(x0.shape)
        result = diff.get_qt_given_q0(q0=x0, t=100, make_one_hot=True)
        np.testing.assert_allclose(result, expected)
示例#12
0
    def test_cached_reverse(self):
        """Test cached diffusion."""
        num_steps = 100
        dim = 20
        key = jrandom.PRNGKey(0)

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear', beta_min=5e-4, beta_max=5e-3, num_steps=num_steps)

        embedding = jrandom.normal(key, (dim, 32))

        slow = diffusion.NearestNeighborCachedDiffusion(dim,
                                                        schedule,
                                                        use_slow_get=True)
        fast = diffusion.NearestNeighborCachedDiffusion(dim,
                                                        schedule,
                                                        use_slow_get=False)

        state = slow.update_state(embedding)

        slow.set_state(state)
        fast.set_state(state)

        slow_q, slow_sample = slow.sample_and_compute_posterior_q(
            key, jnp.array([0, 1, 2]), 30, make_one_hot=True)
        fast_q, fast_sample = fast.sample_and_compute_posterior_q(
            key, jnp.array([0, 1, 2]), 30, make_one_hot=True)

        np.testing.assert_array_almost_equal(slow_q, fast_q, decimal=3)
        np.testing.assert_array_equal(slow_sample, fast_sample)

        reverse_probs = slow.qt_reverse(jnp.array([0, 1, 2]),
                                        30,
                                        make_one_hot=True)[0]
        expected = slow.get(30)[0]
        np.testing.assert_array_almost_equal(reverse_probs,
                                             expected,
                                             decimal=3)
示例#13
0
    def test_band_diagonal(self):
        """Test the Diffusion noise diffusion."""

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=1e-3,
            beta_max=1e-3,
            num_steps=100,
        )

        diff = diffusion.BandDiagonalDiffusion(dim=100,
                                               schedule=schedule,
                                               width=5)

        np.testing.assert_allclose(diff.get(0).sum(0), 1.0, rtol=1e-6)
        np.testing.assert_allclose(diff.get(10).sum(0), 1.0, rtol=1e-6)
        np.testing.assert_allclose(diff.get(0)[0, 0],
                                   1 - schedule(0) + schedule(0) / 3,
                                   rtol=1e-6)

        np.testing.assert_allclose(diff.get_qt_matrix(0),
                                   jnp.eye(100),
                                   rtol=1e-6)
示例#14
0
    def test_sample_and_posterior(self):
        """Tests that the discrete process predicted probabilities are correct."""

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=1e-3,
            beta_max=1e-3,
            num_steps=100,
        )

        diff = diffusion.BetaDiagonalDiffusion(dim=100, schedule=schedule)

        key = jrandom.PRNGKey(0)
        inputs = jnp.ones((1, ), jnp.int32)

        probs, sample = diff.sample_and_compute_posterior_q(
            key, inputs, 0, return_logits=False)

        self.assertEqual(probs.shape, (1, 100))
        self.assertAlmostEqual(probs[0, 1], 1.0, places=5)

        self.assertEqual(sample.shape, (1, ))
        np.testing.assert_array_equal(sample, jnp.array([1]))
示例#15
0
    def test_mask_diffusion(self):
        """Test the Diffusion noise diffusion."""

        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='linear',
            beta_min=1e-3,
            beta_max=1e-1,
            num_steps=100,
        )

        diff = diffusion.MaskDiffusion(dim=100, schedule=schedule)

        np.testing.assert_allclose(diff.get(0).sum(0), 1.0, rtol=1e-6)
        np.testing.assert_allclose(diff.get(10).sum(0), 1.0, rtol=1e-6)
        np.testing.assert_allclose(diff.get(0)[0, 0],
                                   1.0 - schedule(0),
                                   rtol=1e-6)
        np.testing.assert_allclose(diff.get(1)[0, 0],
                                   1.0 - schedule(1),
                                   rtol=1e-6)

        np.testing.assert_allclose(diff.get_qt_matrix(0),
                                   jnp.eye(100),
                                   rtol=1e-6)
示例#16
0
    def test_slow_and_fast(self):
        """Compares fast and slow inference."""
        schedule = diffusion.create_discrete_diffusion_schedule(
            kind='standard',
            beta_min=5e-4,
            beta_max=5e-2,
            num_steps=100,
        )

        dim = 16
        length = 16

        fast_diff = diffusion.MaskDiffusion(dim=dim,
                                            schedule=schedule,
                                            use_fast_inference=True)

        slow_diff = diffusion.MaskDiffusion(dim=dim,
                                            schedule=schedule,
                                            use_fast_inference=False)

        key = jrandom.PRNGKey(0)
        x0 = jrandom.randint(key, (length, ), 0, dim)

        for t in range(100):
            qt_slow = slow_diff.get_qt_matrix(t)
            qt_fast = fast_diff.get_qt_matrix(t)

            np.testing.assert_array_almost_equal(qt_slow, qt_fast, decimal=3)

            qt_slow = slow_diff.get_qt_given_q0(q0=x0, t=t, make_one_hot=True)
            qt_fast = fast_diff.get_qt_given_q0(q0=x0, t=t, make_one_hot=True)

            np.testing.assert_array_almost_equal(qt_slow, qt_fast, decimal=3)

            np.testing.assert_array_almost_equal(qt_slow.sum(axis=-1),
                                                 1.,
                                                 decimal=3)
            np.testing.assert_array_almost_equal(qt_fast.sum(axis=-1),
                                                 1.,
                                                 decimal=3)

            posterior_slow, samples_slow = slow_diff.sample_and_compute_posterior_q(
                key, x_0=x0, t=t, make_one_hot=True)
            posterior_fast, samples_fast = fast_diff.sample_and_compute_posterior_q(
                key, x_0=x0, t=t, make_one_hot=True)

            np.testing.assert_array_almost_equal(posterior_slow,
                                                 posterior_fast,
                                                 decimal=3)
            np.testing.assert_array_equal(samples_slow, samples_fast)

        qt = fast_diff.get_qt_given_q0(q0=x0, t=100, make_one_hot=True)
        np.testing.assert_allclose(qt,
                                   losses.onehot(jnp.full(x0.shape, dim - 1),
                                                 dim),
                                   rtol=1e-6)

        qt = slow_diff.get_qt_given_q0(q0=x0, t=100, make_one_hot=True)
        np.testing.assert_allclose(qt,
                                   losses.onehot(jnp.full(x0.shape, dim - 1),
                                                 dim),
                                   rtol=1e-6)