示例#1
0
 def _make_ops(self, num_samples, seed=None):
   prob_dist = tf.constant([[0.15, 0.5, 0.3, 0.05]])
   logits = tf.log(prob_dist)
   # Two independent sets of samples from the same distribution
   sample_op1 = random_ops.multinomial(logits, num_samples, seed)
   sample_op2 = random_ops.multinomial(logits, num_samples, seed)
   return (sample_op1, sample_op2)
 def _make_ops(self, num_samples, seed=None):
   prob_dist = tf.constant([[0.15, 0.5, 0.3, 0.05]])
   logits = tf.log(prob_dist)
   # Two independent sets of samples from the same distribution
   sample_op1 = random_ops.multinomial(logits, num_samples, seed)
   sample_op2 = random_ops.multinomial(logits, num_samples, seed)
   return (sample_op1, sample_op2)
示例#3
0
  def sample_n(self, n, seed=None, name="sample_n"):
    """Sample `n` observations from the Categorical distribution.

    Args:
      n: `Scalar` `Tensor` of type `int32` or `int64`, the number of
        observations to sample.
      seed: Random seed (optional).
      name: A name for this operation (optional).

    Returns:
      An `int64` `Tensor` with shape `[n, batch_shape, event_shape]`
    """
    with ops.name_scope(self.name):
      with ops.name_scope(name, values=[self.logits, n]):
        n = ops.convert_to_tensor(n, name="n")
        logits_2d = array_ops.reshape(
            self.logits, array_ops.pack([-1, self.num_classes]))
        samples = random_ops.multinomial(logits_2d, n, seed=seed)
        samples = math_ops.cast(samples, self._dtype)
        ret = array_ops.reshape(
            array_ops.transpose(samples),
            array_ops.concat(0, ([n], self.batch_shape())))
        ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
                      .concatenate(self.get_batch_shape()))
        return ret
示例#4
0
 def testMatchStatefulMultinomial(self):
     # Stateless ops should be the same as stateful ops on the first call
     # after seed scrambling.
     key = 0x3ec8f720, 0x02461e29
     num_samples = 4
     for logits_dtype in np.float16, np.float32, np.float64:
         for output_dtype in dtypes.int32, dtypes.int64:
             for seed in (7, 17), (11, 5), (2, 3):
                 preseed = invert_philox(
                     key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
                 preseed = preseed[::2] | preseed[1::2] << 32
                 random_seed.set_random_seed(seed[0])
                 with self.test_session(use_gpu=True):
                     for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5],
                                                               [0.8, 0.2],
                                                               [0.25,
                                                                0.75]]):
                         logits_t = constant_op.constant(logits,
                                                         dtype=logits_dtype)
                         stateful = random_ops.multinomial(
                             logits_t,
                             num_samples,
                             seed=seed[1],
                             output_dtype=output_dtype)
                         pure = stateless.stateless_multinomial(
                             logits_t,
                             num_samples,
                             seed=preseed,
                             output_dtype=output_dtype)
                         self.assertAllEqual(stateful.eval(), pure.eval())
示例#5
0
 def _sample_n(self, n, seed=None):
     n_draws = math_ops.cast(self.n, dtype=dtypes.int32)
     if self.n.get_shape().ndims is not None:
         if self.n.get_shape().ndims != 0:
             raise NotImplementedError(
                 "Sample only supported for scalar number of draws.")
     elif self.validate_args:
         is_scalar = check_ops.assert_rank(
             n_draws,
             0,
             message="Sample only supported for scalar number of draws.")
         n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
     k = self.event_shape()[0]
     unnormalized_logits = array_ops.reshape(math_ops.log(
         random_ops.random_gamma(shape=[n],
                                 alpha=self.alpha,
                                 dtype=self.dtype,
                                 seed=seed)),
                                             shape=[-1, k])
     draws = random_ops.multinomial(logits=unnormalized_logits,
                                    num_samples=n_draws,
                                    seed=distribution_util.gen_new_seed(
                                        seed, salt="dirichlet_multinomial"))
     x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
                             reduction_indices=-2)
     final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0)
     return array_ops.reshape(x, final_shape)
示例#6
0
        def body(i, prev_c, prev_h, actions, log_probs):
            # pylint: disable=g-long-lambda
            signal = control_flow_ops.cond(
                math_ops.equal(i, 0), lambda: array_ops.tile(
                    device_go_embedding, [self.hparams.num_children, 1]),
                lambda: embedding_ops.embedding_lookup(device_embeddings,
                                                       actions.read(i - 1)))
            if self.hparams.keep_prob is not None:
                signal = nn_ops.dropout(signal,
                                        rate=(1 - self.hparams.keep_prob))
            next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
            query = math_ops.matmul(next_h, attn_w_2)
            query = array_ops.reshape(
                query,
                [self.hparams.num_children, 1, self.hparams.hidden_size])
            query = math_ops.tanh(query + attn_mem)
            query = array_ops.reshape(query, [
                self.hparams.num_children * self.num_groups,
                self.hparams.hidden_size
            ])
            query = math_ops.matmul(query, attn_v)
            query = array_ops.reshape(
                query, [self.hparams.num_children, self.num_groups])
            query = nn_ops.softmax(query)
            query = array_ops.reshape(
                query, [self.hparams.num_children, self.num_groups, 1])
            query = math_ops.reduce_sum(attn_mem * query, axis=1)
            query = array_ops.concat([next_h, query], axis=1)
            logits = math_ops.matmul(query, device_softmax)
            logits /= self.hparams.temperature
            if self.hparams.tanh_constant > 0:
                logits = math_ops.tanh(logits) * self.hparams.tanh_constant
            if self.hparams.logits_std_noise > 0:
                num_in_logits = math_ops.cast(array_ops.size(logits),
                                              dtype=dtypes.float32)
                avg_norm = math_ops.divide(linalg_ops.norm(logits),
                                           math_ops.sqrt(num_in_logits))
                logits_noise = random_ops.random_normal(
                    array_ops.shape(logits),
                    stddev=self.hparams.logits_std_noise * avg_norm)
                logits = control_flow_ops.cond(
                    self.global_step > self.hparams.stop_noise_step,
                    lambda: logits, lambda: logits + logits_noise)

            if mode == "sample":
                next_y = random_ops.multinomial(logits,
                                                1,
                                                seed=self.hparams.seed)
            elif mode == "greedy":
                next_y = math_ops.argmax(logits, 1)
            elif mode == "target":
                next_y = array_ops.slice(y, [0, i], [-1, 1])
            else:
                raise NotImplementedError
            next_y = math_ops.cast(next_y, dtypes.int32)
            next_y = array_ops.reshape(next_y, [self.hparams.num_children])
            actions = actions.write(i, next_y)
            log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits(
                logits=logits, labels=next_y)
            return i + 1, next_c, next_h, actions, log_probs
示例#7
0
    def _do_sampling(self, logits, num_samples):
        """Categorical samples from given input.

    Args:
      logits: Numpy ndarray of shape [batch_size, num_classes].
      num_samples: Int; number of samples to draw.

    Returns:
      Frequencies from sampled classes; shape [batch_size, num_classes].
    """
        with self.cached_session() as sess, self.test_scope():
            random_seed.set_random_seed(1618)
            op = random_ops.multinomial(logits,
                                        num_samples,
                                        output_dtype=dtypes.int32)
            d = self.evaluate(op)

        batch_size, num_classes = logits.shape
        freqs_mat = []
        for i in range(batch_size):
            cnts = dict(collections.Counter(d[i, :]))

            # Requires drawn class labels be in range.
            self.assertLess(max(cnts.keys()), num_classes)
            self.assertGreaterEqual(min(cnts.keys()), 0)

            freqs = [(cnts[k] * 1. / num_samples if k in cnts else 0)
                     for k in range(num_classes)]
            freqs_mat.append(freqs)

        return freqs_mat
示例#8
0
  def sample(self, n, seed=None, name="sample"):
    """Sample `n` observations from the Categorical distribution.

    Args:
      n: 0-D.  Number of independent samples to draw for each distribution.
      seed: Random seed (optional).
      name: A name for this operation (optional).

    Returns:
      An `int64` `Tensor` with shape `[n, batch_shape, event_shape]`
    """
    with ops.name_scope(self.name):
      with ops.op_scope([self.logits, n], name):
        n = ops.convert_to_tensor(n, name="n")
        logits_2d = array_ops.reshape(
            self.logits, array_ops.pack([-1, self.num_classes]))
        samples = random_ops.multinomial(logits_2d, n, seed=seed)
        samples = math_ops.cast(samples, self._dtype)
        ret = array_ops.reshape(
            array_ops.transpose(samples),
            array_ops.concat(
                0, [array_ops.expand_dims(n, 0), self.batch_shape()]))
        ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
                      .concatenate(self.get_batch_shape()))
        return ret
示例#9
0
    def sample_n(self, n, seed=None, name="sample_n"):
        """Sample `n` observations from the Categorical distribution.

    Args:
      n: 0-D.  Number of independent samples to draw for each distribution.
      seed: Random seed (optional).
      name: A name for this operation (optional).

    Returns:
      An `int64` `Tensor` with shape `[n, batch_shape, event_shape]`
    """
        with ops.name_scope(self.name):
            with ops.name_scope(name, values=[self.logits, n]):
                n = ops.convert_to_tensor(n, name="n")
                logits_2d = array_ops.reshape(
                    self.logits, array_ops.pack([-1, self.num_classes]))
                samples = random_ops.multinomial(logits_2d, n, seed=seed)
                samples = math_ops.cast(samples, self._dtype)
                ret = array_ops.reshape(
                    array_ops.transpose(samples),
                    array_ops.concat(0, ([n], self.batch_shape())))
                ret.set_shape(
                    tensor_shape.vector(
                        tensor_util.constant_value(n)).concatenate(
                            self.get_batch_shape()))
                return ret
 def testEmpty(self):
   with self.cached_session():
     with self.test_scope():
       x = random_ops.multinomial(
           array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
       y = self.evaluate(x)
       self.assertEqual(y.shape, (42, 0))
示例#11
0
    def sample(self, n, seed=None, name="sample"):
        """Generate `n` samples.

    Args:
      n: scalar.  Number of samples to draw from each distribution.
      seed: Python integer seed for RNG.
      name: name to give to the op.

    Returns:
      samples: a `Tensor` of shape `(n,) + self.batch_shape` with values of type
          `self.dtype`.
    """
        with ops.name_scope(self.name):
            with ops.op_scope([self.p, n], name):
                n = ops.convert_to_tensor(n, name="n")
                p_2d = array_ops.reshape(self.p, array_ops.pack([-1, 1]))
                q_2d = 1. - p_2d
                probs = array_ops.concat(1, [q_2d, p_2d])
                samples = random_ops.multinomial(math_ops.log(probs),
                                                 n,
                                                 seed=seed)
                ret = array_ops.reshape(
                    array_ops.transpose(samples),
                    array_ops.concat(
                        0, [array_ops.expand_dims(n, 0),
                            self.batch_shape()]))
                ret.set_shape(
                    tensor_shape.vector(
                        tensor_util.constant_value(n)).concatenate(
                            self.get_batch_shape()))
                return math_ops.cast(ret, self.dtype)
  def _do_sampling(self, logits, num_samples):
    """Categorical samples from given input.

    Args:
      logits: Numpy ndarray of shape [batch_size, num_classes].
      num_samples: Int; number of samples to draw.

    Returns:
      Frequencies from sampled classes; shape [batch_size, num_classes].
    """
    with self.cached_session(), self.test_scope():
      random_seed.set_random_seed(1618)
      op = random_ops.multinomial(logits, num_samples,
                                  output_dtype=dtypes.int32)
      d = self.evaluate(op)

    batch_size, num_classes = logits.shape
    freqs_mat = []
    for i in range(batch_size):
      cnts = dict(collections.Counter(d[i, :]))

      # Requires drawn class labels be in range.
      self.assertLess(max(cnts.keys()), num_classes)
      self.assertGreaterEqual(min(cnts.keys()), 0)

      freqs = [(cnts[k] * 1. / num_samples if k in cnts else 0)
               for k in range(num_classes)]
      freqs_mat.append(freqs)

    return freqs_mat
  def sample(self, time, outputs, state, name=None):
    """sample for SampledEmbeddingHelper."""
    #del time, state  # unused by sample_fn
    # Outputs are logits, use random_ops.multinomial to sample ids
    if not isinstance(outputs, ops.Tensor):
      raise TypeError("Expected outputs to be a single Tensor, got: %s" %
                      type(outputs))

    ### Original ###
    #outputs2 = math_ops.div(outputs,self.temp)

    #Own method
    max_values = tf.math.argmax(outputs, axis=1) 
    one_hot_max_values =  tf.one_hot(indices=max_values, depth=outputs.shape[1])
    second_highest = tf.math.top_k(outputs, k =3)
    rel_indices = second_highest.values[:,0] - second_highest.values[:,1]
    rel_indices_new = tf.reshape(rel_indices, [-1,1])
    outputs2 = outputs - rel_indices_new * one_hot_max_values/3 

    sample_ids2 = math_ops.cast(
         random_ops.multinomial(outputs2, 1)
          , dtypes.int32
    )

    sample_ids = array_ops.reshape(sample_ids2,[-1])
 
    return sample_ids
示例#14
0
 def _sample_n(self, n, seed=None):
     n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
     if self.total_count.get_shape().ndims is not None:
         if self.total_count.get_shape().ndims != 0:
             raise NotImplementedError(
                 "Sample only supported for scalar number of draws.")
     elif self.validate_args:
         is_scalar = check_ops.assert_rank(
             n_draws,
             0,
             message="Sample only supported for scalar number of draws.")
         n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
     k = self.event_shape_tensor()[0]
     # Flatten batch dims so logits has shape [B, k],
     # where B = reduce_prod(self.batch_shape_tensor()).
     x = random_ops.multinomial(logits=array_ops.reshape(
         self.logits, [-1, k]),
                                num_samples=n * n_draws,
                                seed=seed)
     x = array_ops.reshape(x, shape=[-1, n, n_draws])
     x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k),
                             axis=-2)  # shape: [B, n, k]
     x = array_ops.transpose(x, perm=[1, 0, 2])
     final_shape = array_ops.concat(
         [[n], self.batch_shape_tensor(), [k]], 0)
     x = array_ops.reshape(x, final_shape)
     return math_ops.cast(x, self.dtype)
 def testMatchStatefulMultinomial(self):
   # Stateless ops should be the same as stateful ops on the first call
   # after seed scrambling.
   key = 0x3ec8f720, 0x02461e29
   num_samples = 4
   for logits_dtype in np.float16, np.float32, np.float64:
     for output_dtype in dtypes.int32, dtypes.int64:
       for seed in (7, 17), (11, 5), (2, 3):
         preseed = invert_philox(key,
                                 (seed[0], 0, seed[1], 0)).astype(np.uint64)
         preseed = preseed[::2] | preseed[1::2] << 32
         random_seed.set_random_seed(seed[0])
         with self.test_session(use_gpu=True):
           for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
                                                     [0.25, 0.75]]):
             logits_t = constant_op.constant(logits, dtype=logits_dtype)
             stateful = random_ops.multinomial(
                 logits_t,
                 num_samples,
                 seed=seed[1],
                 output_dtype=output_dtype)
             pure = stateless.stateless_multinomial(
                 logits_t,
                 num_samples,
                 seed=preseed,
                 output_dtype=output_dtype)
             self.assertAllEqual(stateful.eval(), pure.eval())
示例#16
0
 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
   if self.total_count.get_shape().ndims is not None:
     if self.total_count.get_shape().ndims != 0:
       raise NotImplementedError(
           "Sample only supported for scalar number of draws.")
   elif self.validate_args:
     is_scalar = check_ops.assert_rank(
         n_draws, 0,
         message="Sample only supported for scalar number of draws.")
     n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
   k = self.event_shape_tensor()[0]
   # Flatten batch dims so logits has shape [B, k],
   # where B = reduce_prod(self.batch_shape_tensor()).
   x = random_ops.multinomial(
       logits=array_ops.reshape(self.logits, [-1, k]),
       num_samples=n * n_draws,
       seed=seed)
   x = array_ops.reshape(x, shape=[-1, n, n_draws])
   x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k),
                           axis=-2)  # shape: [B, n, k]
   x = array_ops.transpose(x, perm=[1, 0, 2])
   final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
   x = array_ops.reshape(x, final_shape)
   return math_ops.cast(x, self.dtype)
示例#17
0
 def testNegativeMinLogits(self):
   random_seed.set_random_seed(78844)
   with test_util.use_gpu():
     logits = constant_op.constant([[np.finfo(np.float32).min] * 1023 + [0]])
     num_samples = 1000
     samples = self.evaluate(random_ops.multinomial(logits, num_samples))
     self.assertAllEqual([[1023] * num_samples], samples)
示例#18
0
  def sample(self, n, seed=None, name="sample"):
    """Generate `n` samples.

    Args:
      n: scalar.  Number of samples to draw from each distribution.
      seed: Python integer seed for RNG.
      name: name to give to the op.

    Returns:
      samples: a `Tensor` of shape `(n,) + self.batch_shape` with values of type
          `self.dtype`.
    """
    with ops.name_scope(self.name):
      with ops.op_scope([self.p, n], name):
        n = ops.convert_to_tensor(n, name="n")
        p_2d = array_ops.reshape(self.p, array_ops.pack([-1, 1]))
        q_2d = 1. - p_2d
        probs = array_ops.concat(1, [q_2d, p_2d])
        samples = random_ops.multinomial(math_ops.log(probs), n, seed=seed)
        ret = array_ops.reshape(
            array_ops.transpose(samples),
            array_ops.concat(0,
                             [array_ops.expand_dims(n, 0), self.batch_shape()]))
        ret.set_shape(tensor_shape.vector(tensor_util.constant_value(n))
                      .concatenate(self.get_batch_shape()))
        return math_ops.cast(ret, self.dtype)
 def _sample_single(args):
   logits, n_draw = args[0], args[1]  # [K], []
   x = random_ops.multinomial(logits[array_ops.newaxis, ...], n_draw,
                              seed)  # [1, n*n_draw]
   x = array_ops.reshape(x, shape=[n, -1])  # [n, n_draw]
   x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2)  # [n, k]
   return x
 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.n, dtype=dtypes.int32)
   if self.n.get_shape().ndims is not None:
     if self.n.get_shape().ndims != 0:
       raise NotImplementedError(
           "Sample only supported for scalar number of draws.")
   elif self.validate_args:
     is_scalar = check_ops.assert_rank(
         n_draws, 0,
         message="Sample only supported for scalar number of draws.")
     n_draws = control_flow_ops.with_dependencies([is_scalar], n_draws)
   k = self.event_shape()[0]
   unnormalized_logits = array_ops.reshape(
       math_ops.log(random_ops.random_gamma(
           shape=[n],
           alpha=self.alpha,
           dtype=self.dtype,
           seed=seed)),
       shape=[-1, k])
   draws = random_ops.multinomial(
       logits=unnormalized_logits,
       num_samples=n_draws,
       seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
   x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k),
                           reduction_indices=-2)
   final_shape = array_ops.concat([[n], self.batch_shape(), [k]], 0)
   return array_ops.reshape(x, final_shape)
示例#21
0
def _get_batch(per_class_queues, probs, batch_size):
  """Generates batches according to per-class-probabilities."""
  num_classes = probs.size
  # Number of examples per class is governed by a multinomial distribution.
  # Note: multinomial takes unnormalized log probabilities for its first
  # argument, of dimension [batch_size, num_classes].
  examples = random_ops.multinomial(
      np.expand_dims(np.log(probs), 0), batch_size)

  # Prepare the data and label batches.
  val_list = []
  label_list = []
  for i in range(num_classes):
    num_examples = math_ops.reduce_sum(
        math_ops.cast(math_ops.equal(examples, i), dtypes.int32))
    val_list.append(per_class_queues[i].dequeue_many(num_examples))
    label_list.append(array_ops.ones([num_examples], dtype=dtypes.int32) * i)

  # Create a tensor of labels.
  batch_labels = array_ops.concat(0, label_list)
  batch_labels.set_shape([batch_size])

  # Debug instrumentation.
  sample_tags = ['stratified_sample/samples_class%i' % i for i in
                 range(num_classes)]
  logging_ops.scalar_summary(sample_tags, math_ops.reduce_sum(
      array_ops.one_hot(batch_labels, num_classes), 0))

  return array_ops.concat(0, val_list), batch_labels
示例#22
0
 def testNegativeMinLogits(self):
   random_seed.set_random_seed(78844)
   with self.test_session(use_gpu=True):
     logits = constant_op.constant([[np.finfo(np.float32).min] * 1023 + [0]])
     num_samples = 1000
     samples = random_ops.multinomial(logits, num_samples).eval()
     self.assertAllEqual([[1023] * num_samples], samples)
示例#23
0
 def _sample_single(args):
   logits, n_draw = args[0], args[1]  # [K], []
   x = random_ops.multinomial(logits[array_ops.newaxis, ...], n_draw,
                              seed)  # [1, n*n_draw]
   x = array_ops.reshape(x, shape=[n, -1])  # [n, n_draw]
   x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2)  # [n, k]
   return x
示例#24
0
    def __call__(self, inputs, state, scope=None):
        batch_size = array_ops.shape(inputs)[0]

        ### Unpack state
        # last_betas: bs x V
        # last_index : bs
        last_beta, last_index = array_ops.split(
            state, [self._transitions.num_tags, 1], axis=1)
        last_index = math_ops.cast(last_index, dtypes.int32)

        ### Unpack inputs
        # unary:   bs x V
        # last_beta bs x V
        shape = [
            self._transitions.num_tags, self._transitions.num_tags,
            self._transitions._total_nr_parameters
        ]
        unary, beta, pairwise_flat = array_ops.split(inputs, shape, axis=1)

        ### Construct logits of this timestep according to (6) in the paper
        batch_indices = array_ops.reshape(math_ops.range(batch_size), [-1, 1])

        pairwise = self._transitions.get_pairwise_given_start(
            pairwise_flat, last_index)
        # bs x V
        last_beta = array_ops.gather_nd(
            last_beta, array_ops.concat([batch_indices, last_index], axis=1))
        last_beta = array_ops.reshape(last_beta, [-1, 1])
        logits = pairwise + unary + beta - last_beta  # NOTE this is only valid for the index we are gathering from

        # bs x V
        log_probs = nn_ops.log_softmax(logits)
        # bs x 1
        entropy = -math_ops.reduce_sum(
            log_probs * math_ops.exp(log_probs), axis=1, keepdims=True)

        ### Sample the next symbol
        # bs x 1
        new_indices = random_ops.multinomial(logits, 1)
        new_indices = math_ops.to_int32(new_indices)

        ### Gather the logits of the new symbol to return the sequence probability
        gather_indices = array_ops.concat(
            [batch_indices,
             array_ops.reshape(new_indices, [-1, 1])], axis=1)

        # bs x 1
        output_logits = array_ops.gather_nd(logits, gather_indices)
        output_logits = array_ops.reshape(output_logits, [-1, 1])

        ### Pack the new state
        new_state = array_ops.concat(
            [beta, math_ops.to_float(new_indices)], axis=1)

        output = array_ops.concat(
            [math_ops.to_float(new_indices), output_logits, entropy, logits],
            axis=1)

        return output, new_state
 def testEmpty(self):
     with self.cached_session() as sess:
         with self.test_scope():
             x = random_ops.multinomial(array_ops.zeros([42, 40]),
                                        0,
                                        output_dtype=dtypes.int32)
             y = sess.run(x)
             self.assertEqual(y.shape, (42, 0))
 def testSmallEntropy(self):
   random_seed.set_random_seed(1618)
   with self.test_session(use_gpu=self.use_gpu):
     # A logit value of -10 corresponds to a probability of ~5e-5.
     logits = constant_op.constant([[-10., 10., -10.], [-10., -10., 10.]])
     num_samples = 1000
     samples = random_ops.multinomial(logits, num_samples).eval()
     self.assertAllEqual([[1] * num_samples, [2] * num_samples], samples)
示例#27
0
 def testEmpty(self):
   classes = 5
   with self.test_session(use_gpu=True):
     for batch in 0, 3:
       for samples in 0, 7:
         x = random_ops.multinomial(
             array_ops.zeros([batch, classes]), samples).eval()
         self.assertEqual(x.shape, (batch, samples))
示例#28
0
 def testSmallEntropy(self):
   random_seed.set_random_seed(1618)
   with self.test_session(use_gpu=True):
     # A logit value of -10 corresponds to a probability of ~5e-5.
     logits = constant_op.constant([[-10., 10., -10.], [-10., -10., 10.]])
     num_samples = 1000
     samples = random_ops.multinomial(logits, num_samples).eval()
     self.assertAllEqual([[1] * num_samples, [2] * num_samples], samples)
 def _sample_n(self, n, seed=None):
     logits_2d = array_ops.reshape(self.logits,
                                   array_ops.pack([-1, self.num_classes]))
     samples = random_ops.multinomial(logits_2d, n, seed=seed)
     samples = math_ops.cast(samples, self.dtype)
     ret = array_ops.reshape(array_ops.transpose(samples),
                             array_ops.concat(0, ([n], self.batch_shape())))
     return ret
示例#30
0
 def _sample_n(self, n, seed=None):
     if self.logits.get_shape().ndims == 2:
         logits_2d = self.logits
     else:
         logits_2d = array_ops.reshape(self.logits, [-1, self.num_classes])
     samples = random_ops.multinomial(logits_2d, n, seed=seed)
     samples = math_ops.cast(samples, self.dtype)
     ret = array_ops.reshape(array_ops.transpose(samples), array_ops.concat(0, ([n], self.batch_shape())))
     return ret
示例#31
0
 def testEmpty(self):
   classes = 5
   with test_util.use_gpu():
     for batch in 0, 3:
       for samples in 0, 7:
         x = self.evaluate(
             random_ops.multinomial(
                 array_ops.zeros([batch, classes]), samples))
         self.assertEqual(x.shape, (batch, samples))
示例#32
0
 def testCategoricalIsInRange(self):
   for dtype in [dtypes.float32, dtypes.float64]:
     with self.test_session() as sess:
       with self.test_scope():
         x = random_ops.multinomial(
             array_ops.ones(shape=[1, 20], dtype=dtype), 1000)
       y = sess.run(x)
       self.assertTrue((y >= 0).sum() == 1000)
       self.assertTrue((y < 20).sum() == 1000)
示例#33
0
 def testEmpty(self):
     classes = 5
     with test_util.use_gpu():
         for batch in 0, 3:
             for samples in 0, 7:
                 x = self.evaluate(
                     random_ops.multinomial(
                         array_ops.zeros([batch, classes]), samples))
                 self.assertEqual(x.shape, (batch, samples))
    def body(i, prev_c, prev_h, actions, log_probs):
      # pylint: disable=g-long-lambda
      signal = control_flow_ops.cond(
          math_ops.equal(i, 0),
          lambda: array_ops.tile(device_go_embedding,
                                 [self.hparams.num_children, 1]),
          lambda: embedding_ops.embedding_lookup(device_embeddings,
                                                 actions.read(i - 1))
      )
      if self.hparams.keep_prob is not None:
        signal = nn_ops.dropout(signal, self.hparams.keep_prob)
      next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
      query = math_ops.matmul(next_h, attn_w_2)
      query = array_ops.reshape(
          query, [self.hparams.num_children, 1, self.hparams.hidden_size])
      query = math_ops.tanh(query + attn_mem)
      query = array_ops.reshape(query, [
          self.hparams.num_children * self.num_groups, self.hparams.hidden_size
      ])
      query = math_ops.matmul(query, attn_v)
      query = array_ops.reshape(query,
                                [self.hparams.num_children, self.num_groups])
      query = nn_ops.softmax(query)
      query = array_ops.reshape(query,
                                [self.hparams.num_children, self.num_groups, 1])
      query = math_ops.reduce_sum(attn_mem * query, axis=1)
      query = array_ops.concat([next_h, query], axis=1)
      logits = math_ops.matmul(query, device_softmax)
      logits /= self.hparams.temperature
      if self.hparams.tanh_constant > 0:
        logits = math_ops.tanh(logits) * self.hparams.tanh_constant
      if self.hparams.logits_std_noise > 0:
        num_in_logits = math_ops.cast(
            array_ops.size(logits), dtype=dtypes.float32)
        avg_norm = math_ops.divide(
            linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
        logits_noise = random_ops.random_normal(
            array_ops.shape(logits),
            stddev=self.hparams.logits_std_noise * avg_norm)
        logits = control_flow_ops.cond(
            self.global_step > self.hparams.stop_noise_step, lambda: logits,
            lambda: logits + logits_noise)

      if mode == "sample":
        next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
      elif mode == "greedy":
        next_y = math_ops.argmax(logits, 1)
      elif mode == "target":
        next_y = array_ops.slice(y, [0, i], [-1, 1])
      else:
        raise NotImplementedError
      next_y = math_ops.to_int32(next_y)
      next_y = array_ops.reshape(next_y, [self.hparams.num_children])
      actions = actions.write(i, next_y)
      log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits(
          logits=logits, labels=next_y)
      return i + 1, next_c, next_h, actions, log_probs
示例#35
0
 def _sample_n(self, n, seed=None):
     if self.logits.get_shape().ndims == 2:
         logits_2d = self.logits
     else:
         logits_2d = array_ops.reshape(self.logits, [-1, self.num_classes])
     samples = random_ops.multinomial(logits_2d, n, seed=seed)
     samples = math_ops.cast(samples, self.dtype)
     ret = array_ops.reshape(array_ops.transpose(samples),
                             array_ops.concat(([n], self.batch_shape()), 0))
     return ret
示例#36
0
 def _sample_n(self, n, seed=None):
   if self.logits.get_shape().ndims == 2:
     logits_2d = self.logits
   else:
     logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
   draws = random_ops.multinomial(logits_2d, n, seed=seed)
   draws = array_ops.reshape(
       array_ops.transpose(draws),
       array_ops.concat([[n], self.batch_shape_tensor()], 0))
   return math_ops.cast(draws, self.dtype)
示例#37
0
 def _sample_n(self, n, seed=None):
     if self.logits.get_shape().ndims == 2:
         logits_2d = self.logits
     else:
         logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
     draws = random_ops.multinomial(logits_2d, n, seed=seed)
     draws = array_ops.reshape(
         array_ops.transpose(draws),
         array_ops.concat([[n], self.batch_shape_tensor()], 0))
     return math_ops.cast(draws, self.dtype)
示例#38
0
 def testLargeLogits(self):
   for neg in [True, False]:
     with self.test_session(use_gpu=True):
       logits = np.array([[1000.] * 5])
       if neg:
         logits *= -1
       samples = random_ops.multinomial(logits, 10).eval()
     # Sampled classes should be in-range.
     self.assertTrue((samples >= 0).all())
     self.assertTrue((samples < 5).all())
示例#39
0
 def testSmallEntropy(self):
   random_seed.set_random_seed(1618)
   for output_dtype in [np.int32, np.int64]:
     with test_util.device(use_gpu=True):
       # A logit value of -10 corresponds to a probability of ~5e-5.
       logits = constant_op.constant([[-10., 10., -10.], [-10., -10., 10.]])
       num_samples = 1000
       samples = self.evaluate(random_ops.multinomial(
           logits, num_samples, output_dtype=output_dtype))
       self.assertAllEqual([[1] * num_samples, [2] * num_samples], samples)
示例#40
0
 def testLargeLogits(self):
     for neg in [True, False]:
         with test_util.use_gpu():
             logits = np.array([[1000.] * 5])
             if neg:
                 logits *= -1
             samples = self.evaluate(random_ops.multinomial(logits, 10))
         # Sampled classes should be in-range.
         self.assertTrue((samples >= 0).all())
         self.assertTrue((samples < 5).all())
示例#41
0
    def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
        with ops.name_scope(
                name, "attention_decoder_fn_inference",
            [time, cell_state, cell_input, cell_output, context_state]):
            if cell_input is not None:
                raise ValueError(
                    "Expected cell_input to be None, but saw: %s" % cell_input)
            if cell_output is None:
                # invariant that this is time == 0
                next_input_id = array_ops.ones([
                    batch_size,
                ], dtype=dtype) * (start_of_sequence_id)
                done = array_ops.zeros([
                    batch_size,
                ], dtype=dtypes.bool)
                cell_state = encoder_state
                cell_output = array_ops.zeros([num_decoder_symbols],
                                              dtype=dtypes.float32)
                cell_input = array_ops.gather(embeddings, next_input_id)

                # init attention
                attention = _init_attention(encoder_state)
            else:
                # construct attention
                attention = attention_construct_fn(cell_output, attention_keys,
                                                   attention_values)
                cell_output = attention

                # sampled decoder
                cell_output = output_fn(cell_output)  # logits
                if temperature:
                    temperature_cell_output = math_ops.divide(
                        cell_output, temperature)
                    temperature_cell_output = nn_ops.softmax(
                        temperature_cell_output)
                    sampled_cell_output = random_ops.multinomial(
                        cell_output, 1)
                    sampled_cell_output = array_ops.reshape(
                        sampled_cell_output, [-1])
                else:
                    sampled_cell_output = math_ops.argmax(cell_output, 1)
                next_input_id = math_ops.cast(sampled_cell_output, dtype=dtype)
                done = math_ops.equal(next_input_id, end_of_sequence_id)
                cell_input = array_ops.gather(embeddings, next_input_id)

            # combine cell_input and attention
            next_input = array_ops.concat([cell_input, attention], 1)

            # if time > maxlen, return all true vector
            done = control_flow_ops.cond(
                math_ops.greater(time, maximum_length),
                lambda: array_ops.ones([
                    batch_size,
                ], dtype=dtypes.bool), lambda: done)
            return (done, cell_state, next_input, cell_output, context_state)
 def testCategoricalIsInRange(self):
   for dtype in self.float_types:
     for output_dtype in self.output_dtypes():
       with self.cached_session():
         with self.test_scope():
           x = random_ops.multinomial(
               array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
               output_dtype=output_dtype)
         y = self.evaluate(x)
         self.assertTrue((y >= 0).sum() == 1000)
         self.assertTrue((y < 20).sum() == 1000)
示例#43
0
 def _sample_n(self, n, seed=None):
   sample_shape = array_ops.concat(([n], array_ops.shape(self.logits)), 0)
   logits = self.logits
   if logits.get_shape().ndims == 2:
     logits_2d = logits
   else:
     logits_2d = array_ops.reshape(logits, [-1, self.num_classes])
   samples = random_ops.multinomial(logits_2d, n, seed=seed)
   samples = array_ops.transpose(samples)
   samples = array_ops.one_hot(samples, self.num_classes, dtype=self.dtype)
   ret = array_ops.reshape(samples, sample_shape)
   return ret
示例#44
0
 def testCategoricalIsInRange(self):
     for dtype in self.float_types:
         for output_dtype in self.output_dtypes():
             with self.cached_session() as sess:
                 with self.test_scope():
                     x = random_ops.multinomial(array_ops.ones(
                         shape=[1, 20], dtype=dtype),
                                                1000,
                                                output_dtype=output_dtype)
                 y = self.evaluate(x)
                 self.assertTrue((y >= 0).sum() == 1000)
                 self.assertTrue((y < 20).sum() == 1000)
 def _sample_n(self, n, seed=None):
     sample_shape = array_ops.concat([[n], array_ops.shape(self.logits)], 0)
     logits = self.logits
     if logits.get_shape().ndims == 2:
         logits_2d = logits
     else:
         logits_2d = array_ops.reshape(logits, [-1, self.event_size])
     samples = random_ops.multinomial(logits_2d, n, seed=seed)
     samples = array_ops.transpose(samples)
     samples = array_ops.one_hot(samples, self.event_size, dtype=self.dtype)
     ret = array_ops.reshape(samples, sample_shape)
     return ret
示例#46
0
    def sample(self, time, outputs, state, name=None):
        """sample for SampledEmbeddingHelper."""
        del time, state  # unused by sample_fn
        # Outputs are logits, use random_ops.multinomial to sample ids
        if not isinstance(outputs, ops.Tensor):
            raise TypeError("Expected outputs to be a single Tensor, got: %s" %
                            type(outputs))

        sample_ids2 = math_ops.cast(random_ops.multinomial(outputs, 1),
                                    dtypes.int32)
        sample_ids = array_ops.reshape(sample_ids2, [-1])

        return sample_ids
示例#47
0
 def _sample_n(n):
   """Sample vector of categoricals."""
   if logits.shape.ndims == 2:
     logits_2d = logits
   else:
     logits_2d = array_ops.reshape(logits, [-1, event_size])
   sample_dtype = dtypes.int64 if logits.dtype.size > 4 else dtypes.int32
   draws = random_ops.multinomial(
       logits_2d, n, seed=seed, output_dtype=sample_dtype)
   draws = array_ops.reshape(
       array_ops.transpose(draws),
       array_ops.concat([[n], batch_shape_tensor], 0))
   return math_ops.cast(draws, dtype)
示例#48
0
 def _sample_n(n):
   """Sample vector of categoricals."""
   if logits.shape.ndims == 2:
     logits_2d = logits
   else:
     logits_2d = array_ops.reshape(logits, [-1, event_size])
   sample_dtype = dtypes.int64 if logits.dtype.size > 4 else dtypes.int32
   draws = random_ops.multinomial(
       logits_2d, n, seed=seed, output_dtype=sample_dtype)
   draws = array_ops.reshape(
       array_ops.transpose(draws),
       array_ops.concat([[n], batch_shape_tensor], 0))
   return math_ops.cast(draws, dtype)
示例#49
0
def _get_batch_from_per_class_queues(per_class_queues, probs, batch_size):
    """Generates batches according to per-class-probabilities."""
    num_classes = probs.get_shape().num_elements()
    # Number of examples per class is governed by a multinomial distribution.
    # Note: multinomial takes unnormalized log probabilities for its first
    # argument, of dimension [batch_size, num_classes].
    examples = random_ops.multinomial(
        array_ops.expand_dims(math_ops.log(probs), 0), batch_size)

    # Prepare the data and label batches.
    val_list = []
    label_list = []
    for i in range(num_classes):
        num_examples = math_ops.reduce_sum(
            math_ops.cast(math_ops.equal(examples, i), dtypes.int32))
        tensors = per_class_queues[i].dequeue_many(num_examples)

        # If you enqueue a list with a single tensor, only a single tensor is
        # returned. If you enqueue a list with multiple tensors, then a list is
        # returned. We want to handle both cases, so reduce the case of the single
        # tensor to the case of multiple tensors.
        if not isinstance(tensors, list):
            tensors = [tensors]

        val_list.append(tensors)
        label_list.append(
            array_ops.ones([num_examples], dtype=dtypes.int32) * i)

    # Create a list of tensor of values. val_list is of dimension
    # [num_classes x len(tensors)]. We want list_batch_vals to be of dimension
    # [len(tensors)].
    num_data = len(val_list[0])
    list_batch_vals = [
        array_ops.concat(0, [val_list[i][j] for i in range(num_classes)])
        for j in range(num_data)
    ]

    # Create a tensor of labels.
    batch_labels = array_ops.concat(0, label_list)
    batch_labels.set_shape([batch_size])

    # Debug instrumentation.
    sample_tags = [
        'stratified_sample/%s/samples_class%i' % (batch_labels.name, i)
        for i in range(num_classes)
    ]
    logging_ops.scalar_summary(
        sample_tags,
        math_ops.reduce_sum(array_ops.one_hot(batch_labels, num_classes), 0))

    return list_batch_vals, batch_labels
示例#50
0
    def make_grouping_predictions(self, input_layer, reuse=None):
        """model that predicts grouping (grouping_actions).

    Args:
      input_layer: group_input_layer
      reuse: reuse

    Returns:
       grouping_actions: actions
       grouping_log_probs: log probabilities corresponding to actions
    """
        with variable_scope.variable_scope(self.hparams.name, reuse=True):
            # input_layer: tensor of size [1, num_ops, hidden_size]
            w_grouping_ff = variable_scope.get_variable("w_grouping_ff")
            w_grouping_softmax = variable_scope.get_variable(
                "w_grouping_softmax")

        batch_size = array_ops.shape(input_layer)[0]
        embedding_dim = array_ops.shape(input_layer)[2]

        reshaped = array_ops.reshape(
            input_layer, [batch_size * self.num_ops, embedding_dim])
        ff_output = math_ops.matmul(reshaped, w_grouping_ff)
        logits = math_ops.matmul(ff_output, w_grouping_softmax)
        if self.hparams.logits_std_noise > 0:
            num_in_logits = math_ops.cast(array_ops.size(logits),
                                          dtype=dtypes.float32)
            avg_norm = math_ops.divide(linalg_ops.norm(logits),
                                       math_ops.sqrt(num_in_logits))
            logits_noise = random_ops.random_normal(
                array_ops.shape(logits),
                stddev=self.hparams.logits_std_noise * avg_norm)
            logits = control_flow_ops.cond(
                self.global_step > self.hparams.stop_noise_step,
                lambda: logits, lambda: logits + logits_noise)
        logits = array_ops.reshape(
            logits, [batch_size * self.num_ops, self.num_groups])
        actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
        actions = math_ops.to_int32(actions)
        actions = array_ops.reshape(actions, [batch_size, self.num_ops])
        action_label = array_ops.reshape(actions, [-1])
        log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=action_label)
        log_probs = array_ops.reshape(log_probs, [batch_size, -1])
        log_probs = math_ops.reduce_sum(log_probs, 1)
        grouping_actions = actions
        grouping_log_probs = log_probs
        return grouping_actions, grouping_log_probs
  def make_grouping_predictions(self, input_layer, reuse=None):
    """model that predicts grouping (grouping_actions).

    Args:
      input_layer: group_input_layer
      reuse: reuse

    Returns:
       grouping_actions: actions
       grouping_log_probs: log probabilities corresponding to actions
    """
    with variable_scope.variable_scope(self.hparams.name, reuse=True):
      # input_layer: tensor of size [1, num_ops, hidden_size]
      w_grouping_ff = variable_scope.get_variable("w_grouping_ff")
      w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax")

    batch_size = array_ops.shape(input_layer)[0]
    embedding_dim = array_ops.shape(input_layer)[2]

    reshaped = array_ops.reshape(input_layer,
                                 [batch_size * self.num_ops, embedding_dim])
    ff_output = math_ops.matmul(reshaped, w_grouping_ff)
    logits = math_ops.matmul(ff_output, w_grouping_softmax)
    if self.hparams.logits_std_noise > 0:
      num_in_logits = math_ops.cast(
          array_ops.size(logits), dtype=dtypes.float32)
      avg_norm = math_ops.divide(
          linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
      logits_noise = random_ops.random_normal(
          array_ops.shape(logits),
          stddev=self.hparams.logits_std_noise * avg_norm)
      logits = control_flow_ops.cond(
          self.global_step > self.hparams.stop_noise_step, lambda: logits,
          lambda: logits + logits_noise)
    logits = array_ops.reshape(logits,
                               [batch_size * self.num_ops, self.num_groups])
    actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
    actions = math_ops.to_int32(actions)
    actions = array_ops.reshape(actions, [batch_size, self.num_ops])
    action_label = array_ops.reshape(actions, [-1])
    log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=action_label)
    log_probs = array_ops.reshape(log_probs, [batch_size, -1])
    log_probs = math_ops.reduce_sum(log_probs, 1)
    grouping_actions = actions
    grouping_log_probs = log_probs
    return grouping_actions, grouping_log_probs
 def testLargeDynamicRange2(self):
   random_seed.set_random_seed(10)
   counts_by_indices = {}
   with self.test_session(use_gpu=True) as sess:
     samples = random_ops.multinomial(
         constant_op.constant([[0, -30]], dtype=dtypes.float32),
         num_samples=1000000,
         seed=15)
     for _ in range(100):
       x = self.evaluate(samples)
       indices, counts = np.unique(x, return_counts=True)
       for index, count in zip(indices, counts):
         if index in counts_by_indices.keys():
           counts_by_indices[index] += count
         else:
           counts_by_indices[index] = count
   self.assertEqual(counts_by_indices[0], 100000000)
 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
   k = self.event_shape_tensor()[0]
   unnormalized_logits = array_ops.reshape(
       math_ops.log(random_ops.random_gamma(
           shape=[n],
           alpha=self.concentration,
           dtype=self.dtype,
           seed=seed)),
       shape=[-1, k])
   draws = random_ops.multinomial(
       logits=unnormalized_logits,
       num_samples=n_draws,
       seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
   x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2)
   final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
   return array_ops.reshape(x, final_shape)
示例#54
0
def _get_batch_from_per_class_queues(per_class_queues, probs, batch_size):
  """Generates batches according to per-class-probabilities."""
  num_classes = probs.get_shape().num_elements()
  # Number of examples per class is governed by a multinomial distribution.
  # Note: multinomial takes unnormalized log probabilities for its first
  # argument, of dimension [batch_size, num_classes].
  examples = random_ops.multinomial(
      array_ops.expand_dims(math_ops.log(probs), 0), batch_size)

  # Prepare the data and label batches.
  val_list = []
  label_list = []
  for i in range(num_classes):
    num_examples = math_ops.reduce_sum(
        math_ops.cast(math_ops.equal(examples, i), dtypes.int32))
    tensors = per_class_queues[i].dequeue_many(num_examples)

    # If you enqueue a list with a single tensor, only a single tensor is
    # returned. If you enqueue a list with multiple tensors, then a list is
    # returned. We want to handle both cases, so reduce the case of the single
    # tensor to the case of multiple tensors.
    if not isinstance(tensors, list):
      tensors = [tensors]

    val_list.append(tensors)
    label_list.append(array_ops.ones([num_examples], dtype=dtypes.int32) * i)

  # Create a list of tensor of values. val_list is of dimension
  # [num_classes x len(tensors)]. We want list_batch_vals to be of dimension
  # [len(tensors)].
  num_data = len(val_list[0])
  list_batch_vals = [array_ops.concat(
      0, [val_list[i][j] for i in range(num_classes)]) for j in range(num_data)]

  # Create a tensor of labels.
  batch_labels = array_ops.concat(0, label_list)
  batch_labels.set_shape([batch_size])

  # Debug instrumentation.
  sample_tags = ['stratified_sample/%s/samples_class%i' % (batch_labels.name, i)
                 for i in range(num_classes)]
  logging_ops.scalar_summary(sample_tags, math_ops.reduce_sum(
      array_ops.one_hot(batch_labels, num_classes), 0))

  return list_batch_vals, batch_labels
  def testLargeDynamicRange3(self):
    random_seed.set_random_seed(10)
    counts_by_indices = {}
    # here the cpu undersamples and won't pass this test either
    with self.test_session(use_gpu=True) as sess:
      samples = random_ops.multinomial(
          constant_op.constant([[0, -17]], dtype=dtypes.float32),
          num_samples=1000000,
          seed=22)

      # we'll run out of memory if we try to draw 1e9 samples directly
      # really should fit in 12GB of memory...
      for _ in range(100):
        x = self.evaluate(samples)
        indices, counts = np.unique(x, return_counts=True)
        for index, count in zip(indices, counts):
          if index in counts_by_indices.keys():
            counts_by_indices[index] += count
          else:
            counts_by_indices[index] = count
    self.assertGreater(counts_by_indices[1], 0)