def onehot_categorical(self, name, logits, n_samples=None, group_ndims=0, dtype=tf.int32, **kwargs): """ Add a stochastic node in this :class:`BayesianNet` that follows the OnehotCategorical distribution. :param name: The name of the stochastic node. Must be unique in a :class:`BayesianNet`. See :class:`~zhusuan.distributions.multivariate.OnehotCategorical` for more information about the other arguments. :return: A :class:`StochasticTensor` instance. """ dist = distributions.OnehotCategorical(logits, group_ndims=group_ndims, dtype=dtype, **kwargs) return self.stochastic(name, dist, n_samples=n_samples, **kwargs)
def __init__(self, name, logits, n_samples=None, group_ndims=0, dtype=None, **kwargs): onehot_cat = distributions.OnehotCategorical(logits, group_ndims=group_ndims, dtype=dtype, **kwargs) super(OnehotCategorical, self).__init__(name, onehot_cat, n_samples)
def __init__(self, logits, dtype=None): """ Construct the :class:`OnehotCategorical`. Args: logits: An N-D (N >= 1) `float` Tensor of shape ``(..., n_categories)``. Each slice `[i, j,..., k, :]` represents the un-normalized log probabilities for all categories. :math:`\\mathrm{logits} \\propto \\log p` dtype: The value type of samples from the distribution. (default ``tf.int32``) """ if dtype is None: dtype = tf.int32 super(OnehotCategorical, self).__init__(zd.OnehotCategorical(logits=logits, dtype=dtype))
def test_proxied_props_and_methods(self): zs_distrib = zd.OnehotCategorical(tf.zeros([3, 4, 5]), dtype=tf.int64, group_ndims=1) distrib = ZhuSuanDistribution(zs_distrib) for attr in ['dtype', 'is_continuous', 'is_reparameterized']: self.assertEqual(getattr(zs_distrib, attr), getattr(distrib, attr), msg='Attribute `{}` does not equal'.format(attr)) for meth in ['get_value_shape', 'get_batch_shape']: self.assertEqual( getattr(zs_distrib, meth)(), getattr(distrib, meth)(), msg='Output of method `{}` does not equal'.format(meth)) with self.test_session(): for attr in ['value_shape', 'batch_shape']: np.testing.assert_equal( getattr(zs_distrib, attr).eval(), getattr(distrib, attr).eval(), err_msg='Value of attribute `{}` does not equal'.format( attr))