def _ParseSingleSequenceExampleShape(op): # pylint: disable=invalid-name """Shape function for the ParseExample op.""" op.inputs[0].get_shape().with_rank(0) # input # feature_list_dense_missing_assumed_empty op.inputs[1].get_shape().with_rank(1) num_context_sparse = op.get_attr("Ncontext_sparse") num_context_dense = op.get_attr("Ncontext_dense") num_feature_list_dense = op.get_attr("Nfeature_list_dense") context_dense_shapes = op.get_attr("context_dense_shapes") num_feature_list_sparse = op.get_attr("Nfeature_list_sparse") feature_list_dense_shapes = op.get_attr("feature_list_dense_shapes") context_sparse_index_shapes = [ tensor_shape.matrix(None, 1) for _ in range(num_context_sparse)] context_sparse_value_shapes = [ tensor_shape.vector(None) for _ in range(num_context_sparse)] context_sparse_shape_shapes = [ tensor_shape.vector(1) for _ in range(num_context_sparse)] context_dense_shapes = [ tensor_shape.TensorShape(dense_shape) for dense_shape in context_dense_shapes] feature_list_sparse_index_shapes = [ tensor_shape.matrix(None, 2) for _ in range(num_feature_list_sparse)] feature_list_sparse_value_shapes = [ tensor_shape.vector(None) for _ in range(num_feature_list_sparse)] feature_list_sparse_shape_shapes = [ tensor_shape.vector(2) for _ in range(num_feature_list_sparse)] feature_list_dense_shapes = [ tensor_shape.vector(None).concatenate(dense_shape) for dense_shape in feature_list_dense_shapes] assert num_context_dense == len(context_dense_shapes) assert num_feature_list_dense == len(feature_list_dense_shapes) return (context_sparse_index_shapes + context_sparse_value_shapes + context_sparse_shape_shapes + context_dense_shapes + feature_list_sparse_index_shapes + feature_list_sparse_value_shapes + feature_list_sparse_shape_shapes + feature_list_dense_shapes)
def testShapes(self): fdef = self._build_function_def() g = function_def_to_graph.function_def_to_graph(fdef) self.assertIsNone(g.inputs[0].shape.dims) # Unknown dims. self.assertIsNone(g.inputs[1].shape.dims) # Unknown dims. self.assertIsNone(g.outputs[0].shape.dims) # Unknown dims. self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims. g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[tensor_shape.vector(5), tensor_shape.vector(5)]) self.assertSequenceEqual(g.inputs[0].shape.dims, [5]) self.assertSequenceEqual(g.inputs[1].shape.dims, [5]) self.assertSequenceEqual(g.outputs[0].shape.dims, [5]) self.assertSequenceEqual(g.outputs[1].shape.dims, [5]) g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[None, tensor_shape.matrix(5, 7)]) self.assertIsNone(g.inputs[0].shape.dims) self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7]) self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7]) self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7]) # Should raise a ValueError if the length of input_shapes does not match # the number of input args in FunctionDef.signature.input_arg. with self.assertRaises(ValueError): g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[tensor_shape.matrix(5, 7)])
def _ParseSingleSequenceExampleShape(op): """Shape function for the ParseExample op.""" op.inputs[0].get_shape().with_rank(0) # input # feature_list_dense_missing_assumed_empty op.inputs[1].get_shape().with_rank(1) num_context_sparse = op.get_attr("Ncontext_sparse") num_context_dense = op.get_attr("Ncontext_dense") num_feature_list_dense = op.get_attr("Nfeature_list_dense") context_dense_shapes = op.get_attr("context_dense_shapes") num_feature_list_sparse = op.get_attr("Nfeature_list_sparse") feature_list_dense_shapes = op.get_attr("feature_list_dense_shapes") context_sparse_index_shapes = [tensor_shape.matrix(None, 1) for _ in range(num_context_sparse)] context_sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_context_sparse)] context_sparse_shape_shapes = [tensor_shape.vector(1) for _ in range(num_context_sparse)] context_dense_shapes = [tensor_shape.TensorShape(dense_shape) for dense_shape in context_dense_shapes] feature_list_sparse_index_shapes = [tensor_shape.matrix(None, 2) for _ in range(num_feature_list_sparse)] feature_list_sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_feature_list_sparse)] feature_list_sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_feature_list_sparse)] feature_list_dense_shapes = [ tensor_shape.vector(None).concatenate(dense_shape) for dense_shape in feature_list_dense_shapes ] assert num_context_dense == len(context_dense_shapes) assert num_feature_list_dense == len(feature_list_dense_shapes) return ( context_sparse_index_shapes + context_sparse_value_shapes + context_sparse_shape_shapes + context_dense_shapes + feature_list_sparse_index_shapes + feature_list_sparse_value_shapes + feature_list_sparse_shape_shapes + feature_list_dense_shapes )
def _CTCGreedyDecoderShape(op): """Shape function for the CTCGreedyDecoder op.""" inputs_shape = op.inputs[0].get_shape().with_rank(3) sequence_length_shape = op.inputs[1].get_shape().with_rank(1) # merge batch_size sequence_length_shape[0].merge_with(inputs_shape[1]) inputs_shape[1].merge_with(sequence_length_shape[0]) batch_size = inputs_shape[1] # decoded_indices, decoded_values, decoded_shape, log_probability return [tensor_shape.matrix(None, 2), tensor_shape.vector(None), tensor_shape.vector(2), tensor_shape.matrix(batch_size, 1)]
def testBroadcast_many_dimensions(self): unknown = tensor_shape.unknown_shape() shape_0 = tensor_shape.scalar() shape_1 = tensor_shape.vector(1) shape_4 = tensor_shape.vector(4) shape_1x4 = tensor_shape.matrix(1, 4) shape_4x1 = tensor_shape.matrix(4, 1) shape_3x4 = tensor_shape.matrix(3, 4) shape_4x3 = tensor_shape.matrix(4, 3) # Tensors with same shape should have the same broadcast result. for shape in (shape_0, shape_1, shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3): self._assert_broadcast(expected=shape, shape1=shape, shape2=shape) # [] and [1] act like identity. for identity in (shape_0, shape_1): for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3): self._assert_broadcast(expected=shape, shape1=identity, shape2=shape) # Unknown in, unknown out. for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3): self._assert_broadcast(expected=unknown, shape1=shape, shape2=unknown) self._assert_broadcast(expected=shape_1x4, shape1=shape_4, shape2=shape_1x4) shape_4x4 = tensor_shape.matrix(4, 4) self._assert_broadcast(expected=shape_4x4, shape1=shape_4, shape2=shape_4x1) self._assert_broadcast(expected=shape_3x4, shape1=shape_4, shape2=shape_3x4) self._assert_incompatible_broadcast(shape1=shape_4, shape2=shape_4x3) self._assert_broadcast(expected=shape_4x4, shape1=shape_1x4, shape2=shape_4x1) self._assert_broadcast(expected=shape_3x4, shape1=shape_1x4, shape2=shape_3x4) self._assert_incompatible_broadcast(shape1=shape_1x4, shape2=shape_4x3) self._assert_incompatible_broadcast(shape1=shape_4x1, shape2=shape_3x4) self._assert_broadcast(expected=shape_4x3, shape1=shape_4x1, shape2=shape_4x3) self._assert_incompatible_broadcast(shape1=shape_3x4, shape2=shape_4x3)
def _SparseConcatShape(op): """Shape function for SparseConcat op.""" num_inputs = int(op.get_attr("N")) # TF flattens and concatenates all list inputs, so reconstruct the lists here. ind_shapes = [ ind.get_shape().with_rank(2) for ind in op.inputs[0:num_inputs] ] val_shapes = [ val.get_shape().with_rank(1) for val in op.inputs[num_inputs:2 * num_inputs] ] shape_shapes = [ shape.get_shape().with_rank(1) for shape in op.inputs[2 * num_inputs:] ] output_ind_rows = tensor_shape.Dimension(0) output_ind_cols = tensor_shape.Dimension(None) output_val_elems = tensor_shape.Dimension(0) output_shape_shape = tensor_shape.TensorShape(None) for i in xrange(num_inputs): num_elems_i = ind_shapes[i][0].merge_with(val_shapes[i][0]) output_ind_rows += num_elems_i output_ind_cols = output_ind_cols.merge_with(ind_shapes[i][1]) output_val_elems += num_elems_i output_shape_shape = output_shape_shape.merge_with(shape_shapes[i]) output_ind_shape = tensor_shape.matrix(output_ind_rows, output_ind_cols) output_val_shape = tensor_shape.vector(output_val_elems) return [output_ind_shape, output_val_shape, output_shape_shape]
def _SerializeManySparseShape(op): # pylint: disable=invalid-name """Shape function for SerializeSparse op.""" op.inputs[0].get_shape().with_rank(2) op.inputs[1].get_shape().with_rank(1) op.inputs[2].get_shape().with_rank(1) return [tensor_shape.matrix(None, 3)]
def _reverse_seq(input_seq, lengths): """Reverse a list of Tensors up to specified lengths. Args: input_seq: Sequence of seq_len tensors of dimension (batch_size, depth) lengths: A tensor of dimension batch_size, containing lengths for each sequence in the batch. If "None" is specified, simply reverses the list. Returns: time-reversed sequence """ if lengths is None: return list(reversed(input_seq)) input_shape = tensor_shape.matrix(None, None) for input_ in input_seq: input_shape.merge_with(input_.get_shape()) input_.set_shape(input_shape) # Join into (time, batch_size, depth) s_joined = array_ops.pack(input_seq) # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32 if lengths is not None: lengths = math_ops.to_int64(lengths) # Reverse along dimension 0 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) # Split again into list result = array_ops.unpack(s_reversed) for r in result: r.set_shape(input_shape) return result
def _ComputeAccidentalHitsShape(op): num_true = op.get_attr("num_true") # Validate that the input shape matches the attrs, even though it # does not influence the shape of the output. true_candidates_shape = op.inputs[0].get_shape().merge_with(tensor_shape.matrix(None, num_true)) output_shape = tensor_shape.vector(None) return [output_shape] * 3
def _SparseConcatShape(op): """Shape function for SparseConcat op.""" num_inputs = int(op.get_attr("N")) # TF flattens and concatenates all list inputs, so reconstruct the lists here. ind_shapes = [ind.get_shape().with_rank(2) for ind in op.inputs[0:num_inputs]] val_shapes = [val.get_shape().with_rank(1) for val in op.inputs[num_inputs:2 * num_inputs]] shape_shapes = [shape.get_shape().with_rank(1) for shape in op.inputs[2 * num_inputs:]] output_ind_rows = tensor_shape.Dimension(0) output_ind_cols = tensor_shape.Dimension(None) output_val_elems = tensor_shape.Dimension(0) output_shape_shape = tensor_shape.TensorShape(None) for i in xrange(num_inputs): num_elems_i = ind_shapes[i][0].merge_with(val_shapes[i][0]) output_ind_rows += num_elems_i output_ind_cols = output_ind_cols.merge_with(ind_shapes[i][1]) output_val_elems += num_elems_i output_shape_shape = output_shape_shape.merge_with(shape_shapes[i]) output_ind_shape = tensor_shape.matrix(output_ind_rows, output_ind_cols) output_val_shape = tensor_shape.vector(output_val_elems) return [output_ind_shape, output_val_shape, output_shape_shape]
def testHelpers(self): tensor_shape.TensorShape([]).assert_is_compatible_with( tensor_shape.scalar()) tensor_shape.TensorShape([37]).assert_is_compatible_with( tensor_shape.vector(37)) tensor_shape.TensorShape( [94, 43]).assert_is_compatible_with(tensor_shape.matrix(94, 43))
def testHelpers(self): tensor_shape.TensorShape([]).assert_is_compatible_with( tensor_shape.scalar()) tensor_shape.TensorShape([37]).assert_is_compatible_with( tensor_shape.vector(37)) tensor_shape.TensorShape([94, 43]).assert_is_compatible_with( tensor_shape.matrix(94, 43))
def _CandidateSamplerShape(op): true_classes_shape = op.inputs[0].get_shape().with_rank(2) batch_size = true_classes_shape[0] num_sampled = op.get_attr("num_sampled") num_true = op.get_attr("num_true") return [tensor_shape.vector(num_sampled), tensor_shape.matrix(batch_size, num_true), tensor_shape.vector(num_sampled)]
def _ComputeAccidentalHitsShape(op): num_true = op.get_attr("num_true") # Validate that the input shape matches the attrs, even though it # does not influence the shape of the output. true_candidates_shape = op.inputs[0].get_shape().merge_with( tensor_shape.matrix(None, num_true)) output_shape = tensor_shape.vector(None) return [output_shape] * 3
def _DeserializeSparseShape(op): # pylint: disable=invalid-name """Shape function for DeserializeManySparse op.""" serialized_sparse_shape = op.inputs[0].get_shape().with_rank(2) serialized_sparse_shape.merge_with( tensor_shape.TensorShape([None, 3])) return [tensor_shape.matrix(None, None), tensor_shape.vector(None), tensor_shape.vector(None)]
def _SparseTensorDenseMatMulShape(op): # pylint: disable=invalid-name """Shape function for SparseTensorDenseMatMul op.""" adjoint_b = op.get_attr("adjoint_b") op.inputs[0].get_shape().assert_has_rank(2) # a_indices op.inputs[1].get_shape().assert_has_rank(1) # a_values op.inputs[2].get_shape().merge_with(tensor_shape.vector(2)) # a_shape b_shape = op.inputs[3].get_shape().with_rank(2) output_shape_right = b_shape[0] if adjoint_b else b_shape[1] return [tensor_shape.matrix(None, output_shape_right)]
def _CTCBeamSearchDecoderShape(op): """Shape function for the CTCBeamSearchDecoder op.""" inputs_shape = op.inputs[0].get_shape().with_rank(3) sequence_length_shape = op.inputs[1].get_shape().with_rank(1) # merge batch size sequence_length_shape[0].merge_with(inputs_shape[1]) inputs_shape[1].merge_with(sequence_length_shape[0]) batch_size = inputs_shape[1] top_paths = op.get_attr("top_paths") # first the decoded indices output_shapes = [tensor_shape.matrix(None, 2) for _ in range(top_paths)] # next the decoded values output_shapes.extend([tensor_shape.vector(None) for _ in range(top_paths)]) # the shapes of the decoded values output_shapes.extend([tensor_shape.vector(2)] * top_paths) # the log_probability matrix output_shapes.append(tensor_shape.matrix(batch_size, top_paths)) return output_shapes
def testStr(self): self.assertEqual("<unknown>", str(tensor_shape.unknown_shape())) self.assertEqual("(?,)", str(tensor_shape.unknown_shape(ndims=1))) self.assertEqual("(?, ?)", str(tensor_shape.unknown_shape(ndims=2))) self.assertEqual("(?, ?, ?)", str(tensor_shape.unknown_shape(ndims=3))) self.assertEqual("()", str(tensor_shape.scalar())) self.assertEqual("(7,)", str(tensor_shape.vector(7))) self.assertEqual("(3, 8)", str(tensor_shape.matrix(3, 8))) self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2]))) self.assertEqual("(32, ?, 1, 9)", str(tensor_shape.TensorShape([32, None, 1, 9])))
def _ParseExampleShape(op): """Shape function for the ParseExample op.""" input_shape = op.inputs[0].get_shape().with_rank(1) op.inputs[1].get_shape().with_rank(1) # names num_sparse = op.get_attr("Nsparse") num_dense = op.get_attr("Ndense") dense_shapes = op.get_attr("dense_shapes") sparse_index_shapes = [tensor_shape.matrix(None, 2) for _ in range(num_sparse)] sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_sparse)] sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_sparse)] assert num_dense == len(dense_shapes) dense_shapes = [input_shape.concatenate(dense_shape) for dense_shape in dense_shapes] return sparse_index_shapes + sparse_value_shapes + sparse_shape_shapes + dense_shapes
def testBroadcast_many_dimensions(self): unknown = tensor_shape.unknown_shape() shape_0 = tensor_shape.scalar() shape_1 = tensor_shape.vector(1) shape_4 = tensor_shape.vector(4) shape_1x4 = tensor_shape.matrix(1, 4) shape_4x1 = tensor_shape.matrix(4, 1) shape_3x4 = tensor_shape.matrix(3, 4) shape_4x3 = tensor_shape.matrix(4, 3) # Tensors with same shape should have the same broadcast result. for shape in ( shape_0, shape_1, shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3): self._assert_broadcast(expected=shape, shape1=shape, shape2=shape) # [] and [1] act like identity. for identity in (shape_0, shape_1): for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3): self._assert_broadcast(expected=shape, shape1=identity, shape2=shape) # Unknown in, unknown out. for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3): self._assert_broadcast(expected=unknown, shape1=shape, shape2=unknown) self._assert_broadcast(expected=shape_1x4, shape1=shape_4, shape2=shape_1x4) shape_4x4 = tensor_shape.matrix(4, 4) self._assert_broadcast(expected=shape_4x4, shape1=shape_4, shape2=shape_4x1) self._assert_broadcast(expected=shape_3x4, shape1=shape_4, shape2=shape_3x4) self._assert_incompatible_broadcast(shape1=shape_4, shape2=shape_4x3) self._assert_broadcast( expected=shape_4x4, shape1=shape_1x4, shape2=shape_4x1) self._assert_broadcast( expected=shape_3x4, shape1=shape_1x4, shape2=shape_3x4) self._assert_incompatible_broadcast(shape1=shape_1x4, shape2=shape_4x3) self._assert_incompatible_broadcast(shape1=shape_4x1, shape2=shape_3x4) self._assert_broadcast( expected=shape_4x3, shape1=shape_4x1, shape2=shape_4x3) self._assert_incompatible_broadcast(shape1=shape_3x4, shape2=shape_4x3)
def _ParseExampleShape(op): """Shape function for the ParseExample op.""" input_shape = op.inputs[0].get_shape().with_rank(1) num_sparse = op.get_attr("Nsparse") num_dense = op.get_attr("Ndense") dense_shapes = op.get_attr("dense_shapes") sparse_index_shapes = [ tensor_shape.matrix(None, 2) for _ in range(num_sparse)] sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_sparse)] sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_sparse)] assert num_dense == len(dense_shapes) dense_shapes = [ input_shape.concatenate((d.size for d in dense_shape.dim)) for dense_shape in dense_shapes] return (sparse_index_shapes + sparse_value_shapes + sparse_shape_shapes + dense_shapes)
def _ParseExampleShape(op): # pylint: disable=invalid-name """Shape function for the ParseExample op.""" input_shape = op.inputs[0].get_shape().with_rank(1) op.inputs[1].get_shape().with_rank(1) # names num_sparse = op.get_attr("Nsparse") num_dense = op.get_attr("Ndense") dense_shapes = op.get_attr("dense_shapes") sparse_index_shapes = [ tensor_shape.matrix(None, 2) for _ in range(num_sparse)] sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_sparse)] sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_sparse)] assert num_dense == len(dense_shapes) dense_shapes = [ input_shape.concatenate(dense_shape) for dense_shape in dense_shapes] return (sparse_index_shapes + sparse_value_shapes + sparse_shape_shapes + dense_shapes)
def testStr(self): self.assertEqual("<unknown>", str(tensor_shape.unknown_shape())) self.assertEqual( "(None,)", str(tensor_shape.unknown_shape(rank=1)).replace("?", "None")) self.assertEqual( "(None, None)", str(tensor_shape.unknown_shape(rank=2)).replace("?", "None")) self.assertEqual( "(None, None, None)", str(tensor_shape.unknown_shape(rank=3)).replace("?", "None")) self.assertEqual( "(32, None, 1, 9)", str(tensor_shape.TensorShape([32, None, 1, 9])).replace("?", "None")) self.assertEqual("()", str(tensor_shape.scalar())) self.assertEqual("(7,)", str(tensor_shape.vector(7))) self.assertEqual("(3, 8)", str(tensor_shape.matrix(3, 8))) self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2])))
def testBroadcast_unknown_dims(self): unknown = tensor_shape.unknown_shape() shape_0 = tensor_shape.scalar() shape_1 = tensor_shape.vector(1) # pylint: disable=invalid-name shape_U = tensor_shape.vector(None) shape_1xU = tensor_shape.matrix(1, None) shape_Ux1 = tensor_shape.matrix(None, 1) shape_4xU = tensor_shape.matrix(4, None) shape_Ux4 = tensor_shape.matrix(None, 4) # pylint: enable=invalid-name # Tensors with same shape should have the same broadcast result. for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4): self._assert_broadcast_with_unknown_dims( expected=shape, shape1=shape, shape2=shape) # [] and [1] act like identity. for identity in (shape_0, shape_1): for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4): self._assert_broadcast_with_unknown_dims( expected=shape, shape1=identity, shape2=shape) # Unknown in, unknown out. for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4): self._assert_broadcast_with_unknown_dims( expected=unknown, shape1=shape, shape2=unknown) self._assert_broadcast_with_unknown_dims( expected=shape_1xU, shape1=shape_U, shape2=shape_1xU) shape_UxU = tensor_shape.matrix(None, None) # pylint: disable=invalid-name self._assert_broadcast_with_unknown_dims( expected=shape_UxU, shape1=shape_U, shape2=shape_Ux1) self._assert_broadcast_with_unknown_dims( expected=shape_4xU, shape1=shape_U, shape2=shape_4xU) self._assert_broadcast_with_unknown_dims( expected=shape_Ux4, shape1=shape_U, shape2=shape_Ux4) self._assert_broadcast_with_unknown_dims( expected=shape_UxU, shape1=shape_1xU, shape2=shape_Ux1) self._assert_broadcast_with_unknown_dims( expected=shape_4xU, shape1=shape_1xU, shape2=shape_4xU) self._assert_broadcast_with_unknown_dims( expected=shape_Ux4, shape1=shape_1xU, shape2=shape_Ux4) self._assert_broadcast_with_unknown_dims( expected=shape_4xU, shape1=shape_Ux1, shape2=shape_4xU) self._assert_broadcast_with_unknown_dims( expected=shape_Ux4, shape1=shape_Ux1, shape2=shape_Ux4) shape_4x4 = tensor_shape.matrix(4, 4) self._assert_broadcast_with_unknown_dims( expected=shape_4x4, shape1=shape_4xU, shape2=shape_Ux4)
def _PadShape(op): """Shape function for the Pad op. This op has two inputs: * input: A rank-N tensor. * paddings: An N-by-2 matrix, in which the i^th row contains the number of padding elements to add before and after `input` in the i^th dimension. It has one output, which has the same rank as input, and additional elements according to the values in paddings. Args: op: A Pad Operation. Returns: A single-element list containing the shape of the output. Raises: ValueError: If the input shapes are incompatible. """ paddings_shape = op.inputs[1].get_shape().with_rank(2) input_shape = op.inputs[0].get_shape() if input_shape.ndims == 0 and paddings_shape[0].value == 1: # TODO(irving): Remove once !kAllowLegacyScalars. input_shape = tensor_shape.TensorShape([1]) else: input_shape = input_shape.with_rank(paddings_shape[0].value) paddings_shape = paddings_shape.merge_with( tensor_shape.matrix(input_shape.ndims, 2)) paddings = tensor_util.ConstantValue(op.inputs[1]) if paddings is None: return [tensor_shape.unknown_shape(ndims=input_shape.ndims)] else: output_dims = [] for i, dim in enumerate(input_shape.dims): if paddings[i, 0] < 0 or paddings[i, 1] < 0: raise ValueError("paddings must be non-negative") output_dims.append(dim + paddings[i, 0] + paddings[i, 1]) return [tensor_shape.TensorShape(output_dims)]
def reverse_seq(input_seq, lengths): if lengths is None: return list(reversed(input_seq)) input_shape = tensor_shape.matrix(None, None) for input_ in input_seq: input_shape.merge_with(input_.get_shape()) input_.set_shape(input_shape) # Join into (time, batch_size, depth) s_joined = array_ops.pack(input_seq) # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32 if lengths is not None: lengths = math_ops.to_int64(lengths) # Reverse along dimension 0 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) # Split again into list result = array_ops.unpack(s_reversed) for r in result: r.set_shape(input_shape) return result
def _MultinomialShape(op): # pylint: disable=invalid-name logits_shape = op.inputs[0].get_shape().with_rank(2) batch_size = logits_shape[0] num_samples_or_none = tensor_util.constant_value(op.inputs[1]) return [tensor_shape.matrix(batch_size, num_samples_or_none)]
def _SparseFeatureCrossShape(unused_op): # pylint: disable=invalid-name return [ tensor_shape.matrix(None, 2), tensor_shape.vector(None), tensor_shape.vector(2) ]
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase, test_util.TensorFlowTestCase): # pylint: disable=g-long-lambda,protected-access @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec, [dtypes.float32], [[]]), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), tensor_array_ops.TensorArraySpec, [dtypes.variant], [[]]), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensorSpec, [dtypes.variant], [None]), ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]), ragged_tensor.RaggedTensorSpec, [dtypes.variant], [None]), ("Nested_0", lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), tuple, [dtypes.float32, dtypes.int32], [[], [3]]), ("Nested_1", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, dict, [dtypes.float32, dtypes.int32], [[], [3]]), ("Nested_2", lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, dict, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]), ) def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): value = value_fn() s = structure.type_spec_from_value(value) self.assertIsInstance(s, expected_structure) flat_types = structure.get_flat_tensor_types(s) self.assertEqual(expected_types, flat_types) flat_shapes = structure.get_flat_tensor_shapes(s) self.assertLen(flat_shapes, len(expected_shapes)) for expected, actual in zip(expected_shapes, flat_shapes): if expected is None: self.assertEqual(actual.ndims, None) else: self.assertEqual(actual.as_list(), expected) @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), lambda: [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=10) ], lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(3, ), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=0) ]), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [ sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) ], lambda: [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None, None]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), lambda: [ ragged_factory_ops.constant([[1, 2], [3, 4], []]), ragged_factory_ops.constant([[1], [2, 3, 4], [5]]), ], lambda: [ ragged_factory_ops.constant(1), ragged_factory_ops.constant([1, 2]), ragged_factory_ops.constant([[1], [2]]), ragged_factory_ops.constant([["a", "b"]]), ]), ("Nested", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) }], lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { "a": constant_op.constant(15), "b": constant_op.constant([4, 5, 6]) }, { "a": constant_op.constant(15), "b": sparse_tensor.SparseTensor( indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) @test_util.run_deprecated_v1 def testIsCompatibleWithStructure(self, original_value_fn, compatible_values_fn, incompatible_values_fn): original_value = original_value_fn() compatible_values = compatible_values_fn() incompatible_values = incompatible_values_fn() s = structure.type_spec_from_value(original_value) for compatible_value in compatible_values: self.assertTrue( structure.are_compatible( s, structure.type_spec_from_value(compatible_value))) for incompatible_value in incompatible_values: self.assertFalse( structure.are_compatible( s, structure.type_spec_from_value(incompatible_value))) @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(), size=0)), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[3]], values=[-1], dense_shape=[5]), lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])), ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]), lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]), lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1), lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]), lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])), ("Nested", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: { "a": constant_op.constant(42.0), "b": constant_op.constant([4, 5, 6]) }, lambda: { "a": constant_op.constant([1, 2, 3]), "b": constant_op.constant(37.0) }), ) # pyformat: disable def testStructureFromValueEquality(self, value1_fn, value2_fn, *not_equal_value_fns): # pylint: disable=g-generic-assert s1 = structure.type_spec_from_value(value1_fn()) s2 = structure.type_spec_from_value(value2_fn()) self.assertEqual(s1, s1) # check __eq__ operator. self.assertEqual(s1, s2) # check __eq__ operator. self.assertFalse(s1 != s1) # check __ne__ operator. self.assertFalse(s1 != s2) # check __ne__ operator. for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)): self.assertEqual(hash(c1), hash(c1)) self.assertEqual(hash(c1), hash(c2)) for value_fn in not_equal_value_fns: s3 = structure.type_spec_from_value(value_fn()) self.assertNotEqual(s1, s3) # check __ne__ operator. self.assertNotEqual(s2, s3) # check __ne__ operator. self.assertFalse(s1 == s3) # check __eq_ operator. self.assertFalse(s2 == s3) # check __eq_ operator. @parameterized.named_parameters( ("RaggedTensor_RaggedRank", structure.RaggedTensorStructure(dtypes.int32, None, 1), structure.RaggedTensorStructure(dtypes.int32, None, 2)), ("RaggedTensor_Shape", structure.RaggedTensorStructure(dtypes.int32, [3, None], 1), structure.RaggedTensorStructure(dtypes.int32, [5, None], 1)), ("RaggedTensor_DType", structure.RaggedTensorStructure(dtypes.int32, None, 1), structure.RaggedTensorStructure(dtypes.float32, None, 1)), ) def testRaggedStructureInequality(self, s1, s2): # pylint: disable=g-generic-assert self.assertNotEqual(s1, s2) # check __ne__ operator. self.assertFalse(s1 == s2) # check __eq__ operator. @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(), size=0)), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[3]], values=[-1], dense_shape=[5])), ("Nested", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: { "a": constant_op.constant(42.0), "b": constant_op.constant([4, 5, 6]) }, lambda: { "a": constant_op.constant([1, 2, 3]), "b": constant_op.constant(37.0) }), ) def testHash(self, value1_fn, value2_fn, value3_fn): s1 = structure.type_spec_from_value(value1_fn()) s2 = structure.type_spec_from_value(value2_fn()) s3 = structure.type_spec_from_value(value3_fn()) for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2), nest.flatten(s3)): self.assertEqual(hash(c1), hash(c1)) self.assertEqual(hash(c1), hash(c2)) self.assertNotEqual(hash(c1), hash(c3)) self.assertNotEqual(hash(c2), hash(c3)) @parameterized.named_parameters( ( "Tensor", lambda: constant_op.constant(37.0), ), ( "SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)), ( "RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), ), ( "Nested_0", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, ), ( "Nested_1", lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, ), ) def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.type_spec_from_value(value) def maybe_stack_ta(v): if isinstance(v, tensor_array_ops.TensorArray): return v.stack() else: return v before = self.evaluate(maybe_stack_ta(value)) after = self.evaluate( maybe_stack_ta( structure.from_tensor_list(s, structure.to_tensor_list(s, value)))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) for b, a in zip(flat_before, flat_after): if isinstance(b, sparse_tensor.SparseTensorValue): self.assertAllEqual(b.indices, a.indices) self.assertAllEqual(b.values, a.values) self.assertAllEqual(b.dense_shape, a.dense_shape) elif isinstance(b, (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)): self.assertAllEqual(b, a) else: self.assertAllEqual(b, a) # pylint: enable=g-long-lambda def preserveStaticShape(self): rt = ragged_factory_ops.constant([[1, 2], [], [3]]) rt_s = structure.type_spec_from_value(rt) rt_after = structure.from_tensor_list( rt_s, structure.to_tensor_list(rt_s, rt)) self.assertEqual(rt_after.row_splits.shape.as_list(), rt.row_splits.shape.as_list()) self.assertEqual(rt_after.values.shape.as_list(), [None]) st = sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5]) st_s = structure.type_spec_from_value(st) st_after = structure.from_tensor_list( st_s, structure.to_tensor_list(st_s, st)) self.assertEqual(st_after.indices.shape.as_list(), [None, 2]) self.assertEqual(st_after.values.shape.as_list(), [None]) self.assertEqual(st_after.dense_shape.shape.as_list(), st.dense_shape.shape.as_list()) def testIncompatibleStructure(self): # Define three mutually incompatible values/structures, and assert that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_tensor = constant_op.constant(42.0) s_tensor = structure.type_spec_from_value(value_tensor) flat_tensor = structure.to_tensor_list(s_tensor, value_tensor) value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor) flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor, value_sparse_tensor) value_nest = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_nest = structure.type_spec_from_value(value_nest) flat_nest = structure.to_tensor_list(s_nest, value_nest) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): structure.to_tensor_list(s_tensor, value_sparse_tensor) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_tensor, value_nest) with self.assertRaisesRegexp( TypeError, "Neither a SparseTensor nor SparseTensorValue"): structure.to_tensor_list(s_sparse_tensor, value_tensor) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_sparse_tensor, value_nest) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_nest, value_tensor) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_nest, value_sparse_tensor) with self.assertRaisesRegexp(ValueError, r"Incompatible input:"): structure.from_tensor_list(s_tensor, flat_sparse_tensor) with self.assertRaisesRegexp(ValueError, "Expected 1 tensors but got 2."): structure.from_tensor_list(s_tensor, flat_nest) with self.assertRaisesRegexp(ValueError, "Incompatible input: "): structure.from_tensor_list(s_sparse_tensor, flat_tensor) with self.assertRaisesRegexp(ValueError, "Expected 1 tensors but got 2."): structure.from_tensor_list(s_sparse_tensor, flat_nest) with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 1."): structure.from_tensor_list(s_nest, flat_tensor) with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 1."): structure.from_tensor_list(s_nest, flat_sparse_tensor) def testIncompatibleNestedStructure(self): # Define three mutually incompatible nested values/structures, and assert # that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructure a flattened value with an # incompatible structure fails. value_0 = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_0 = structure.type_spec_from_value(value_0) flat_s_0 = structure.to_tensor_list(s_0, value_0) # `value_1` has compatible nested structure with `value_0`, but different # classes. value_1 = { "a": constant_op.constant(37.0), "b": sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) } s_1 = structure.type_spec_from_value(value_1) flat_s_1 = structure.to_tensor_list(s_1, value_1) # `value_2` has incompatible nested structure with `value_0` and `value_1`. value_2 = { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) } s_2 = structure.type_spec_from_value(value_2) flat_s_2 = structure.to_tensor_list(s_2, value_2) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*int32.* and shape \(3,\)"): structure.to_tensor_list(s_0, value_1) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_0, value_2) with self.assertRaisesRegexp( TypeError, "Neither a SparseTensor nor SparseTensorValue"): structure.to_tensor_list(s_1, value_0) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_1, value_2) # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp # needs to account for "a" coming before or after "b". It might be worth # adding a deterministic repr for these error messages (among other # improvements). with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_2, value_0) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_2, value_1) with self.assertRaisesRegexp(ValueError, r"Incompatible input:"): structure.from_tensor_list(s_0, flat_s_1) with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 3."): structure.from_tensor_list(s_0, flat_s_2) with self.assertRaisesRegexp(ValueError, "Incompatible input: "): structure.from_tensor_list(s_1, flat_s_0) with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 3."): structure.from_tensor_list(s_1, flat_s_2) with self.assertRaisesRegexp(ValueError, "Expected 3 tensors but got 2."): structure.from_tensor_list(s_2, flat_s_0) with self.assertRaisesRegexp(ValueError, "Expected 3 tensors but got 2."): structure.from_tensor_list(s_2, flat_s_1) @parameterized.named_parameters( ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", dtypes.int32, tensor_shape.matrix( 2, 2), sparse_tensor.SparseTensor, structure.SparseTensorStructure(dtypes.int32, [2, 2])), ("TensorArray_0", dtypes.int32, tensor_shape.as_shape([None, True, 2, 2 ]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)), ("TensorArray_1", dtypes.int32, tensor_shape.as_shape([True, None, 2, 2 ]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)), ("TensorArray_2", dtypes.int32, tensor_shape.as_shape([True, False, 2, 2 ]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)), ("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, None), structure.RaggedTensorStructure(dtypes.int32, [2, None], 1), structure.RaggedTensorStructure(dtypes.int32, [2, None], 1)), ("Nested", { "a": dtypes.float32, "b": (dtypes.int32, dtypes.string) }, { "a": tensor_shape.scalar(), "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar()) }, { "a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor) }, { "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) }), ) def testConvertLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self.assertEqual(actual_structure, expected_structure) def testNestedNestedStructure(self): s = (structure.TensorStructure(dtypes.int64, []), (structure.TensorStructure(dtypes.float32, []), structure.TensorStructure(dtypes.string, []))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) string_t = constant_op.constant("Foo") nested_tensors = (int64_t, (float32_t, string_t)) tensor_list = structure.to_tensor_list(s, nested_tensors) for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): self.assertIs(expected, actual) (actual_int64_t, (actual_float32_t, actual_string_t)) = structure.from_tensor_list(s, tensor_list) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) (actual_int64_t, (actual_float32_t, actual_string_t)) = (structure.from_compatible_tensor_list( s, tensor_list)) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, []), 32, structure.TensorStructure(dtypes.float32, [32])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None, structure.TensorStructure(dtypes.float32, [None])), ("SparseTensor", structure.SparseTensorStructure( dtypes.float32, [None]), 32, structure.SparseTensorStructure(dtypes.float32, [32, None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [4]), None, structure.SparseTensorStructure(dtypes.float32, [None, 4])), ("RaggedTensor", structure.RaggedTensorStructure(dtypes.float32, [2, None], 1), 32, structure.RaggedTensorStructure(dtypes.float32, [32, 2, None], 2)), ("RaggedTensorUnknown", structure.RaggedTensorStructure(dtypes.float32, [4, None], 1), None, structure.RaggedTensorStructure(dtypes.float32, [None, 4, None], 2)), ("Nested", { "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) }, 128, { "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [128])) }), ) def testBatch(self, element_structure, batch_size, expected_batched_structure): batched_structure = nest.map_structure( lambda component_spec: component_spec._batch(batch_size), element_structure) self.assertEqual(batched_structure, expected_batched_structure) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, [32]), structure.TensorStructure(dtypes.float32, [])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", structure.SparseTensorStructure(dtypes.float32, [32, None]), structure.SparseTensorStructure(dtypes.float32, [None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [None, 4]), structure.SparseTensorStructure(dtypes.float32, [4])), ("RaggedTensor", structure.RaggedTensorStructure(dtypes.float32, [32, None, None], 2), structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)), ("RaggedTensorUnknown", structure.RaggedTensorStructure(dtypes.float32, [None, None, None], 2), structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)), ("Nested", { "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [None])) }, { "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) }), ) def testUnbatch(self, element_structure, expected_unbatched_structure): unbatched_structure = nest.map_structure( lambda component_spec: component_spec._unbatch(), element_structure) self.assertEqual(unbatched_structure, expected_unbatched_structure) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), lambda: constant_op.constant([1.0, 2.0])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]), lambda: sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2])), ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]), lambda: ragged_factory_ops.constant([[1]])), ("Nest", lambda: (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])), lambda: (constant_op.constant([1.0, 2.0]), sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2]))), ) def testToBatchedTensorList(self, value_fn, element_0_fn): batched_value = value_fn() s = structure.type_spec_from_value(batched_value) batched_tensor_list = structure.to_batched_tensor_list( s, batched_value) # The batch dimension is 2 for all of the test cases. # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT # tensors in which we store sparse tensors. for t in batched_tensor_list: if t.dtype != dtypes.variant: self.assertEqual(2, self.evaluate(array_ops.shape(t)[0])) # Test that the 0th element from the unbatched tensor is equal to the # expected value. expected_element_0 = self.evaluate(element_0_fn()) unbatched_s = nest.map_structure( lambda component_spec: component_spec._unbatch(), s) actual_element_0 = structure.from_tensor_list( unbatched_s, [t[0] for t in batched_tensor_list]) for expected, actual in zip(nest.flatten(expected_element_0), nest.flatten(actual_element_0)): self.assertValuesEqual(expected, actual)
def output_shapes(self): num_elements = tensor_shape.Dimension(None) return (tensor_shape.matrix(num_elements, self._row_shape.shape[0] + 1), tensor_shape.vector(num_elements), tensor_shape.vector(self._row_shape.shape[0] + 1))
def _to_legacy_output_shapes(self): # Sneak the dynamic_size and infer_shape values into the legacy shape. return (tensor_shape.matrix(self._dynamic_size, self._infer_shape).concatenate( self._element_shape))
class StructureTest(test.TestCase, parameterized.TestCase): # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they # will be executed before the (eager- or graph-mode) test environment has been # set up. # pylint: disable=g-long-lambda,protected-access @parameterized.parameters( (lambda: constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32], [[]]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), structure.SparseTensorStructure, [dtypes.variant], [None]), (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3] ]), (lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, structure.NestedStructure, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None])) def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): value = value_fn() s = structure.Structure.from_value(value) self.assertIsInstance(s, expected_structure) self.assertEqual(expected_types, s._flat_types) for expected, actual in zip(expected_shapes, s._flat_shapes): self.assertTrue(actual.is_compatible_with(expected)) self.assertTrue( tensor_shape.as_shape(expected).is_compatible_with(actual)) @parameterized.parameters( (lambda: constant_op.constant(37.0), lambda: [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [ sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) ], lambda: [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None, None]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) }], lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { "a": constant_op.constant(15), "b": constant_op.constant([4, 5, 6]) }, { "a": constant_op.constant(15), "b": sparse_tensor.SparseTensor( indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) def testIsCompatibleWithStructure(self, original_value_fn, compatible_values_fn, incompatible_values_fn): original_value = original_value_fn() compatible_values = compatible_values_fn() incompatible_values = incompatible_values_fn() s = structure.Structure.from_value(original_value) for compatible_value in compatible_values: self.assertTrue( s.is_compatible_with( structure.Structure.from_value(compatible_value))) for incompatible_value in incompatible_values: self.assertFalse( s.is_compatible_with( structure.Structure.from_value(incompatible_value))) @parameterized.parameters( (lambda: constant_op.constant(37.0), ), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, ), (lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, ), ) def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.Structure.from_value(value) before = self.evaluate(value) after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) for b, a in zip(flat_before, flat_after): if isinstance(b, sparse_tensor.SparseTensorValue): self.assertAllEqual(b.indices, a.indices) self.assertAllEqual(b.values, a.values) self.assertAllEqual(b.dense_shape, a.dense_shape) else: self.assertAllEqual(b, a) # pylint: enable=g-long-lambda def testIncompatibleStructure(self): # Define three mutually incompatible values/structures, and assert that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_tensor = constant_op.constant(42.0) s_tensor = structure.Structure.from_value(value_tensor) flat_tensor = s_tensor._to_tensor_list(value_tensor) value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor) flat_sparse_tensor = s_sparse_tensor._to_tensor_list( value_sparse_tensor) value_nest = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_nest = structure.Structure.from_value(value_nest) flat_nest = s_nest._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Value \{.*\} is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_tensor) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"): s_tensor._from_tensor_list(flat_sparse_tensor) with self.assertRaisesRegexp( ValueError, "TensorStructure corresponds to a single tf.Tensor."): s_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_sparse_tensor) def testIncompatibleNestedStructure(self): # Define three mutually incompatible nested values/structures, and assert # that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_0 = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_0 = structure.Structure.from_value(value_0) flat_s_0 = s_0._to_tensor_list(value_0) # `value_1` has compatible nested structure with `value_0`, but different # classes. value_1 = { "a": constant_op.constant(37.0), "b": sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) } s_1 = structure.Structure.from_value(value_1) flat_s_1 = s_1._to_tensor_list(value_1) # `value_2` has incompatible nested structure with `value_0` and `value_1`. value_2 = { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) } s_2 = structure.Structure.from_value(value_2) flat_s_2 = s_2._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure"): s_0._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*SparseTensorStructure"): s_1._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp # needs to account for "a" coming before or after "b". It might be worth # adding a deterministic repr for these error messages (among other # improvements). with self.assertRaisesRegexp( ValueError, "Tensor.*Tensor.* not compatible with the nested structure " ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* " "not compatible with the nested structure .*" "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"): s_0._from_tensor_list(flat_s_1) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_0._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_1._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_1._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_1) @parameterized.named_parameters( ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", dtypes.int32, tensor_shape.matrix( 2, 2), sparse_tensor.SparseTensor, structure.SparseTensorStructure(dtypes.int32, [2, 2])), ("Nest", { "a": dtypes.float32, "b": (dtypes.int32, dtypes.string) }, { "a": tensor_shape.scalar(), "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar()) }, { "a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) })), ) def testFromLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.Structure._from_legacy_structure( output_types, output_shapes, output_classes) self.assertTrue( expected_structure.is_compatible_with(actual_structure)) self.assertTrue( actual_structure.is_compatible_with(expected_structure))
def _WhereShape(op): """Shape function for the Where op.""" input_shape = op.inputs[0].get_shape() return [tensor_shape.matrix(None, input_shape.ndims)]
def _to_legacy_output_shapes(self): # Sneak the dynamic_size and infer_shape values into the legacy shape. return (tensor_shape.matrix(self._dynamic_size, self._infer_shape) .concatenate(self._element_shape))
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they # will be executed before the (eager- or graph-mode) test environment has been # set up. # pylint: disable=g-long-lambda,protected-access @parameterized.parameters( (lambda: constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32], [[]]), (lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), structure.TensorArrayStructure, [dtypes.variant], [None, 3]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), structure.SparseTensorStructure, [dtypes.variant], [None]), (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3] ]), (lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, structure.NestedStructure, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None])) def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): value = value_fn() s = structure.Structure.from_value(value) self.assertIsInstance(s, expected_structure) self.assertEqual(expected_types, s._flat_types) for expected, actual in zip(expected_shapes, s._flat_shapes): self.assertTrue(actual.is_compatible_with(expected)) self.assertTrue( tensor_shape.as_shape(expected).is_compatible_with(actual)) @parameterized.parameters( (lambda: constant_op.constant(37.0), lambda: [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), (lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=10) ], lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(3, ), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=0) ]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [ sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) ], lambda: [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None, None]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) }], lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { "a": constant_op.constant(15), "b": constant_op.constant([4, 5, 6]) }, { "a": constant_op.constant(15), "b": sparse_tensor.SparseTensor( indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) @test_util.run_deprecated_v1 def testIsCompatibleWithStructure(self, original_value_fn, compatible_values_fn, incompatible_values_fn): original_value = original_value_fn() compatible_values = compatible_values_fn() incompatible_values = incompatible_values_fn() s = structure.Structure.from_value(original_value) for compatible_value in compatible_values: self.assertTrue( s.is_compatible_with( structure.Structure.from_value(compatible_value))) for incompatible_value in incompatible_values: self.assertFalse( s.is_compatible_with( structure.Structure.from_value(incompatible_value))) @parameterized.parameters( (lambda: constant_op.constant(37.0), ), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ), (lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, ), (lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, ), ) def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.Structure.from_value(value) def maybe_stack_ta(v): if isinstance(v, tensor_array_ops.TensorArray): return v.stack() else: return v before = self.evaluate(maybe_stack_ta(value)) after = self.evaluate( maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value)))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) for b, a in zip(flat_before, flat_after): if isinstance(b, sparse_tensor.SparseTensorValue): self.assertAllEqual(b.indices, a.indices) self.assertAllEqual(b.values, a.values) self.assertAllEqual(b.dense_shape, a.dense_shape) else: self.assertAllEqual(b, a) # pylint: enable=g-long-lambda def testIncompatibleStructure(self): # Define three mutually incompatible values/structures, and assert that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_tensor = constant_op.constant(42.0) s_tensor = structure.Structure.from_value(value_tensor) flat_tensor = s_tensor._to_tensor_list(value_tensor) value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor) flat_sparse_tensor = s_sparse_tensor._to_tensor_list( value_sparse_tensor) value_nest = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_nest = structure.Structure.from_value(value_nest) flat_nest = s_nest._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Value \{.*\} is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_tensor) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"): s_tensor._from_tensor_list(flat_sparse_tensor) with self.assertRaisesRegexp( ValueError, "TensorStructure corresponds to a single tf.Tensor."): s_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_sparse_tensor) def testIncompatibleNestedStructure(self): # Define three mutually incompatible nested values/structures, and assert # that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_0 = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_0 = structure.Structure.from_value(value_0) flat_s_0 = s_0._to_tensor_list(value_0) # `value_1` has compatible nested structure with `value_0`, but different # classes. value_1 = { "a": constant_op.constant(37.0), "b": sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) } s_1 = structure.Structure.from_value(value_1) flat_s_1 = s_1._to_tensor_list(value_1) # `value_2` has incompatible nested structure with `value_0` and `value_1`. value_2 = { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) } s_2 = structure.Structure.from_value(value_2) flat_s_2 = s_2._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure"): s_0._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*SparseTensorStructure"): s_1._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp # needs to account for "a" coming before or after "b". It might be worth # adding a deterministic repr for these error messages (among other # improvements). with self.assertRaisesRegexp( ValueError, "Tensor.*Tensor.* not compatible with the nested structure " ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* " "not compatible with the nested structure .*" "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"): s_0._from_tensor_list(flat_s_1) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_0._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_1._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_1._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_1) @parameterized.named_parameters( ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", dtypes.int32, tensor_shape.matrix( 2, 2), sparse_tensor.SparseTensor, structure.SparseTensorStructure(dtypes.int32, [2, 2])), ("TensorArray0", dtypes.int32, tensor_shape.as_shape( [None, True, 2, 2]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)), ("TensorArray1", dtypes.int32, tensor_shape.as_shape( [True, None, 2, 2]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)), ("TensorArray2", dtypes.int32, tensor_shape.as_shape([True, False, 2, 2 ]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)), ("Nest", { "a": dtypes.float32, "b": (dtypes.int32, dtypes.string) }, { "a": tensor_shape.scalar(), "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar()) }, { "a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) })), ) def testConvertLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self.assertTrue( expected_structure.is_compatible_with(actual_structure)) self.assertTrue( actual_structure.is_compatible_with(expected_structure)) def testNestedNestedStructure(self): # Although `Structure.from_value()` will not construct one, a nested # structure containing nested `NestedStructure` objects can occur if a # structure is constructed manually. s = structure.NestedStructure( (structure.TensorStructure(dtypes.int64, []), structure.NestedStructure( (structure.TensorStructure(dtypes.float32, []), structure.TensorStructure(dtypes.string, []))))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) string_t = constant_op.constant("Foo") nested_tensors = (int64_t, (float32_t, string_t)) tensor_list = s._to_tensor_list(nested_tensors) for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): self.assertIs(expected, actual) (actual_int64_t, (actual_float32_t, actual_string_t)) = s._from_tensor_list(tensor_list) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) (actual_int64_t, (actual_float32_t, actual_string_t)) = (s._from_compatible_tensor_list(tensor_list)) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, []), 32, structure.TensorStructure(dtypes.float32, [32])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None, structure.TensorStructure(dtypes.float32, [None])), ("SparseTensor", structure.SparseTensorStructure( dtypes.float32, [None]), 32, structure.SparseTensorStructure(dtypes.float32, [32, None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [4]), None, structure.SparseTensorStructure(dtypes.float32, [None, 4])), ("Nest", structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) }), 128, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [128])) })), ) def testBatch(self, element_structure, batch_size, expected_batched_structure): batched_structure = element_structure._batch(batch_size) self.assertTrue( batched_structure.is_compatible_with(expected_batched_structure)) self.assertTrue( expected_batched_structure.is_compatible_with(batched_structure)) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, [32]), structure.TensorStructure(dtypes.float32, [])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", structure.SparseTensorStructure(dtypes.float32, [32, None]), structure.SparseTensorStructure(dtypes.float32, [None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [None, 4]), structure.SparseTensorStructure(dtypes.float32, [4])), ("Nest", structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [None])) }), structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) })), ) def testUnbatch(self, element_structure, expected_unbatched_structure): unbatched_structure = element_structure._unbatch() self.assertTrue( unbatched_structure.is_compatible_with( expected_unbatched_structure)) self.assertTrue( expected_unbatched_structure.is_compatible_with( unbatched_structure)) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), lambda: constant_op.constant([1.0, 2.0])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]), lambda: sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2])), ("Nest", lambda: (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])), lambda: (constant_op.constant([1.0, 2.0]), sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2]))), ) def testToBatchedTensorList(self, value_fn, element_0_fn): batched_value = value_fn() s = structure.Structure.from_value(batched_value) batched_tensor_list = s._to_batched_tensor_list(batched_value) # The batch dimension is 2 for all of the test cases. # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT # tensors in which we store sparse tensors. for t in batched_tensor_list: if t.dtype != dtypes.variant: self.assertEqual(2, self.evaluate(array_ops.shape(t)[0])) # Test that the 0th element from the unbatched tensor is equal to the # expected value. expected_element_0 = self.evaluate(element_0_fn()) unbatched_s = s._unbatch() actual_element_0 = unbatched_s._from_tensor_list( [t[0] for t in batched_tensor_list]) for expected, actual in zip(nest.flatten(expected_element_0), nest.flatten(actual_element_0)): if sparse_tensor.is_sparse(expected): self.assertSparseValuesEqual(expected, actual) else: self.assertAllEqual(expected, actual)