Ejemplo n.º 1
0
 def testEmptyStateless(self):
   with self.cached_session() as sess:
     with self.test_scope():
       seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
       x = stateless_random_ops.stateless_multinomial(
           array_ops.zeros([42, 40]),
           0,
           seed=seed_t,
           output_dtype=dtypes.int32)
       y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
       self.assertEqual(y.shape, (42, 0))
 def testEmptyStateless(self):
     with self.cached_session() as sess:
         with self.test_scope():
             seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
             x = stateless_random_ops.stateless_multinomial(
                 array_ops.zeros([42, 40]),
                 0,
                 seed=seed_t,
                 output_dtype=dtypes.int32)
             y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
             self.assertEqual(y.shape, (42, 0))
Ejemplo n.º 3
0
 def testDeterminismMultinomial(self):
   # Stateless values should be equal iff the seeds are equal (roughly)
   num_samples = 10
   with self.cached_session(), self.test_scope():
     seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
     seeds = [(x, y) for x in range(5) for y in range(5)] * 3
     for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
                                               [0.25, 0.75]]):
       pure = stateless_random_ops.stateless_multinomial(
           logits, num_samples, seed=seed_t)
       values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds]
       for s0, v0 in values:
         for s1, v1 in values:
           self.assertEqual(s0 == s1, np.all(v0 == v1))
Ejemplo n.º 4
0
 def testStatelessMultinomialIsInRange(self):
   for dtype in self.float_types:
     for output_dtype in self.output_dtypes():
       with self.cached_session() as sess:
         with self.test_scope():
           seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
           x = stateless_random_ops.stateless_multinomial(
               array_ops.ones(shape=[1, 20], dtype=dtype),
               1000,
               seed_t,
               output_dtype=output_dtype)
         y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
         self.assertTrue((y >= 0).sum() == 1000)
         self.assertTrue((y < 20).sum() == 1000)
Ejemplo n.º 5
0
 def testStatelessMultinomialIsInRange(self):
     for dtype in self.float_types:
         for output_dtype in self.output_dtypes():
             with self.cached_session() as sess:
                 with self.test_scope():
                     seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
                     x = stateless_random_ops.stateless_multinomial(
                         array_ops.ones(shape=[1, 20], dtype=dtype),
                         1000,
                         seed_t,
                         output_dtype=output_dtype)
                 y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
                 self.assertTrue((y >= 0).sum() == 1000)
                 self.assertTrue((y < 20).sum() == 1000)
Ejemplo n.º 6
0
 def testDeterminismMultinomial(self):
     # Stateless values should be equal iff the seeds are equal (roughly)
     num_samples = 10
     with self.cached_session(), self.test_scope():
         seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
         seeds = [(x, y) for x in range(5) for y in range(5)] * 3
         for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
                                                   [0.25, 0.75]]):
             pure = stateless_random_ops.stateless_multinomial(logits,
                                                               num_samples,
                                                               seed=seed_t)
             values = [(seed, pure.eval(feed_dict={seed_t: seed}))
                       for seed in seeds]
             for s0, v0 in values:
                 for s1, v1 in values:
                     self.assertEqual(s0 == s1, np.all(v0 == v1))