def testLogits(self):
        logits = [-42., 42.]
        dist = bernoulli.Bernoulli(logits=logits)
        self.assertAllClose(logits, self.evaluate(dist.logits))

        if not special:
            return

        self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))

        p = [0.01, 0.99, 0.42]
        dist = bernoulli.Bernoulli(probs=p)
        self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
Esempio n. 2
0
 def sample(self, time, outputs, state, name=None):
     with ops.name_scope(name, "%sSample" % type(self).__name__,
                         (time, outputs, state)):
         sampler = bernoulli.Bernoulli(probs=self._sampling_probability)
         return math_ops.cast(
             sampler.sample(sample_shape=self.batch_size, seed=self._seed),
             dtypes.bool)
Esempio n. 3
0
  def __init__(self, logits, targets=None, seed=None):
    dist = bernoulli.Bernoulli(logits=logits)
    self._logits = logits
    self._probs = dist.probs

    super(MultiBernoulliNegativeLogProbLoss, self).__init__(
        dist, targets=targets, seed=seed)
Esempio n. 4
0
 def testEntropyWithBatch(self):
     p = [[0.1, 0.7], [0.2, 0.6]]
     dist = bernoulli.Bernoulli(probs=p, validate_args=False)
     with self.test_session():
         self.assertAllClose(dist.entropy().eval(),
                             [[entropy(0.1), entropy(0.7)],
                              [entropy(0.2), entropy(0.6)]])
Esempio n. 5
0
  def testPmfShapes(self):
    with self.cached_session():
      p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
      dist = bernoulli.Bernoulli(probs=p)
      self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))

      dist = bernoulli.Bernoulli(probs=0.5)
      self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape))

      dist = bernoulli.Bernoulli(probs=0.5)
      self.assertEqual((), dist.log_prob(1).get_shape())
      self.assertEqual((1), dist.log_prob([1]).get_shape())
      self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())

      dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
      self.assertEqual((2, 1), dist.log_prob(1).get_shape())
Esempio n. 6
0
    def testLogits(self):
        logits = [-42., 42.]
        dist = bernoulli.Bernoulli(logits=logits)
        with self.test_session():
            self.assertAllClose(logits, dist.logits.eval())

        if not special:
            return

        with self.test_session():
            self.assertAllClose(special.expit(logits), dist.probs.eval())

        p = [0.01, 0.99, 0.42]
        dist = bernoulli.Bernoulli(probs=p)
        with self.test_session():
            self.assertAllClose(special.logit(p), dist.logits.eval())
 def sample(self, time, outputs, state, name=None):
   with ops.name_scope(name, "ScheduledOutputTrainingHelperSample",
                       [time, outputs, state]):
     sampler = bernoulli.Bernoulli(probs=self._sampling_probability)
     return math_ops.cast(
         sampler.sample(sample_shape=self.batch_size, seed=self._seed),
         dtypes.bool)
    def testBernoulliBernoulliKL(self):
        batch_size = 6
        a_p = np.array([0.5] * batch_size, dtype=np.float32)
        b_p = np.array([0.4] * batch_size, dtype=np.float32)

        a = bernoulli.Bernoulli(probs=a_p)
        b = bernoulli.Bernoulli(probs=b_p)

        kl = kullback_leibler.kl_divergence(a, b)
        kl_val = self.evaluate(kl)

        kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
            (1. - a_p) / (1. - b_p)))

        self.assertEqual(kl.get_shape(), (batch_size, ))
        self.assertAllClose(kl_val, kl_expected)
 def testPmfInvalid(self):
     p = [0.1, 0.2, 0.7]
     dist = bernoulli.Bernoulli(probs=p, validate_args=True)
     with self.assertRaisesOpError("must be non-negative."):
         self.evaluate(dist.prob([1, 1, -1]))
     with self.assertRaisesOpError("Elements cannot exceed 1."):
         self.evaluate(dist.prob([2, 0, 1]))
Esempio n. 10
0
  def testInvalidP(self):
    invalid_ps = [1.01, 2.]
    for p in invalid_ps:
      with self.assertRaisesOpError("probs has components greater than 1"):
        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
        self.evaluate(dist.probs)

    invalid_ps = [-0.01, -3.]
    for p in invalid_ps:
      with self.assertRaisesOpError("Condition x >= 0"):
        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
        self.evaluate(dist.probs)

    valid_ps = [0.0, 0.5, 1.0]
    for p in valid_ps:
      dist = bernoulli.Bernoulli(probs=p)
      self.assertEqual(p, self.evaluate(dist.probs))  # Should not fail
Esempio n. 11
0
 def testPmfWithFloatArgReturnsXEntropy(self):
   p = [[0.2], [0.4], [0.3], [0.6]]
   samps = [0, 0.1, 0.8]
   self.assertAllClose(
       np.float32(samps) * np.log(np.float32(p)) +
       (1 - np.float32(samps)) * np.log(1 - np.float32(p)),
       self.evaluate(
           bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps)))
Esempio n. 12
0
 def testPmfInvalid(self):
     p = [0.1, 0.2, 0.7]
     with self.test_session():
         dist = bernoulli.Bernoulli(probs=p, validate_args=True)
         with self.assertRaisesOpError("must be non-negative."):
             dist.prob([1, 1, -1]).eval()
         with self.assertRaisesOpError("is not less than or equal to 1."):
             dist.prob([2, 0, 1]).eval()
 def testNotReparameterized(self):
     p = constant_op.constant([0.2, 0.6])
     with backprop.GradientTape() as tape:
         tape.watch(p)
         dist = bernoulli.Bernoulli(probs=p)
         samples = dist.sample(100)
     grad_p = tape.gradient(samples, p)
     self.assertIsNone(grad_p)
Esempio n. 14
0
 def testBroadcasting(self):
     with self.test_session():
         p = array_ops.placeholder(dtypes.float32)
         dist = bernoulli.Bernoulli(probs=p)
         self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5}))
         self.assertAllClose(np.log([0.5, 0.5, 0.5]),
                             dist.log_prob([1, 1, 1]).eval({p: 0.5}))
         self.assertAllClose(np.log([0.5, 0.5, 0.5]),
                             dist.log_prob(1).eval({p: [0.5, 0.5, 0.5]}))
Esempio n. 15
0
 def testSampleN(self):
   p = [0.2, 0.6]
   dist = bernoulli.Bernoulli(probs=p)
   n = 100000
   samples = dist.sample(n)
   samples.set_shape([n, 2])
   self.assertEqual(samples.dtype, dtypes.int32)
   sample_values = self.evaluate(samples)
   self.assertTrue(np.all(sample_values >= 0))
   self.assertTrue(np.all(sample_values <= 1))
   # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
   # n). This means that the tolerance is very sensitive to the value of p
   # as well as n.
   self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
   self.assertEqual(set([0, 1]), set(sample_values.flatten()))
   # In this test we're just interested in verifying there isn't a crash
   # owing to mismatched types. b/30940152
   dist = bernoulli.Bernoulli(np.log([.2, .4]))
   self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
Esempio n. 16
0
 def testPmfCorrectBroadcastDynamicShape(self):
     with self.test_session():
         p = array_ops.placeholder(dtype=dtypes.float32)
         dist = bernoulli.Bernoulli(probs=p)
         event1 = [1, 0, 1]
         event2 = [[1, 0, 1]]
         self.assertAllClose(
             dist.prob(event1).eval({p: [0.2, 0.3, 0.4]}), [0.2, 0.7, 0.4])
         self.assertAllClose(
             dist.prob(event2).eval({p: [0.2, 0.3, 0.4]}),
             [[0.2, 0.7, 0.4]])
Esempio n. 17
0
 def sample(self, time, outputs, state, name=None):
     with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
                         [time, outputs, state]):
         # Return -1s where we did not sample, and sample_ids elsewhere
         select_sampler = bernoulli.Bernoulli(
             probs=self._sampling_probability, dtype=dtypes.bool)
         select_sample = select_sampler.sample(sample_shape=self.batch_size,
                                               seed=self._scheduling_seed)
         sample_id_sampler = categorical.Categorical(logits=outputs)
         return array_ops.where(select_sample,
                                sample_id_sampler.sample(seed=self._seed),
                                gen_array_ops.fill([self.batch_size], -1))
Esempio n. 18
0
 def testSampleActsLikeSampleN(self):
   with self.cached_session() as sess:
     p = [0.2, 0.6]
     dist = bernoulli.Bernoulli(probs=p)
     n = 1000
     seed = 42
     self.assertAllEqual(
         self.evaluate(dist.sample(n, seed)),
         self.evaluate(dist.sample(n, seed)))
     n = array_ops.placeholder(dtypes.int32)
     sample1, sample2 = sess.run([dist.sample(n, seed), dist.sample(n, seed)],
                                 feed_dict={n: 1000})
     self.assertAllEqual(sample1, sample2)
 def testVarianceAndStd(self):
     var = lambda p: p * (1. - p)
     p = [[0.2, 0.7], [0.5, 0.4]]
     dist = bernoulli.Bernoulli(probs=p)
     self.assertAllClose(
         self.evaluate(dist.variance()),
         np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
                  dtype=np.float32))
     self.assertAllClose(
         self.evaluate(dist.stddev()),
         np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
                   [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
                  dtype=np.float32))
 def testSampleAndLogProbShapesBroadcastMix(self):
     mix_probs = np.float32([.3, .7])
     bern_probs = np.float32([[.4, .6], [.25, .75]])
     bm = tfd.MixtureSameFamily(
         mixture_distribution=categorical_lib.Categorical(probs=mix_probs),
         components_distribution=bernoulli_lib.Bernoulli(probs=bern_probs))
     x = bm.sample([4, 5], seed=42)
     log_prob_x = bm.log_prob(x)
     x_ = self.evaluate(x)
     self.assertEqual([4, 5, 2], x.shape)
     self.assertEqual([4, 5, 2], log_prob_x.shape)
     self.assertAllEqual(np.ones_like(x_, dtype=np.bool),
                         np.logical_or(x_ == 0., x_ == 1.))
Esempio n. 21
0
    def _testMnistLike(self, static_shape):
        sample_shape = [4, 5]
        batch_shape = [10]
        image_shape = [28, 28, 1]
        logits = 3 * self._rng.random_sample(batch_shape + image_shape).astype(
            np.float32) - 1

        def expected_log_prob(x, logits):
            return (x * logits -
                    np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1)

        with self.test_session() as sess:
            logits_ph = array_ops.placeholder(
                dtypes.float32, shape=logits.shape if static_shape else None)
            ind = independent_lib.Independent(
                distribution=bernoulli_lib.Bernoulli(logits=logits_ph))
            x = ind.sample(sample_shape)
            log_prob_x = ind.log_prob(x)
            [
                x_,
                actual_log_prob_x,
                ind_batch_shape,
                ind_event_shape,
                x_shape,
                log_prob_x_shape,
            ] = sess.run([
                x,
                log_prob_x,
                ind.batch_shape_tensor(),
                ind.event_shape_tensor(),
                array_ops.shape(x),
                array_ops.shape(log_prob_x),
            ],
                         feed_dict={logits_ph: logits})

            if static_shape:
                ind_batch_shape = ind.batch_shape
                ind_event_shape = ind.event_shape
                x_shape = x.shape
                log_prob_x_shape = log_prob_x.shape

            self.assertAllEqual(batch_shape, ind_batch_shape)
            self.assertAllEqual(image_shape, ind_event_shape)
            self.assertAllEqual(sample_shape + batch_shape + image_shape,
                                x_shape)
            self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape)
            self.assertAllClose(expected_log_prob(x_, logits),
                                actual_log_prob_x,
                                rtol=1e-6,
                                atol=0.)
Esempio n. 22
0
    def get_next_input(inp, out):
      next_input = inp.read(time)
      if self._prenet is not None:
        next_input = self._prenet(next_input)
        out = self._prenet(out)
      if self._sampling_prob > 0.:
        next_input = tf.stop_gradient(next_input)
        out = tf.stop_gradient(out)
        select_sampler = bernoulli.Bernoulli(
            probs=self._sampling_prob, dtype=dtypes.bool
        )
        select_sample = select_sampler.sample(
            sample_shape=(self.batch_size, 1), seed=self._seed
        )
        select_sample = tf.tile(select_sample, [1, self._last_dim])
        sample_ids = array_ops.where(
            select_sample, out,
            gen_array_ops.fill(
                [self.batch_size, self._last_dim],
                tf.cast(-20., self._dtype)
            )
        )
        where_sampling = math_ops.cast(
            array_ops.where(sample_ids > -20), dtypes.int32
        )
        where_not_sampling = math_ops.cast(
            array_ops.where(sample_ids <= -20), dtypes.int32
        )
        sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(
            next_input, where_not_sampling
        )
        sampled_next_inputs = sample_ids_sampling
        base_shape = array_ops.shape(next_input)

        next_input = (
            array_ops.scatter_nd(
                indices=where_sampling,
                updates=sampled_next_inputs,
                shape=base_shape
            ) + array_ops.scatter_nd(
                indices=where_not_sampling,
                updates=inputs_not_sampling,
                shape=base_shape
            )
        )
      return next_input
Esempio n. 23
0
    def _testMnistLike(self, static_shape):
        sample_shape = [4, 5]
        batch_shape = [10]
        image_shape = [28, 28, 1]
        logits = 3 * self._rng.random_sample(batch_shape + image_shape).astype(
            np.float32) - 1

        def expected_log_prob(x, logits):
            return (x * logits -
                    np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1)

        logits_ph = tf.placeholder_with_default(
            input=logits, shape=logits.shape if static_shape else None)
        ind = tfd.Independent(distribution=bernoulli_lib.Bernoulli(
            logits=logits_ph))
        x = ind.sample(sample_shape, seed=42)
        log_prob_x = ind.log_prob(x)
        [
            x_,
            actual_log_prob_x,
            ind_batch_shape,
            ind_event_shape,
            x_shape,
            log_prob_x_shape,
        ] = self.evaluate([
            x,
            log_prob_x,
            ind.batch_shape_tensor(),
            ind.event_shape_tensor(),
            tf.shape(x),
            tf.shape(log_prob_x),
        ])

        if static_shape:
            ind_batch_shape = ind.batch_shape
            ind_event_shape = ind.event_shape
            x_shape = x.shape
            log_prob_x_shape = log_prob_x.shape

        self.assertAllEqual(batch_shape, ind_batch_shape)
        self.assertAllEqual(image_shape, ind_event_shape)
        self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape)
        self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape)
        self.assertAllClose(expected_log_prob(x_, logits),
                            actual_log_prob_x,
                            rtol=1e-6,
                            atol=0.)
Esempio n. 24
0
	def sample(self, time, outputs, state, name=None):
		with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
							[time, outputs, state]):
			# Return -1s where we did not sample, and sample_ids elsewhere
			select_sampler = bernoulli.Bernoulli(
				probs=self._sampling_probability, dtype=dtypes.bool)
			select_sample = select_sampler.sample(
				sample_shape=self.batch_size, seed=self._scheduling_seed)
			
# 			self.logs = tf.Print(select_sample, [select_sample])
# 			sample_id_sampler = categorical.Categorical(logits=outputs)
			sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32)
# 			select_sample = tf.ones(shape=(self.batch_size,), dtype=dtypes.bool, name="test")
			return array_ops.where(
				select_sample,
				sample_ids,
				gen_array_ops.fill([self.batch_size], -1))
def scheduled_sampling(hps, sampling_probability, output, embedding, inp):
    vocab_size = embedding.get_shape()[0].value
    with variable_scope.variable_scope("ScheduleEmbedding"):
        select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
        select_sample = select_sampler.sample(sample_shape=hps.batch_size)
        sample_id_sampler = categorical.Categorical(probs=output)
        sample_ids = array_ops.where(select_sample, sample_id_sampler.sample(seed=123), gen_array_ops.fill([hps.batch_size],-1))
        where_sampling = math_ops.cast(array_ops.where(sample_ids > -1), tf.int32)
        where_not_sampling = math_ops.cast(array_ops.where(sample_ids <= -1), tf.int32)
        sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
        cond = tf.less(sample_ids_sampling, vocab_size)
        sample_ids_sampling = tf.cast(cond, tf.int32) * sample_ids_sampling
        inputs_not_sampling = array_ops.gather_nd(inp, where_not_sampling)
        sampling_next_inputs = tf.nn.embedding_lookup(embedding, sample_ids_sampling)
        result1 = array_ops.scatter_nd(indices=where_sampling, updates=sampling_next_inputs, shape=array_ops.shape(inp))
        result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=array_ops.shape(inp))
        return result1 + result2
Esempio n. 26
0
  def scheduled_sampling(self, batch_size, sampling_probability, true, estimate):
    with variable_scope.variable_scope("ScheduledEmbedding"):
      # Return -1s where we do not sample, and sample_ids elsewhere
      select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
      select_sample = select_sampler.sample(sample_shape=batch_size)
      sample_ids = array_ops.where(
                  select_sample,
                  tf.range(batch_size),
                  gen_array_ops.fill([batch_size], -1))
      where_sampling = math_ops.cast(
          array_ops.where(sample_ids > -1), tf.int32)
      where_not_sampling = math_ops.cast(
          array_ops.where(sample_ids <= -1), tf.int32)
      _estimate = array_ops.gather_nd(estimate, where_sampling)
      _true = array_ops.gather_nd(true, where_not_sampling)

      base_shape = array_ops.shape(true)
      result1 = array_ops.scatter_nd(indices=where_sampling, updates=_estimate, shape=base_shape)
      result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=_true, shape=base_shape)
      result = result1 + result2
      return result1 + result2
Esempio n. 27
0
  def _testPmf(self, **kwargs):
    dist = bernoulli.Bernoulli(**kwargs)
    # pylint: disable=bad-continuation
    xs = [
        0,
        [1],
        [1, 0],
        [[1, 0]],
        [[1, 0], [1, 1]],
    ]
    expected_pmfs = [
        [[0.8, 0.6], [0.7, 0.4]],
        [[0.2, 0.4], [0.3, 0.6]],
        [[0.2, 0.6], [0.3, 0.4]],
        [[0.2, 0.6], [0.3, 0.4]],
        [[0.2, 0.6], [0.3, 0.6]],
    ]
    # pylint: enable=bad-continuation

    for x, expected_pmf in zip(xs, expected_pmfs):
      self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
      self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
Esempio n. 28
0
def scheduled_sampling_vocab_dist(hps, sampling_probability, output, embedding, inp, alpha = 0):
  # borrowed ideas from https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper

  def soft_argmax(alpha, output):
    #alpha_exp = tf.exp(alpha * output) # (batch_size, vocab_size)
    #one_hot_scores = alpha_exp / tf.reshape(tf.reduce_sum(alpha_exp, axis=1),[-1,1]) #(batch_size, vocab_size)
    one_hot_scores = tf.nn.softmax(alpha * output)
    return one_hot_scores

  def soft_top_k(alpha, output, K):
    copy = tf.identity(output)
    p = []
    arg_top_k = []
    for k in range(K):
      sargmax = soft_argmax(alpha, copy)
      copy = (1-sargmax)* copy
      p.append(tf.reduce_sum(sargmax * output, axis=1))
      arg_top_k.append(sargmax)

    return tf.stack(p, axis=1), tf.stack(arg_top_k)

  with variable_scope.variable_scope("ScheduledEmbedding"):
    # Return -1s where we did not sample, and sample_ids elsewhere
    select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
    select_sample = select_sampler.sample(sample_shape=hps.batch_size)
    sample_id_sampler = categorical.Categorical(probs=output) # equals to argmax{ Multinomial(output, total_count=1) }, our greedy search selection
    sample_ids = array_ops.where(
            select_sample,
            sample_id_sampler.sample(seed=123),
            gen_array_ops.fill([hps.batch_size], -1))

    where_sampling = math_ops.cast(
        array_ops.where(sample_ids > -1), tf.int32)
    where_not_sampling = math_ops.cast(
        array_ops.where(sample_ids <= -1), tf.int32)

    if hps.greedy_scheduled_sampling:
      sample_ids = tf.argmax(output, axis=1, output_type=tf.int32)

    sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
    inputs_not_sampling = array_ops.gather_nd(inp, where_not_sampling)

    if hps.E2EBackProp:
      if hps.hard_argmax:
        greedy_search_prob, greedy_search_sample = tf.nn.top_k(output, k=hps.k) # (batch_size, k)
        greedy_search_prob_normalized = greedy_search_prob/tf.reshape(tf.reduce_sum(greedy_search_prob,axis=1),[-1,1])
        greedy_embedding = tf.nn.embedding_lookup(embedding, greedy_search_sample)
        normalized_embedding = tf.multiply(tf.reshape(greedy_search_prob_normalized,[hps.batch_size,hps.k,1]), greedy_embedding)
        e2e_embedding = tf.reduce_mean(normalized_embedding,axis=1)
      else:
        e = []
        greedy_search_prob, greedy_search_sample = soft_top_k(alpha, output,
                                                              K=hps.k)  # (batch_size, k), (k, batch_size, vocab_size)
        greedy_search_prob_normalized = greedy_search_prob / tf.reshape(tf.reduce_sum(greedy_search_prob, axis=1),
                                                                        [-1, 1])

        for _ in range(hps.k):
          a_k = greedy_search_sample[_]
          e_k = tf.matmul(tf.reshape(greedy_search_prob_normalized[:,_],[-1,1]) * a_k, embedding)
          e.append(e_k)
        e2e_embedding = tf.reduce_sum(e, axis=0) # (batch_size, emb_dim)
      sampled_next_inputs = array_ops.gather_nd(e2e_embedding, where_sampling)
    else:
      if hps.hard_argmax:
        sampled_next_inputs = tf.nn.embedding_lookup(embedding, sample_ids_sampling)
      else: # using soft armax (greedy) proposed in: https://arxiv.org/abs/1704.06970
        #alpha_exp = tf.exp(alpha * (output_not_extended + G)) # (batch_size, vocab_size)
        #one_hot_scores = alpha_exp / tf.reduce_sum(alpha_exp, axis=1) #(batch_size, vocab_size)
        one_hot_scores = soft_argmax(alpha, output) #(batch_size, vocab_size)
        soft_argmax_embedding = tf.matmul(one_hot_scores, embedding) #(batch_size, emb_size)
        sampled_next_inputs = array_ops.gather_nd(soft_argmax_embedding, where_sampling)

    base_shape = array_ops.shape(inp)
    result1 = array_ops.scatter_nd(indices=where_sampling, updates=sampled_next_inputs, shape=base_shape)
    result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=base_shape)
    return result1 + result2
Esempio n. 29
0
 def testP(self):
     p = [0.2, 0.4]
     dist = bernoulli.Bernoulli(probs=p)
     with self.test_session():
         self.assertAllClose(p, dist.probs.eval())
Esempio n. 30
0
def make_bernoulli(batch_shape, dtype=dtypes.int32):
    p = np.random.uniform(size=list(batch_shape))
    p = constant_op.constant(p, dtype=dtypes.float32)
    return bernoulli.Bernoulli(probs=p, dtype=dtype)