Beispiel #1
0
 def testSampleCPU(self):
     with tf.device('CPU'):
         _, runtime = self.evaluate(
             gamma_lib.random_gamma_with_runtime(
                 shape=tf.constant([], dtype=tf.int32),
                 concentration=tf.constant(1.),
                 seed=test_util.test_seed()))
     self.assertEqual(implementation_selection._RUNTIME_CPU, runtime)
Beispiel #2
0
 def testSampleGPU(self):
   if not tf.test.is_gpu_available():
     self.skipTest('no GPU')
   with tf.device('GPU'):
     _, runtime = self.evaluate(gamma_lib.random_gamma_with_runtime(
         shape=tf.constant([], dtype=tf.int32),
         concentration=tf.constant(1.),
         seed=test_util.test_seed()))
   self.assertEqual(implementation_selection._RUNTIME_DEFAULT, runtime)
Beispiel #3
0
 def testSampleXLA(self):
   self.skip_if_no_xla()
   if not tf.executing_eagerly(): return  # jit_compile is eager-only.
   concentration = np.exp(np.random.rand(4, 3).astype(np.float32))
   rate = np.exp(np.random.rand(4, 3).astype(np.float32))
   dist = tfd.Gamma(concentration=concentration, rate=rate, validate_args=True)
   # Verify the compile succeeds going all the way through the distribution.
   self.evaluate(
       tf.function(lambda: dist.sample(5, seed=test_util.test_seed()),
                   jit_compile=True)())
   # Also test the low-level sampler and verify the XLA-friendly variant.
   # TODO(bjp): functools.partial, after eliminating PY2 which breaks
   # tf_inspect in interesting ways:
   # ValueError: Some arguments ['concentration', 'rate'] do not have default
   # value, but they are positioned after those with default values. This can
   # not be expressed with ArgSpec.
   scalar_gamma = tf.function(
       lambda **kwds: gamma_lib.random_gamma_with_runtime(shape=[], **kwds),
       jit_compile=True)
   _, runtime = self.evaluate(
       scalar_gamma(
           concentration=tf.constant(1.),
           seed=test_util.test_seed()))
   self.assertEqual(implementation_selection._RUNTIME_DEFAULT, runtime)