예제 #1
0
    def onehot_categorical(self,
                           name,
                           logits,
                           n_samples=None,
                           group_ndims=0,
                           dtype=tf.int32,
                           **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        OnehotCategorical distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.multivariate.OnehotCategorical`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.OnehotCategorical(logits,
                                               group_ndims=group_ndims,
                                               dtype=dtype,
                                               **kwargs)
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)
예제 #2
0
 def __init__(self,
              name,
              logits,
              n_samples=None,
              group_ndims=0,
              dtype=None,
              **kwargs):
     onehot_cat = distributions.OnehotCategorical(logits,
                                                  group_ndims=group_ndims,
                                                  dtype=dtype,
                                                  **kwargs)
     super(OnehotCategorical, self).__init__(name, onehot_cat, n_samples)
예제 #3
0
    def __init__(self, logits, dtype=None):
        """
        Construct the :class:`OnehotCategorical`.

        Args:
            logits: An N-D (N >= 1) `float` Tensor of shape
                ``(..., n_categories)``.  Each slice `[i, j,..., k, :]`
                represents the un-normalized log probabilities for all
                categories.  :math:`\\mathrm{logits} \\propto \\log p`
            dtype: The value type of samples from the distribution.
                (default ``tf.int32``)
        """
        if dtype is None:
            dtype = tf.int32
        super(OnehotCategorical,
              self).__init__(zd.OnehotCategorical(logits=logits, dtype=dtype))
예제 #4
0
 def test_proxied_props_and_methods(self):
     zs_distrib = zd.OnehotCategorical(tf.zeros([3, 4, 5]),
                                       dtype=tf.int64,
                                       group_ndims=1)
     distrib = ZhuSuanDistribution(zs_distrib)
     for attr in ['dtype', 'is_continuous', 'is_reparameterized']:
         self.assertEqual(getattr(zs_distrib, attr),
                          getattr(distrib, attr),
                          msg='Attribute `{}` does not equal'.format(attr))
     for meth in ['get_value_shape', 'get_batch_shape']:
         self.assertEqual(
             getattr(zs_distrib, meth)(),
             getattr(distrib, meth)(),
             msg='Output of method `{}` does not equal'.format(meth))
     with self.test_session():
         for attr in ['value_shape', 'batch_shape']:
             np.testing.assert_equal(
                 getattr(zs_distrib, attr).eval(),
                 getattr(distrib, attr).eval(),
                 err_msg='Value of attribute `{}` does not equal'.format(
                     attr))