Example #1
0
 def test_maybe_tile_tensor_for_mc_tiles_scalar_tensor(self):
     num_samples = 16
     x = 3.14
     tensor_in = tf.constant(x, dtype=tf.float32)
     tensor_out = monte_carlo_manager._maybe_tile_tensor_for_mc(
         tensor_in, num_samples)
     self.assertEqual(tensor_out.shape.as_list(), [num_samples])
     with self.session() as session:
         tensor_out_eval = session.run(tensor_out)
     self.assertAllEqual(tensor_out_eval,
                         np.asarray([x] * num_samples, dtype=np.float32))
Example #2
0
 def test_maybe_tile_tensor_for_mc_does_not_tile_2d_tensor(self):
   np.random.seed(0)
   num_samples = 16
   num_dims = 7
   x = np.random.uniform(size=[num_samples, num_dims])
   tensor_in = tf.constant(x)
   tensor_out = monte_carlo_manager._maybe_tile_tensor_for_mc(
       tensor_in, num_samples)
   self.assertAllEqual(tensor_out.shape.as_list(), [num_samples, num_dims])
   with self.session() as session:
     tensor_out_eval = session.run(tensor_out)
   self.assertAllEqual(tensor_out_eval, x)
Example #3
0
 def test_reshape_initial_state_for_mc_reshapes_tuple(self):
   np.random.seed(0)
   num_samples = 16
   tensors_in = (tf.random_uniform(shape=[3]), tf.random_uniform(shape=[7]),
                 tf.random_uniform(shape=()))
   tensors_out = monte_carlo_manager._reshape_initial_state_for_mc(
       tensors_in, num_samples)
   expected_tensors_out = (monte_carlo_manager._maybe_tile_tensor_for_mc(
       tensors_in[0], num_samples),
                           monte_carlo_manager._maybe_tile_tensor_for_mc(
                               tensors_in[1], num_samples),
                           monte_carlo_manager._maybe_tile_tensor_for_mc(
                               tensors_in[2], num_samples))
   for i in range(len(expected_tensors_out)):
     self.assertAllEqual(tensors_out[i].shape.as_list(),
                         expected_tensors_out[i].shape.as_list())
   with self.session() as session:
     (tensors_out_eval, expected_tensors_out_eval) = session.run(
         (tensors_out, expected_tensors_out))
   for i in range(len(expected_tensors_out)):
     self.assertAllEqual(tensors_out_eval[i], expected_tensors_out_eval[i])