def _sample_n(self, n, seed=None):
     # TODO(b/149334734): Consider sampling from the Logistic distribution,
     # and then truncating based on cutpoints. This could be faster than the
     # given sampling scheme.
     logits = tf.reshape(self.categorical_log_probs(),
                         [-1, self._num_categories()])
     draws = samplers.categorical(logits, n, dtype=self.dtype, seed=seed)
     return tf.reshape(tf.transpose(draws),
                       shape=ps.concat(
                           [[n], self._batch_shape_tensor()], axis=0))
Beispiel #2
0
 def _sample_one_batch_member(args):
     logits, num_cat_samples, item_seed = args  # [K], []
     # x has shape [1, num_cat_samples = num_samples * num_trials]
     x = samplers.categorical(logits[tf.newaxis, ...],
                              num_cat_samples,
                              seed=item_seed)
     x = tf.reshape(x, shape=[num_samples,
                              -1])  # [num_samples, num_trials]
     x = tf.one_hot(
         x, depth=num_classes)  # [num_samples, num_trials, num_classes]
     x = tf.reduce_sum(x, axis=-2)  # [num_samples, num_classes]
     return tf.cast(x, dtype=dtype)
 def _sample_n(self, n, seed=None):
   logits = self._logits_parameter_no_checks()
   sample_shape = ps.concat([[n], ps.shape(logits)], 0)
   event_size = self._event_size(logits)
   if tensorshape_util.rank(logits.shape) == 2:
     logits_2d = logits
   else:
     logits_2d = tf.reshape(logits, [-1, event_size])
   samples = samplers.categorical(logits_2d, n, seed=seed)
   samples = tf.transpose(a=samples)
   samples = tf.one_hot(samples, event_size, dtype=self.dtype)
   ret = tf.reshape(samples, sample_shape)
   return ret
Beispiel #4
0
 def _sample_n(self, n, seed=None):
   logits = self._logits_parameter_no_checks()
   logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)])
   sample_dtype = tf.int64 if dtype_util.size(self.dtype) > 4 else tf.int32
   # TODO(b/147874898): Remove workaround for seed-sensitive tests.
   if seed is None or isinstance(seed, six.integer_types):
     draws = tf.random.categorical(
         logits_2d, n, dtype=sample_dtype, seed=seed)
   else:
     draws = samplers.categorical(
         logits_2d, n, dtype=sample_dtype, seed=seed)
   draws = tf.cast(draws, self.dtype)
   return tf.reshape(
       tf.transpose(draws),
       shape=ps.concat([[n], self._batch_shape_tensor(logits=logits)], axis=0))