def test_flatten_given_dims(self): inputs = tf.random_uniform([5, 2, 10, 10, 3]) actual_flattened = shape_utils.flatten_dimensions(inputs, first=1, last=3) expected_flattened = tf.reshape(inputs, [5, 20, 10, 3]) with self.test_session() as sess: (actual_flattened_np, expected_flattened_np) = sess.run([actual_flattened, expected_flattened]) self.assertAllClose(expected_flattened_np, actual_flattened_np)
def test_raises_value_error_incorrect_dimensions(self): inputs = tf.random_uniform([5, 2, 10, 10, 3]) with self.assertRaises(ValueError): shape_utils.flatten_dimensions(inputs, first=0, last=6)
def graph_fn(): inputs = tf.random_uniform([5, 2, 10, 10, 3]) actual_flattened = shape_utils.flatten_dimensions(inputs, first=1, last=3) expected_flattened = tf.reshape(inputs, [5, 20, 10, 3]) return actual_flattened, expected_flattened