Example #1
0
    def test_draw_samples_with_broadcast_no_numpy_verification(self, dtype, a, a_is_samples, rv_shape, num_samples):
        a_mx = mx.nd.array(a, dtype=dtype)
        if not a_is_samples:
            a_mx = add_sample_dimension(mx.nd, a_mx)

        dirichlet = Dirichlet.define_variable(alpha=Variable(), shape=rv_shape, dtype=dtype).factor
        variables = {dirichlet.alpha.uuid: a_mx}
        draw_samples_rt = dirichlet.draw_samples(F=mx.nd, variables=variables, num_samples=num_samples)

        assert np.issubdtype(draw_samples_rt.dtype, dtype)
        assert draw_samples_rt.shape == (5,) + rv_shape
Example #2
0
    def test_draw_samples_no_broadcast(self, dtype, a, a_is_samples, rv_shape, num_samples):
        a_mx = mx.nd.array(a, dtype=dtype)
        if not a_is_samples:
            a_mx = add_sample_dimension(mx.nd, a_mx)

        rand = np.random.gamma(shape=a, scale=np.ones(a.shape), size=(num_samples,)+rv_shape)
        draw_samples_np = rand / np.sum(rand)
        rand_gen = MockMXNetRandomGenerator(mx.nd.array(rand.flatten(), dtype=dtype))

        dirichlet = Dirichlet.define_variable(alpha=Variable(), shape=rv_shape, dtype=dtype, rand_gen=rand_gen).factor
        variables = {dirichlet.alpha.uuid: a_mx}
        draw_samples_rt = dirichlet.draw_samples(F=mx.nd, variables=variables, num_samples=num_samples)

        assert np.issubdtype(draw_samples_rt.dtype, dtype)
        assert np.allclose(draw_samples_np, draw_samples_rt.asnumpy())
Example #3
0
    def test_log_pdf_with_broadcast(self, dtype, a, a_is_samples, rv, rv_is_samples, num_samples):
        # Add sample dimension if varaible is not samples
        a_mx = mx.nd.array(a, dtype=dtype)
        if not a_is_samples:
            a_mx = add_sample_dimension(mx.nd, a_mx)
        a = a_mx.asnumpy()

        rv_mx = mx.nd.array(rv, dtype=dtype)
        if not rv_is_samples:
            rv_mx = add_sample_dimension(mx.nd, rv_mx)
        rv = rv_mx.asnumpy()

        is_samples_any = a_is_samples or rv_is_samples
        rv_shape = rv.shape[1:]

        n_dim = 1 + len(rv.shape) if is_samples_any and not rv_is_samples else len(rv.shape)
        a_np = np.broadcast_to(a, (num_samples, 3, 2))
        rv_np = numpy_array_reshape(rv, is_samples_any, n_dim)

        # Initialize rand_gen
        rand = np.random.rand(num_samples, *rv_shape)
        rand_gen = MockMXNetRandomGenerator(mx.nd.array(rand.flatten(), dtype=dtype))

        # Calculate correct Dirichlet logpdf
        r = []
        for s in range(len(rv_np)):
            a = []
            for i in range(len(rv_np[s])):
                a.append(scipy_dirichlet.logpdf(rv_np[s][i]/sum(rv_np[s][i]), a_np[s][i]))
            r.append(a)
        log_pdf_np = np.array(r)

        dirichlet = Dirichlet.define_variable(alpha=Variable(), shape=rv_shape, dtype=dtype, rand_gen=rand_gen).factor
        variables = {dirichlet.alpha.uuid: a_mx, dirichlet.random_variable.uuid: rv_mx}
        log_pdf_rt = dirichlet.log_pdf(F=mx.nd, variables=variables)

        assert np.issubdtype(log_pdf_rt.dtype, dtype)
        assert array_has_samples(mx.nd, log_pdf_rt) == is_samples_any
        if is_samples_any:
            assert get_num_samples(mx.nd, log_pdf_rt) == num_samples, (get_num_samples(mx.nd, log_pdf_rt), num_samples)
        assert np.allclose(log_pdf_np, log_pdf_rt.asnumpy())