コード例 #1
0
ファイル: bn.py プロジェクト: hmartelb/zhusuan-mnist-cvae
    def normal(self,
               name,
               mean=0.,
               _sentinel=None,
               std=None,
               logstd=None,
               group_ndims=0,
               n_samples=None,
               is_reparameterized=True,
               check_numerics=False,
               **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Normal distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Normal` for more information
        about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Normal(mean,
                                    _sentinel=_sentinel,
                                    std=std,
                                    logstd=logstd,
                                    group_ndims=group_ndims,
                                    is_reparameterized=is_reparameterized,
                                    check_numerics=check_numerics,
                                    **kwargs)
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)
コード例 #2
0
    def __init__(self,
                 mean,
                 std=None,
                 logstd=None,
                 is_reparameterized=True,
                 check_numerics=None):
        """
        Construct the :class:`Normal`.

        Args:
            mean: A `float` tensor, the mean of the Normal distribution.
                Should be broadcastable against `std` / `logstd`.
            std: A `float` tensor, the standard deviation of the Normal
                distribution.  Should be positive, and broadcastable against
                `mean`.  One and only one of `std` or `logstd` should be
                specified.
            logstd: A `float` tensor, the log standard deviation of the Normal
                distribution.  Should be broadcastable against `mean`.
            is_reparameterized (bool): Whether or not the gradients can
                be propagated through parameters? (default :obj:`True`)
            check_numerics (bool): Whether or not to check numerical issues.
                Default to ``tfsnippet.settings.check_numerics``.
        """
        if check_numerics is None:
            check_numerics = settings.check_numerics
        super(TruncatedNormal, self).__init__(
            zd.Normal(
                mean=mean,
                std=std,
                logstd=logstd,
                is_reparameterized=is_reparameterized,
                check_numerics=check_numerics,
            ))
コード例 #3
0
 def test_is_reparemeterized(distrib_flag, sample_flag=None):
     normal = zd.Normal(mean=x, std=1., is_reparameterized=distrib_flag)
     distrib = ZhuSuanDistribution(normal)
     samples = distrib.sample(is_reparameterized=sample_flag)
     grads = tf.gradients(samples, x)
     if sample_flag is True or (sample_flag is None and distrib_flag):
         self.assertIsNotNone(grads[0])
     else:
         self.assertIsNone(grads[0])
コード例 #4
0
    def test_sample(self):
        # sample re-parameterized samples from a non-reparameterized
        # distribution should cause an error
        with pytest.raises(RuntimeError, match='.* is not re-parameterized'):
            d = ZhuSuanDistribution(
                Mock(spec=zd.Normal,
                     wraps=zd.Normal(mean=0., std=1.),
                     is_reparameterized=False))
            self.assertFalse(d.is_reparameterized)
            _ = d.sample(is_reparameterized=True)

        # test sampling with default is_reparameterized = True
        samples = tf.constant(12345678.)
        d = ZhuSuanDistribution(
            Mock(spec=zd.Normal,
                 wraps=zd.Normal(mean=0., std=1.),
                 is_reparameterized=True,
                 sample=Mock(return_value=samples)))
        t = d.sample()
        self.assertIsInstance(t, StochasticTensor)
        self.assertIsNone(t.n_samples)
        self.assertEqual(t.group_ndims, 0)
        self.assertTrue(t.is_reparameterized)
        with self.test_session():
            np.testing.assert_equal(d.sample().eval(), samples.eval())

        # test sampling with default is_reparameterized = False
        self.assertFalse(
            ZhuSuanDistribution(
                Mock(spec=zd.Normal,
                     wraps=zd.Normal(mean=0., std=1.),
                     is_reparameterized=False,
                     sample=Mock(
                         return_value=samples))).sample().is_reparameterized)

        # test sampling with n_samples
        t = d.sample(n_samples=2)
        self.assertEqual(t.n_samples, 2)

        # test sampling with overrided is_reparameterized attribute
        t = d.sample(is_reparameterized=False)
        self.assertFalse(t.is_reparameterized)
コード例 #5
0
ファイル: stochastic.py プロジェクト: wmyw96/ZhuSuan
 def __init__(self,
              name,
              mean=0.,
              logstd=0.,
              n_samples=None,
              group_event_ndims=0,
              is_reparameterized=True,
              check_numerics=False):
     norm = distributions.Normal(mean,
                                 logstd,
                                 group_event_ndims=group_event_ndims,
                                 is_reparameterized=is_reparameterized,
                                 check_numerics=check_numerics)
     super(Normal, self).__init__(name, norm, n_samples)
コード例 #6
0
    def __init__(self, mean, std=None, logstd=None, check_numerics=False):
        """
        Construct the :class:`Normal`.

        Args:
            mean: A `float` tensor, the mean of the Normal distribution.
                Should be broadcastable against `std` / `logstd`.
            std: A `float` tensor, the standard deviation of the Normal
                distribution.  Should be positive, and broadcastable against
                `mean`.  One and only one of `std` or `logstd` should be
                specified.
            logstd: A `float` tensor, the log standard deviation of the Normal
                distribution.  Should be broadcastable against `mean`.
            check_numerics (bool): Whether or not to check numeric issues.
        """
        super(Normal, self).__init__(zd.Normal(
            mean=mean, std=std, logstd=logstd, check_numerics=check_numerics))
コード例 #7
0
ファイル: stochastic.py プロジェクト: zhanghang123123/zhusuan
 def __init__(self,
              name,
              mean=0.,
              _sentinel=None,
              std=None,
              logstd=None,
              n_samples=None,
              group_ndims=0,
              is_reparameterized=True,
              check_numerics=False,
              **kwargs):
     norm = distributions.Normal(mean,
                                 _sentinel=_sentinel,
                                 std=std,
                                 logstd=logstd,
                                 group_ndims=group_ndims,
                                 is_reparameterized=is_reparameterized,
                                 check_numerics=check_numerics,
                                 **kwargs)
     super(Normal, self).__init__(name, norm, n_samples)
コード例 #8
0
 def test_zs_distribution(self):
     normal = zd.Normal(mean=0., std=1.)
     distrib = as_distribution(normal)
     self.assertIsInstance(distrib, Distribution)
     self.assertIsInstance(distrib, ZhuSuanDistribution)
     self.assertIs(distrib._distribution, normal)
コード例 #9
0
    def test_prob_and_log_prob(self):
        x = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) / 24.
        normal = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4]))
        normal1 = zd.Normal(mean=tf.zeros([3, 4]),
                            std=tf.ones([3, 4]),
                            group_ndims=1)

        # test with default group_ndims
        distrib = ZhuSuanDistribution(normal)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())

        # test with static group_ndims
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=1).get_shape(),
                normal1.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=1).get_shape(),
                normal1.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=1).eval(),
                normal1.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x, group_ndims=1).eval(),
                normal1.prob(x).eval())

        # test with dynamic group_ndims
        group_ndims = tf.constant(1, dtype=tf.int32)
        normal1d = zd.Normal(mean=normal.mean,
                             std=normal.std,
                             group_ndims=group_ndims)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=group_ndims).get_shape(),
                normal1d.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=group_ndims).get_shape(),
                normal1d.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=group_ndims).eval(),
                normal1d.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x, group_ndims=group_ndims).eval(),
                normal1d.prob(x).eval())

        # test with bad dynamic group_ndims
        group_ndims = tf.constant(-1, dtype=tf.int32)
        with self.test_session():
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.log_prob(x, group_ndims=group_ndims).eval()
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.prob(x, group_ndims=group_ndims).eval()

        # test override the default group_ndims in ZhuSuan distribution
        distrib = ZhuSuanDistribution(normal1)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())
コード例 #10
0
 def test_repr(self):
     distrib = ZhuSuanDistribution(
         Mock(spec=zd.Normal,
              wraps=zd.Normal(mean=0., std=1.),
              __repr__=Mock(return_value='repr_output')))
     self.assertEqual(repr(distrib), 'Distribution(repr_output)')
コード例 #11
0
    def test_prob_and_log_prob(self):
        x = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) / 24.
        normal = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4]))
        normal1 = zd.Normal(mean=tf.zeros([3, 4]),
                            std=tf.ones([3, 4]),
                            group_ndims=1)

        # test with default group_ndims
        distrib = ZhuSuanDistribution(normal)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())

        # test with static group_ndims
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=1).get_shape(),
                normal1.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=1).get_shape(),
                normal1.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=1).eval(),
                normal1.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(x, group_ndims=1).eval(),
                                       normal1.prob(x).eval(),
                                       rtol=1e-5)

        # test with dynamic group_ndims
        group_ndims = tf.constant(1, dtype=tf.int32)
        normal1d = zd.Normal(mean=normal.mean,
                             std=normal.std,
                             group_ndims=group_ndims)
        with self.test_session():
            # Note: Because we added auxiliary asserts to reduce_mean in our
            # log_prob, the following two static shapes will not be equal.
            #
            # self.assertEqual(
            #     distrib.log_prob(x, group_ndims=group_ndims).get_shape(),
            #     normal1d.log_prob(x).get_shape()
            # )
            # self.assertEqual(
            #     distrib.prob(x, group_ndims=group_ndims).get_shape(),
            #     normal1d.prob(x).get_shape()
            # )
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=group_ndims).eval(),
                normal1d.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(
                x, group_ndims=group_ndims).eval(),
                                       normal1d.prob(x).eval(),
                                       rtol=1e-5)

        # test with bad dynamic group_ndims
        group_ndims = tf.constant(-1, dtype=tf.int32)
        with self.test_session():
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.log_prob(x, group_ndims=group_ndims).eval()
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.prob(x, group_ndims=group_ndims).eval()

        # test override the default group_ndims in ZhuSuan distribution
        distrib = ZhuSuanDistribution(normal1)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(x).eval(),
                                       normal.prob(x).eval(),
                                       rtol=1e-5)

        # test compute_density
        distrib = ZhuSuanDistribution(normal1)
        t = distrib.sample()
        self.assertIsNone(t._self_log_prob)
        t = distrib.sample(compute_density=False)
        self.assertIsNone(t._self_log_prob)
        t = distrib.sample(compute_density=True)
        self.assertIsNotNone(t._self_log_prob)