コード例 #1
0
 def test_flatten_given_dims(self):
   inputs = tf.random_uniform([5, 2, 10, 10, 3])
   actual_flattened = shape_utils.flatten_dimensions(inputs, first=1, last=3)
   expected_flattened = tf.reshape(inputs, [5, 20, 10, 3])
   with self.test_session() as sess:
     (actual_flattened_np,
      expected_flattened_np) = sess.run([actual_flattened, expected_flattened])
   self.assertAllClose(expected_flattened_np, actual_flattened_np)
コード例 #2
0
 def test_raises_value_error_incorrect_dimensions(self):
     inputs = tf.random_uniform([5, 2, 10, 10, 3])
     with self.assertRaises(ValueError):
         shape_utils.flatten_dimensions(inputs, first=0, last=6)
コード例 #3
0
    def graph_fn():
      inputs = tf.random_uniform([5, 2, 10, 10, 3])
      actual_flattened = shape_utils.flatten_dimensions(inputs, first=1, last=3)
      expected_flattened = tf.reshape(inputs, [5, 20, 10, 3])

      return actual_flattened, expected_flattened