def test_expand_first_dimension(self): inputs = tf.constant( [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype=tf.int32) dims = [3, 2] expanded_tensor = shape_utils.expand_first_dimension(inputs, dims) with self.test_session() as sess: expanded_tensor_out = sess.run(expanded_tensor) expected_output = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] self.assertAllEqual(expected_output, expanded_tensor_out)
def test_expand_first_dimension_with_incompatible_dims(self): inputs_default = tf.constant([ [[1, 2]], [[3, 4]], [[5, 6]], ], dtype=tf.int32) inputs = tf.placeholder_with_default(inputs_default, [None, 1, 2]) dims = [3, 2] expanded_tensor = shape_utils.expand_first_dimension(inputs, dims) with self.test_session() as sess: with self.assertRaises(tf.errors.InvalidArgumentError): sess.run(expanded_tensor)
def graph_fn(): inputs = tf.constant( [ [1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12] ], dtype=tf.int32) dims = [3, 2] expanded_tensor = shape_utils.expand_first_dimension( inputs, dims) return expanded_tensor