Пример #1
0
    def test_sample(self):
        tf.set_random_seed(123456)

        mean = tf.constant([0., 1., 2.], dtype=tf.float64)
        normal = Normal(mean=mean, std=tf.constant(1., dtype=tf.float64))
        flow = QuadraticFlow(2., 5.)
        distrib = FlowDistribution(normal, flow)

        # test ordinary sample, is_reparameterized = None
        y = distrib.sample(n_samples=5)
        self.assertTrue(y.is_reparameterized)
        grad = tf.gradients(y * 1., mean)[0]
        self.assertIsNotNone(grad)
        self.assertEqual(get_static_shape(y), (5, 3))
        self.assertIsNotNone(y._self_log_prob)

        x, log_det = flow.inverse_transform(y)
        log_py = normal.log_prob(x) + log_det

        with self.test_session() as sess:
            np.testing.assert_allclose(*sess.run([log_py, y.log_prob()]),
                                       rtol=1e-5)

        # test stop gradient sample, is_reparameterized = False
        y = distrib.sample(n_samples=5, is_reparameterized=False)
        self.assertFalse(y.is_reparameterized)
        grad = tf.gradients(y * 1., mean)[0]
        self.assertIsNone(grad)
Пример #2
0
    def test_invert_flow(self):
        with self.test_session() as sess:
            # test invert a normal flow
            flow = QuadraticFlow(2., 5.)
            inv_flow = flow.invert()

            self.assertIsInstance(inv_flow, InvertFlow)
            self.assertEqual(inv_flow.x_value_ndims, 0)
            self.assertEqual(inv_flow.y_value_ndims, 0)
            self.assertFalse(inv_flow.require_batch_dims)

            test_x = np.arange(12, dtype=np.float32) + 1.
            test_y, test_log_det = quadratic_transform(npyops, test_x, 2., 5.)

            self.assertFalse(flow._has_built)
            y, log_det_y = inv_flow.inverse_transform(tf.constant(test_x))
            self.assertTrue(flow._has_built)

            np.testing.assert_allclose(sess.run(y), test_y)
            np.testing.assert_allclose(sess.run(log_det_y), test_log_det)
            invertible_flow_standard_check(self, inv_flow, sess, test_y)

            # test invert an InvertFlow
            inv_inv_flow = inv_flow.invert()
            self.assertIs(inv_inv_flow, flow)

            # test use with FlowDistribution
            normal = Normal(mean=1., std=2.)
            inv_flow = QuadraticFlow(2., 5.).invert()
            distrib = FlowDistribution(normal, inv_flow)
            distrib_log_det = distrib.log_prob(test_x)
            np.testing.assert_allclose(*sess.run(
                [distrib_log_det,
                 normal.log_prob(test_y) + test_log_det]))
Пример #3
0
    def test_log_prob(self):
        mean = tf.constant([0., 1., 2.], dtype=tf.float64)
        normal = Normal(mean=mean, std=tf.constant(1., dtype=tf.float64))
        flow = QuadraticFlow(2., 5.)
        flow.build(tf.constant(0., dtype=tf.float64))
        distrib = FlowDistribution(normal, flow)

        y = tf.constant([1., -1., 2.], dtype=tf.float64)
        x, log_det = flow.inverse_transform(y)
        log_py = normal.log_prob(x) + log_det
        py = tf.exp(log_py)

        log_prob = distrib.log_prob(y)
        self.assertIsInstance(log_prob, FlowDistributionDerivedTensor)
        self.assertIsInstance(log_prob.flow_origin, StochasticTensor)
        self.assertIs(log_prob.flow_origin.distribution, normal)

        prob = distrib.prob(y)
        self.assertIsInstance(prob, FlowDistributionDerivedTensor)
        self.assertIsInstance(prob.flow_origin, StochasticTensor)
        self.assertIs(prob.flow_origin.distribution, normal)

        with self.test_session() as sess:
            np.testing.assert_allclose(
                *sess.run([log_prob.flow_origin, x]), rtol=1e-5)
            np.testing.assert_allclose(
                *sess.run([log_py, log_prob]), rtol=1e-5)
            np.testing.assert_allclose(
                *sess.run([py, prob]), rtol=1e-5)
Пример #4
0
    def test_log_prob(self):
        mean = tf.constant([0., 1., 2.], dtype=tf.float64)
        normal = Normal(mean=mean, std=tf.constant(1., dtype=tf.float64))
        flow = QuadraticFlow(2., 5.)
        flow.build(tf.constant(0., dtype=tf.float64))
        distrib = FlowDistribution(normal, flow)

        y = tf.constant([1., -1., 2.], dtype=tf.float64)
        x, log_det = flow.inverse_transform(y)
        log_py = normal.log_prob(x) + log_det

        with self.test_session() as sess:
            np.testing.assert_allclose(*sess.run([log_py,
                                                  distrib.log_prob(y)]),
                                       rtol=1e-5)
Пример #5
0
    def test_property(self):
        normal = Normal(mean=tf.constant([0., 1., 2.], dtype=tf.float64),
                        std=tf.constant(1., dtype=tf.float64))
        flow = QuadraticFlow(2., 5., value_ndims=1)
        distrib = FlowDistribution(normal, flow)

        self.assertIs(distrib.flow, flow)
        self.assertIs(distrib.base_distribution, normal)
        self.assertEqual(distrib.dtype, tf.float64)
        self.assertTrue(distrib.is_continuous)
        self.assertTrue(distrib.is_reparameterized)
        self.assertEqual(distrib.value_ndims, 1)

        # self.assertEqual(distrib.get_value_shape(), normal.get_value_shape())
        # self.assertEqual(distrib.get_batch_shape(), normal.get_batch_shape())
        # with self.test_session() as sess:
        #     np.testing.assert_equal(
        #         *sess.run([distrib.value_shape, normal.value_shape]))
        #     np.testing.assert_equal(
        #         *sess.run([distrib.batch_shape, normal.batch_shape]))

        # test is_reparameterized = False
        normal = Normal(mean=[0., 1., 2.], std=1., is_reparameterized=False)
        distrib = FlowDistribution(normal, flow)
        self.assertFalse(distrib.is_reparameterized)

        # test y_value_ndims = 2
        distrib = FlowDistribution(normal, ReshapeFlow(1, [-1, 1]))
        self.assertEqual(distrib.value_ndims, 2)
Пример #6
0
    def test_errors(self):
        # errors in constructor
        normal = Normal(mean=tf.zeros([3]), std=1.)
        with pytest.raises(TypeError, match='`flow` is not an instance of '
                                            '`BaseFlow`: 123'):
            _ = FlowDistribution(normal, 123)

        flow = QuadraticFlow(2., 5., value_ndims=1)
        with pytest.raises(ValueError,
                           match='cannot be transformed by a flow, because '
                                 'it is not continuous'):
            _ = FlowDistribution(Categorical(logits=[0., 1., 2.]), flow)

        with pytest.raises(ValueError,
                           match='cannot be transformed by a flow, because '
                                 'its data type is not float'):
            _ = FlowDistribution(Mock(normal, dtype=tf.int32), flow)

        with pytest.raises(ValueError,
                           match='cannot be transformed by flow .*, because '
                                 'distribution.value_ndims is larger than '
                                 'flow.x_value_ndims'):
            _ = FlowDistribution(Mock(normal, value_ndims=2), flow)

        # errors in sample
        distrib = FlowDistribution(normal, flow)
        with pytest.raises(RuntimeError,
                           match='`FlowDistribution` requires `compute_prob` '
                                 'not to be False'):
            _ = distrib.sample(compute_density=False)
Пример #7
0
    def test_add_with_flow(self):
        normal = Normal(mean=tf.constant([0., 1., 2.]), std=1.)
        flow = QuadraticFlow(2., 5.)

        # test add with sample
        net = BayesianNet()
        x = net.add('x', normal, flow=flow)
        self.assertIsInstance(x.distribution, FlowDistribution)
        self.assertIs(x.distribution.flow, flow)

        # ensure non-invertible flow cannot be added with observed var
        class _Flow(BaseFlow):
            @property
            def explicitly_invertible(self):
                return False

        net = BayesianNet({'x': tf.zeros([5, 3])})
        with pytest.raises(TypeError,
                           match='The observed variable \'x\' expects `flow` '
                           'to be explicitly invertible, but it is not'):
            _ = net.add('x', normal, flow=_Flow(x_value_ndims=0))

        # test add observed with flow
        x = net.add('x', normal, flow=flow)
        self.assertIsInstance(x.distribution, FlowDistribution)
        self.assertIs(x.distribution.flow, flow)
Пример #8
0
    def test_sequential_with_quadratic_flows(self):
        n_layers = 3
        flow1 = MultiLayerQuadraticFlow(n_layers)
        flow2 = SequentialFlow([
            QuadraticFlow(i + 1., i * 2. + 1.)
            for i in range(n_layers)
        ])
        self.assertTrue(flow2.explicitly_invertible)
        self.assertEqual(len(flow2.flows), n_layers)
        for i in range(n_layers):
            self.assertEqual(flow2.flows[i].a, i + 1.)
            self.assertEqual(flow2.flows[i].b, i * 2. + 1.)

        x = tf.range(12, dtype=tf.float32) + 1.

        with self.test_session() as sess:
            invertible_flow_standard_check(self, flow2, sess, x)

            # transform
            y1, log_det_y1 = flow1.transform(x)
            y2, log_det_y2 = flow2.transform(x)
            np.testing.assert_allclose(*sess.run([y1, y2]))
            np.testing.assert_allclose(*sess.run([log_det_y1, log_det_y2]))

            # inverse transform
            x1, log_det_x1 = flow1.inverse_transform(y1)
            x2, log_det_x2 = flow1.inverse_transform(y2)
            np.testing.assert_allclose(*sess.run([x1, x2]))
            np.testing.assert_allclose(*sess.run([log_det_x1, log_det_x2]))
Пример #9
0
    def test_sample_value_and_group_ndims(self):
        tf.set_random_seed(123456)

        mean = tf.constant([0., 1., 2.], dtype=tf.float64)
        normal = Normal(mean=mean, std=tf.constant(1., dtype=tf.float64))

        with self.test_session() as sess:
            # test value_ndims = 0, group_ndims = 1
            flow = QuadraticFlow(2., 5.)
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 0)

            y = distrib.sample(n_samples=5, group_ndims=1)
            self.assertTupleEqual(get_static_shape(y), (5, 3))
            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (5, 3))
            log_py = tf.reduce_sum(normal.log_prob(x) + log_det, axis=-1)

            np.testing.assert_allclose(*sess.run([y.log_prob(), log_py]),
                                       rtol=1e-5)

            # test value_ndims = 1, group_ndims = 0
            flow = QuadraticFlow(2., 5., value_ndims=1)
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 1)

            y = distrib.sample(n_samples=5, group_ndims=0)
            self.assertTupleEqual(get_static_shape(y), (5, 3))
            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (5,))
            log_py = log_det + tf.reduce_sum(normal.log_prob(x), axis=-1)

            np.testing.assert_allclose(*sess.run([y.log_prob(), log_py]),
                                       rtol=1e-5)

            # test value_ndims = 1, group_ndims = 1
            flow = QuadraticFlow(2., 5., value_ndims=1)
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 1)

            y = distrib.sample(n_samples=5, group_ndims=1)
            self.assertTupleEqual(get_static_shape(y), (5, 3))
            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (5,))
            log_py = tf.reduce_sum(
                log_det + tf.reduce_sum(normal.log_prob(x), axis=-1))

            np.testing.assert_allclose(*sess.run([y.log_prob(), log_py]),
                                       rtol=1e-5)
Пример #10
0
    def test_property(self):
        class _Flow(BaseFlow):
            @property
            def explicitly_invertible(self):
                return False

        flow = SequentialFlow([
            _Flow(x_value_ndims=1, y_value_ndims=2),
            _Flow(x_value_ndims=2, y_value_ndims=3),
        ])
        self.assertFalse(flow.explicitly_invertible)
        self.assertEqual(flow.x_value_ndims, 1)
        self.assertEqual(flow.y_value_ndims, 3)

        flow = SequentialFlow([
            QuadraticFlow(2., 3.),
            _Flow(x_value_ndims=0),
        ])
        self.assertFalse(flow.explicitly_invertible)
Пример #11
0
    def test_errors(self):
        normal = Normal(mean=0., std=1.)
        with pytest.raises(TypeError,
                           match='`flow` is not an instance of '
                           '`BaseFlow`: 123'):
            _ = FlowDistribution(normal, 123)

        flow = QuadraticFlow(2., 5.)
        with pytest.raises(ValueError,
                           match='cannot be transformed by a flow, because '
                           'it is not continuous'):
            _ = FlowDistribution(Categorical(logits=[0., 1., 2.]), flow)
        with pytest.raises(ValueError,
                           match='cannot be transformed by a flow, because '
                           'its data type is not float'):
            _ = FlowDistribution(Mock(normal, dtype=tf.int32), flow)

        distrib = FlowDistribution(normal, flow)
        with pytest.raises(RuntimeError,
                           match='`FlowDistribution` requires `compute_prob` '
                           'not to be False'):
            _ = distrib.sample(compute_density=False)
Пример #12
0
    def test_errors(self):
        # errors from the constructor
        with pytest.raises(TypeError,
                           match='`left` must be an instance of '
                           '`BaseFlow`'):
            _ = SplitFlow(-1, object())

        with pytest.raises(TypeError,
                           match='`right` must be an instance of '
                           '`BaseFlow`'):
            _ = SplitFlow(-1, QuadraticFlow(2., 3.), right=object())

        with pytest.raises(ValueError,
                           match='`left` and `right` must have same `x_value_'
                           'ndims` and `y_value_ndims`'):
            _ = SplitFlow(-1,
                          left=QuadraticFlow(2., 3., value_ndims=2),
                          right=QuadraticFlow(2., 3., value_ndims=3))

        with pytest.raises(ValueError,
                           match='`x_value_ndims` != `y_value_ndims`, thus '
                           '`join_axis` must be specified.'):
            _ = SplitFlow(-2,
                          left=SequentialFlow([
                              QuadraticFlow(2., 3., value_ndims=2),
                              ReshapeFlow(2, [-1])
                          ]))

        with pytest.raises(ValueError,
                           match='`x_value_ndims` != `y_value_ndims`, thus '
                           '`right` must be specified.'):
            _ = SplitFlow(-2,
                          left=SequentialFlow([
                              QuadraticFlow(2., 3., value_ndims=2),
                              ReshapeFlow(2, [-1])
                          ]),
                          join_axis=-1)

        # errors from `build`
        with pytest.raises(ValueError,
                           match='`split_axis` out of range, or not covered '
                           'by `x_value_ndims`'):
            flow = SplitFlow(split_axis=-3,
                             left=QuadraticFlow(2., 3., value_ndims=2))
            flow.build(tf.zeros([3, 4, 5, 12]))

        with pytest.raises(ValueError,
                           match='`split_axis` out of range, or not covered '
                           'by `x_value_ndims`'):
            flow = SplitFlow(split_axis=-5,
                             left=QuadraticFlow(2., 3., value_ndims=2))
            flow.build(tf.zeros([3, 4, 5, 12]))

        with pytest.raises(ValueError,
                           match='The split axis of `input` must '
                           'be at least 2'):
            flow = SplitFlow(split_axis=-1,
                             left=QuadraticFlow(2., 3., value_ndims=1))
            flow.build(tf.zeros([3, 4, 5, 1]))

        # errors from `transform`
        with pytest.raises(RuntimeError,
                           match='`y_left.ndims` != `y_right.ndims`'):
            f1 = ReshapeFlow(x_value_ndims=1, y_value_shape=[-1, 1])
            f1._y_value_ndims = 1  # hack for passing constructor
            flow = SplitFlow(split_axis=-1,
                             join_axis=-1,
                             left=QuadraticFlow(2., 3., value_ndims=1),
                             right=SequentialFlow(
                                 [QuadraticFlow(1.5, 3., value_ndims=1), f1]))
            flow.transform(tf.zeros([3, 4, 5, 12]))

        with pytest.raises(ValueError,
                           match='`join_axis` out of range, or not '
                           'covered by `y_value_ndims`'):
            flow = SplitFlow(split_axis=-1,
                             join_axis=-5,
                             left=QuadraticFlow(2., 3., value_ndims=1))
            flow.transform(tf.zeros([3, 4, 5, 12]))
Пример #13
0
    def test_different_value_ndims(self):
        def reshape_tail(x, value_ndims, shape):
            batch_shape = x.shape
            if value_ndims > 0:
                batch_shape = batch_shape[:-value_ndims]
            return np.reshape(x, batch_shape + tuple(shape))

        def split_transform(x,
                            split_axis,
                            join_axis,
                            x_value_ndims,
                            y_shape,
                            a1,
                            b1,
                            a2=None,
                            b2=None):
            n1 = x.shape[split_axis] // 2
            n2 = x.shape[split_axis] - n1
            x1, x2 = np.split(x, [n1], axis=split_axis)
            y1, log_det1 = quadratic_transform(npyops, x1, a1, b1)
            if a2 is not None:
                y2, log_det2 = quadratic_transform(npyops, x2, a2, b2)
            else:
                y2, log_det2 = x2, np.zeros_like(x, dtype=np.float64)

            y1 = reshape_tail(y1, x_value_ndims, y_shape)
            y2 = reshape_tail(y2, x_value_ndims, y_shape)
            y = np.concatenate([y1, y2], axis=join_axis)

            if x_value_ndims > 0:
                reduce_axis = tuple(range(-x_value_ndims, 0))
                log_det1 = np.sum(log_det1, axis=reduce_axis)
                log_det2 = np.sum(log_det2, axis=reduce_axis)
            log_det = log_det1 + log_det2

            return y, log_det

        with self.test_session() as sess:
            np.random.seed(1234)
            x = 10. * np.random.normal(size=[3, 4, 5, 12]).astype(np.float64)

            # 2 -> 3, x_value_ndims = 3, y_value_ndims = 4
            x_ph = tf.placeholder(dtype=tf.float64, shape=[None] * 4)
            flow = SplitFlow(split_axis=-2,
                             join_axis=2,
                             left=SequentialFlow([
                                 QuadraticFlow(2., 5., value_ndims=3),
                                 ReshapeFlow(3, [4, -1, 2, 6]),
                             ]),
                             right=SequentialFlow([
                                 QuadraticFlow(1.5, 3., value_ndims=3),
                                 ReshapeFlow(3, [4, -1, 2, 6]),
                             ]))
            self.assertEqual(flow.x_value_ndims, 3)
            self.assertEqual(flow.y_value_ndims, 4)

            y, log_det = split_transform(x,
                                         split_axis=-2,
                                         join_axis=-3,
                                         x_value_ndims=3,
                                         y_shape=[4, -1, 2, 6],
                                         a1=2.,
                                         b1=5.,
                                         a2=1.5,
                                         b2=3.)
            y_out, log_det_out = sess.run(flow.transform(x_ph),
                                          feed_dict={x_ph: x})

            np.testing.assert_allclose(y_out, y)
            np.testing.assert_allclose(log_det_out, log_det)

            invertible_flow_standard_check(self,
                                           flow,
                                           sess,
                                           x_ph,
                                           feed_dict={x_ph: x})
Пример #14
0
    def test_equal_value_ndims(self):
        def split_transform(x,
                            split_axis,
                            value_ndims,
                            a1,
                            b1,
                            a2=None,
                            b2=None):
            n1 = x.shape[split_axis] // 2
            n2 = x.shape[split_axis] - n1
            x1, x2 = np.split(x, [n1], axis=split_axis)
            y1, log_det1 = quadratic_transform(npyops, x1, a1, b1)
            if a2 is not None:
                y2, log_det2 = quadratic_transform(npyops, x2, a2, b2)
            else:
                y2, log_det2 = x2, np.zeros_like(x, dtype=np.float64)
            y = np.concatenate([y1, y2], axis=split_axis)
            if value_ndims > 0:
                reduce_axis = tuple(range(-value_ndims, 0))
                log_det1 = np.sum(log_det1, axis=reduce_axis)
                log_det2 = np.sum(log_det2, axis=reduce_axis)
            log_det = log_det1 + log_det2
            return y, log_det

        with self.test_session() as sess:
            np.random.seed(1234)
            x = 10. * np.random.normal(size=[3, 4, 5, 6]).astype(np.float64)

            # static input, split_axis = -1, value_ndims = 1, right = None
            flow = SplitFlow(-1, QuadraticFlow(2., 5., value_ndims=1))
            self.assertEqual(flow.x_value_ndims, 1)
            self.assertEqual(flow.y_value_ndims, 1)

            y, log_det = split_transform(x,
                                         split_axis=-1,
                                         value_ndims=1,
                                         a1=2.,
                                         b1=5.)
            y_out, log_det_out = sess.run(flow.transform(x))

            np.testing.assert_allclose(y_out, y)
            np.testing.assert_allclose(log_det_out, log_det)

            invertible_flow_standard_check(self,
                                           flow,
                                           sess,
                                           x,
                                           rtol=1e-4,
                                           atol=1e-5)

            # dynamic input, split_axis = -2, value_ndims = 2, right = None
            x_ph = tf.placeholder(dtype=tf.float64, shape=[None] * 4)
            flow = SplitFlow(-2, QuadraticFlow(2., 5., value_ndims=2))
            self.assertEqual(flow.x_value_ndims, 2)
            self.assertEqual(flow.y_value_ndims, 2)

            y, log_det = split_transform(x,
                                         split_axis=-2,
                                         value_ndims=2,
                                         a1=2.,
                                         b1=5.)
            y_out, log_det_out = sess.run(flow.transform(x_ph),
                                          feed_dict={x_ph: x})

            np.testing.assert_allclose(y_out, y)
            np.testing.assert_allclose(log_det_out, log_det)

            invertible_flow_standard_check(self,
                                           flow,
                                           sess,
                                           x_ph,
                                           feed_dict={x_ph: x})

            # dynamic input, split_axis = 2, value_ndims = 3
            x_ph = tf.placeholder(dtype=tf.float64, shape=[None] * 4)
            flow = SplitFlow(split_axis=2,
                             left=QuadraticFlow(2., 5., value_ndims=3),
                             right=QuadraticFlow(1.5, 3., value_ndims=3))
            self.assertEqual(flow.x_value_ndims, 3)
            self.assertEqual(flow.y_value_ndims, 3)

            y, log_det = split_transform(x,
                                         split_axis=2,
                                         value_ndims=3,
                                         a1=2.,
                                         b1=5.,
                                         a2=1.5,
                                         b2=3.)
            y_out, log_det_out = sess.run(flow.transform(x_ph),
                                          feed_dict={x_ph: x})

            np.testing.assert_allclose(y_out, y)
            np.testing.assert_allclose(log_det_out, log_det)

            invertible_flow_standard_check(self,
                                           flow,
                                           sess,
                                           x_ph,
                                           feed_dict={x_ph: x})
Пример #15
0
    def test_log_prob_value_and_group_ndims(self):
        tf.set_random_seed(123456)

        mean = tf.constant([0., 1., 2.], dtype=tf.float64)
        normal = Normal(mean=mean, std=tf.constant(1., dtype=tf.float64))
        y = tf.random_normal(shape=[2, 5, 3], dtype=tf.float64)

        with self.test_session() as sess:
            # test value_ndims = 0, group_ndims = 1
            flow = QuadraticFlow(2., 5.)
            flow.build(tf.zeros([2, 5, 3], dtype=tf.float64))
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 0)

            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (2, 5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (2, 5, 3))
            log_py = tf.reduce_sum(normal.log_prob(x) + log_det, axis=-1)

            np.testing.assert_allclose(
                *sess.run([distrib.log_prob(y, group_ndims=1), log_py]),
                rtol=1e-5
            )

            # test value_ndims = 1, group_ndims = 0
            flow = QuadraticFlow(2., 5., value_ndims=1)
            flow.build(tf.zeros([2, 5, 3], dtype=tf.float64))
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 1)

            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (2, 5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (2, 5))
            log_py = normal.log_prob(x, group_ndims=1) + log_det

            np.testing.assert_allclose(
                *sess.run([distrib.log_prob(y, group_ndims=0), log_py]),
                rtol=1e-5
            )

            # test value_ndims = 1, group_ndims = 2
            flow = QuadraticFlow(2., 5., value_ndims=1)
            flow.build(tf.zeros([2, 5, 3], dtype=tf.float64))
            distrib = FlowDistribution(normal, flow)
            self.assertEqual(distrib.value_ndims, 1)

            x, log_det = flow.inverse_transform(y)
            self.assertTupleEqual(get_static_shape(x), (2, 5, 3))
            self.assertTupleEqual(get_static_shape(log_det), (2, 5))
            log_py = tf.reduce_sum(
                log_det + tf.reduce_sum(normal.log_prob(x), axis=-1))

            np.testing.assert_allclose(
                *sess.run([distrib.log_prob(y, group_ndims=2), log_py]),
                rtol=1e-5
            )