def testSampleCPU(self): with tf.device('CPU'): _, runtime = self.evaluate( gamma_lib.random_gamma_with_runtime( shape=tf.constant([], dtype=tf.int32), concentration=tf.constant(1.), seed=test_util.test_seed())) self.assertEqual(implementation_selection._RUNTIME_CPU, runtime)
def testSampleGPU(self): if not tf.test.is_gpu_available(): self.skipTest('no GPU') with tf.device('GPU'): _, runtime = self.evaluate(gamma_lib.random_gamma_with_runtime( shape=tf.constant([], dtype=tf.int32), concentration=tf.constant(1.), seed=test_util.test_seed())) self.assertEqual(implementation_selection._RUNTIME_DEFAULT, runtime)
def testSampleXLA(self): self.skip_if_no_xla() if not tf.executing_eagerly(): return # jit_compile is eager-only. concentration = np.exp(np.random.rand(4, 3).astype(np.float32)) rate = np.exp(np.random.rand(4, 3).astype(np.float32)) dist = tfd.Gamma(concentration=concentration, rate=rate, validate_args=True) # Verify the compile succeeds going all the way through the distribution. self.evaluate( tf.function(lambda: dist.sample(5, seed=test_util.test_seed()), jit_compile=True)()) # Also test the low-level sampler and verify the XLA-friendly variant. # TODO(bjp): functools.partial, after eliminating PY2 which breaks # tf_inspect in interesting ways: # ValueError: Some arguments ['concentration', 'rate'] do not have default # value, but they are positioned after those with default values. This can # not be expressed with ArgSpec. scalar_gamma = tf.function( lambda **kwds: gamma_lib.random_gamma_with_runtime(shape=[], **kwds), jit_compile=True) _, runtime = self.evaluate( scalar_gamma( concentration=tf.constant(1.), seed=test_util.test_seed())) self.assertEqual(implementation_selection._RUNTIME_DEFAULT, runtime)