示例#1
0
 def sample_n(self, n, seed=None):
   if self.n != 1:
     logits = logit(tf.ones(self.n, dtype=tf.float32) /
                    tf.cast(self.n, dtype=tf.float32))
     cat = tf.contrib.distributions.Categorical(logits=logits)
     indices = cat.sample_n(n, seed)
     return tf.gather(self._params, indices)
   else:
     multiples = tf.concat(0, [tf.expand_dims(n, 0),
                               [1] * len(self.get_event_shape())])
     return tile(self._params, multiples)
示例#2
0
 def _sample_n(self, n, seed=None):
     if self.n != 1:
         logits = logit(
             tf.ones(self.n, dtype=tf.float32) /
             tf.cast(self.n, dtype=tf.float32))
         cat = tf.contrib.distributions.Categorical(logits=logits)
         indices = cat._sample_n(n, seed)
         return tf.gather(self._params, indices)
     else:
         multiples = tf.concat(
             [tf.expand_dims(n, 0), [1] * len(self.get_event_shape())], 0)
         return tile(self._params, multiples)
示例#3
0
    def sample_n(self, n, seed=None, name="sample_n"):
        """Sample `n` observations from the Point Mass distribution.

    Args:
      n: `Scalar`, type int32, the number of observations to sample.
      seed: Python integer, the random seed.
      name: The name to give this op.

    Returns:
      samples: `[n, ...]`, a `Tensor` of `n` samples for each
        of the distributions determined by broadcasting the hyperparameters.
    """
        with ops.name_scope(self.name):
            with ops.op_scope([self._params, n], name):
                multiples = tf.concat(
                    0,
                    [tf.expand_dims(n, 0), [1] * len(self.get_event_shape())])
                return tile(self._params, multiples)
示例#4
0
    def _test(self, input, multiples):
        if isinstance(multiples, int) or isinstance(multiples, float):
            multiples_shape = [multiples]
        elif isinstance(multiples, tuple):
            multiples_shape = list(multiples)
        else:
            multiples_shape = multiples

        input_shape = get_dims(input)
        diff = len(input_shape) - len(multiples_shape)
        if diff < 0:
            input_shape = [1] * abs(diff) + input_shape
        elif diff > 0:
            multiples_shape = [1] * diff + multiples_shape

        val_true = [x * y for x, y in zip(input_shape, multiples_shape)]
        with self.test_session():
            val_est = get_dims(tile(input, multiples))
            assert val_est == val_true
示例#5
0
 def _sample_n(self, n, seed=None):
     multiples = tf.concat(
         0, [tf.expand_dims(n, 0), [1] * len(self.get_event_shape())])
     return tile(self._params, multiples)
示例#6
0
 def _sample_n(self, n, seed=None):
   multiples = tf.concat(0, [tf.expand_dims(n, 0),
                             [1] * len(self.get_event_shape())])
   return tile(self._params, multiples)