def _test_unbatch_combinations(): cases = [ (tensor_spec.TensorSpec([32], dtypes.float32), tensor_spec.TensorSpec([], dtypes.float32)), (tensor_spec.TensorSpec([None], dtypes.float32), tensor_spec.TensorSpec([], dtypes.float32)), (sparse_tensor.SparseTensorSpec([32, None], dtypes.float32), sparse_tensor.SparseTensorSpec([None], dtypes.float32)), (sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32), sparse_tensor.SparseTensorSpec([4], dtypes.float32)), (ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2), ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)), (ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2), ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)), ({ "a": tensor_spec.TensorSpec([128], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32), tensor_spec.TensorSpec([None], dtypes.string)) }, { "a": tensor_spec.TensorSpec([], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), tensor_spec.TensorSpec([], dtypes.string)) }), ] def reduce_fn(x, y): element_structure, expected_unbatched_structure = y return x + combinations.combine( element_structure=element_structure, expected_unbatched_structure=expected_unbatched_structure) return functools.reduce(reduce_fn, cases, [])
def testTypeSpec(self, distribution, enable_get_next_as_optional): if not tf2.enabled(): self.skipTest("DistributedIterator has CompositeTensor support in " "TF 2.0 only.") ctx = distribute_lib.InputContext() batch_size = ctx.get_per_replica_batch_size(8) # Use 20 which isn't divisible by 8 to test partial batch behavior. row_lengths = np.mod(np.arange(20), 4).astype(np.int64) ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) dataset = dataset_ops.DatasetV2.from_tensor_slices({ "dense": ragged_tensor.to_tensor(), "ragged": ragged_tensor, "sparse": ragged_tensor.to_sparse(), }) dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) dataset = dataset.batch(batch_size) distribution.extended.experimental_enable_get_next_as_optional = ( enable_get_next_as_optional) dist_dataset = distribution.experimental_distribute_dataset(dataset) with distribution.scope(): iterator = iter(dist_dataset) _check_type_spec_structure(iterator) spec = iterator._type_spec self.assertEqual(spec._input_workers, iterator._input_workers) self.assertEqual( spec._element_spec, { "sparse": values.PerReplicaSpec( sparse_tensor.SparseTensorSpec( tensor_shape.TensorShape([None, 3]), dtypes.float32), sparse_tensor.SparseTensorSpec( tensor_shape.TensorShape([None, 3]), dtypes.float32)), "dense": values.PerReplicaSpec( tensor_spec.TensorSpec( shape=(None, 3), dtype=dtypes.float32, name=None), tensor_spec.TensorSpec( shape=(None, 3), dtype=dtypes.float32, name=None)), "ragged": values.PerReplicaSpec( ragged_tensor_lib.RaggedTensorSpec( tensor_shape.TensorShape([None, None]), dtypes.float32, 1, dtypes.int64), ragged_tensor_lib.RaggedTensorSpec( tensor_shape.TensorShape([None, None]), dtypes.float32, 1, dtypes.int64)) })
def _test_convert_legacy_structure_combinations(): cases = [(dtypes.float32, tensor_shape.TensorShape([]), ops.Tensor, tensor_spec.TensorSpec([], dtypes.float32)), (dtypes.int32, tensor_shape.TensorShape([2, 2]), sparse_tensor.SparseTensor, sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)), (dtypes.int32, tensor_shape.TensorShape([None, True, 2, 2]), tensor_array_ops.TensorArray, tensor_array_ops.TensorArraySpec([2, 2], dtypes.int32, dynamic_size=None, infer_shape=True)), (dtypes.int32, tensor_shape.TensorShape([True, None, 2, 2]), tensor_array_ops.TensorArray, tensor_array_ops.TensorArraySpec([2, 2], dtypes.int32, dynamic_size=True, infer_shape=None)), (dtypes.int32, tensor_shape.TensorShape([True, False, 2, 2]), tensor_array_ops.TensorArray, tensor_array_ops.TensorArraySpec([2, 2], dtypes.int32, dynamic_size=True, infer_shape=False)), (dtypes.int32, tensor_shape.TensorShape([2, None]), ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1), ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)), ({ "a": dtypes.float32, "b": (dtypes.int32, dtypes.string) }, { "a": tensor_shape.TensorShape([]), "b": (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([])) }, { "a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor) }, { "a": tensor_spec.TensorSpec([], dtypes.float32), "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), tensor_spec.TensorSpec([], dtypes.string)) })] def reduce_fn(x, y): output_types, output_shapes, output_classes, expected_structure = y return x + combinations.combine(output_types=output_types, output_shapes=output_shapes, output_classes=output_classes, expected_structure=expected_structure) return functools.reduce(reduce_fn, cases, [])
def testNestedRaggedMapWithFnOutputSignature(self): ragged1d = ragged_tensor.RaggedTensorSpec([None], dtypes.int32) ragged2d = ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32) x = ragged_factory_ops.constant([[1, 2, 3, 4], [1]]) # pylint: disable=g-long-lambda y = map_fn_lib.map_fn(lambda r: map_fn_lib.map_fn( lambda y: r, r, fn_output_signature=ragged1d), x, fn_output_signature=ragged2d) expected = [[[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], [[1]]] self.assertAllEqual(y, expected)
def test_save_uses_sanitized_signature_name(self): @def_function.function( input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=2)]) def f(x): return {"output_key": x} # Colons are not usable as name scopes. unsanitized_name = "foo:bar" root = tracking.AutoTrackable() path = os.path.join(self.get_temp_dir(), "saved_model") save.save(root, path, signatures={unsanitized_name: f.get_concrete_function()}) graph = ops.Graph() with graph.as_default(), session_lib.Session() as session: meta_graph_def = loader.load(session, [tag_constants.SERVING], path) signature = meta_graph_def.signature_def[unsanitized_name] tensor_names = [ session.graph.get_tensor_by_name( signature.inputs[key].name).name for key in signature.inputs ] # The placeholder names will have the sanitized version. self.assertCountEqual( tensor_names, ["foo_bar_x:0", "foo_bar_x_1:0", "foo_bar_x_2:0"])
def testPartitionOuterDims(self): if not context.executing_eagerly(): return # TESTING a = dict(x=1, y=[1, 2]) b = dict(x=2, y=[3, 4]) c = dict(x=3, y=[5, 6]) d = dict(x=4, y=[7, 8]) st1 = StructuredTensor.from_pyval([a, b, c, d]) st2 = st1.partition_outer_dimension( row_partition.RowPartition.from_row_splits([0, 2, 2, 3, 4])) self.assertAllEqual(st2, [[a, b], [], [c], [d]]) st3 = st2.partition_outer_dimension( row_partition.RowPartition.from_row_lengths([1, 0, 3, 0])) self.assertAllEqual(st3, [[[a, b]], [], [[], [c], [d]], []]) # If we partition with uniform_row_lengths, then `x` is partitioned into # a Tensor (not a RaggedTensor). st4 = st1.partition_outer_dimension( row_partition.RowPartition.from_uniform_row_length( uniform_row_length=2, nvals=4, nrows=2)) self.assertAllEqual( st4, structured_tensor.StructuredTensor.from_pyval( [[a, b], [c, d]], structured_tensor.StructuredTensorSpec( [2, 2], { "x": tensor_spec.TensorSpec([2, 2], dtypes.int32), "y": ragged_tensor.RaggedTensorSpec([2, 2, None], dtypes.int32) })))
def common_spec(x, y): common_shape = get_common_shape(x.shape, y.shape) if isinstance(x, sparse_tensor.SparseTensorSpec): return sparse_tensor.SparseTensorSpec(common_shape, x.dtype) elif isinstance(x, ragged_tensor.RaggedTensorSpec): return ragged_tensor.RaggedTensorSpec(common_shape, x.dtype) return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)
def testCompositeAndSpec(self): composite_tensor = ragged_tensor.RaggedTensor.from_row_splits( values=[1, 2, 3], row_splits=[0, 2, 3]) spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32) self.assertEqual(trace_type.from_object(composite_tensor), trace_type.from_object(spec))
def testCompositeAndSpec(self): composite_tensor = ragged_tensor.RaggedTensor.from_row_splits( values=[1, 2, 3], row_splits=[0, 2, 3]) spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32) self.assertEqual( function_trace_type.get_arg_spec(composite_tensor, False, False, True), function_trace_type.get_arg_spec(spec, False, False, True))
def testCompositeAndSpec(self): composite_tensor = ragged_tensor.RaggedTensor.from_row_splits( values=[1, 2, 3], row_splits=[0, 2, 3]) spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32) self.assertEqual( make_function_signature_with_context(composite_tensor), make_function_signature_with_context(spec))
def testEncodeDataSetSpec(self): structure = [dataset_ops.DatasetSpec( {"rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32), "st": sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32), "t": tensor_spec.TensorSpec([10, 8], dtypes.string)})] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def get_selection_mask(self, input_ids, axis): selectable = super(RandomItemSelector, self).get_selectable(input_ids, axis) # Run the selection algorithm on positions RT positions_flat = math_ops.range(array_ops.size(input_ids.flat_values)) positions = input_ids.with_flat_values(positions_flat) # Mask out positions that are not selectable positions = ragged_array_ops.boolean_mask(positions, selectable) # merge to the desired axis positions = positions.merge_dims(1, axis) if axis > 1 else positions # Figure out how many we are going to select num_to_select = math_ops.ceil( math_ops.cast(positions.row_lengths(), dtypes.float32) * self.selection_rate) num_to_select = math_ops.minimum(num_to_select, self.max_selections_per_batch) num_to_select = math_ops.cast(num_to_select, dtypes.int64) # Shuffle and trim to items that are going to be selected def _shuffle_and_trim(x): positions, top_n = x if isinstance(positions, ragged_tensor.RaggedTensor): positions_at_axis = math_ops.range(positions.nrows()) chosen_positions_at_axis = self._shuffle_fn( positions_at_axis)[:top_n] return array_ops.gather(positions, chosen_positions_at_axis) else: shuffled = self._shuffle_fn(positions) return shuffled[:top_n] selected_for_mask = map_fn.map_fn( _shuffle_and_trim, (positions, num_to_select), fn_output_signature=ragged_tensor.RaggedTensorSpec( ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype)) selected_for_mask.flat_values.set_shape([None]) # Construct the result which is a boolean RT # Scatter 1's to positions that have been selected_for_mask update_values = array_ops.ones_like(selected_for_mask.flat_values) update_values = math_ops.cast(update_values, input_ids.dtype) update_indices = selected_for_mask.flat_values update_indices = array_ops.expand_dims(update_indices, -1) update_indices = math_ops.cast(update_indices, input_ids.dtype) results_flat = array_ops.zeros_like(input_ids.flat_values) results_flat = gen_array_ops.tensor_scatter_update( results_flat, update_indices, update_values) results = math_ops.cast(input_ids.with_flat_values(results_flat), dtypes.bool) if axis < results.ragged_rank: reduce_axis = list(range(results.ragged_rank, axis, -1)) results = math_ops.reduce_all(results, reduce_axis) return results
def testRaggedBadReturnTypeExpectedRaggedReturnedTensor(self): with self.assertRaisesRegex( (ValueError, errors.InvalidArgumentError), "py_function: func=.* returned .* which did not match Tout=.*"): result = script_ops.eager_py_func( func=lambda x: x, inp=[constant_op.constant([[1, 2, 3]])], Tout=[ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]) self.evaluate(result)
def testDatasetSpecConstructor(self): rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32) st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32) t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string) element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec} ds_struct = dataset_ops.DatasetSpec(element_spec, [5]) self.assertEqual(ds_struct._element_spec, element_spec) # Note: shape was automatically converted from a list to a TensorShape. self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
def _ragged_type_to_spec(t): if isinstance(t, ragged_tensor.RaggedTensorType): # Note: need to adjust ragged_rank by 1, since RaggedTensorSpec gives the # type for the mapped `fn` output, but RaggedTensorType gives the type for # the result of stacking the mapped `fn` outputs. return ragged_tensor.RaggedTensorSpec( None, t.dtype, t.ragged_rank - 1, t.row_splits_dtype) else: return t
def to_ragged_spec(spec): if isinstance(spec, tensor_spec.TensorSpec) and spec.shape.ndims != 0: return ragged_tensor.RaggedTensorSpec( shape=spec.shape, dtype=spec.dtype, ragged_rank=0, row_splits_dtype=row_splits_dtype) else: return spec
def testRaggedToStringUnknownRank(self): @def_function.function( input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=1)]) def f(rt): return ragged_string_ops.ragged_tensor_to_string(rt) with self.assertRaisesRegex( ValueError, 'RaggedTensor to_string requires ' 'that rt.shape.rank is not None'): f(ragged_factory_ops.constant([[1, 2], [3]]))
def _matmul_3d_with_map_fn(a, b, **kwargs): """Multiplies batches of 2D matrices using map_fn. `output[n, i, k]` = sum_j (a[n, i, j] * b[n, j, k])` (for all `n`, `i`, `k`). Requires that `a[n, i].nrows()` == `b[n].nrows()` (for all `n` and `i`). Args: a: A 3D Tensor or RaggedTensor with `shape=[B, I, J]`, where dimensions `I` and `J` may be ragged. b: A 3D Tensor or RaggedTensor with `shape=[B, J, K]`, where dimensions `J` and `K` may be ragged. **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). Returns: A 3D RaggedTensor with `shape=[B, (I), (K)]`. """ if isinstance(b, ragged_tensor.RaggedTensor) and b.ragged_rank == 2: output_ragged_rank = 2 else: output_ragged_rank = 1 def single_batch_matmul(x): out = _matmul_2d(x[0], x[1], **kwargs) if output_ragged_rank == 2: out = ragged_tensor.RaggedTensor.from_tensor(out) return out fn_out_shape = None # Figure out proper shape. row_splits_dtype = ( a.row_splits.dtype if isinstance(a, ragged_tensor.RaggedTensor) else b.row_splits.dtype) output_type = kwargs['output_type'] if output_type is None: output_type = a.dtype spec = ragged_tensor.RaggedTensorSpec( shape=fn_out_shape, dtype=output_type, ragged_rank=output_ragged_rank - 1, row_splits_dtype=row_splits_dtype) result = map_fn.map_fn( single_batch_matmul, elems=(a, b), fn_output_signature=spec) # map_fn loses shape information; restore it, where possible. # pylint: disable=protected-access if kwargs.get('transpose_a') or kwargs.get('adjoint_a'): result._set_shape(a.shape[:-2] + a.shape[-1:] + [None]) else: result._set_shape(a.shape[:-2] + a.shape[-2:-1] + [None]) if kwargs.get('transpose_b') or kwargs.get('adjoint_b'): result._set_shape(b.shape[:-2] + [None] + b.shape[-2:-1]) else: result._set_shape(b.shape[:-2] + [None] + b.shape[-1:]) return result
def testRaggedTensorReturn(self): def fn(v, l): return ragged_tensor.RaggedTensor.from_row_lengths(v, l) values = [1, 2, 3, 4, 5, 6] lengths = constant_op.constant([3, 1, 2], dtypes.int64) out_signature = [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)] y, = script_ops.eager_py_func(fn, [values, lengths], out_signature) self.assertIsInstance(y, ragged_tensor.RaggedTensor) self.assertAllEqual(y, [[1, 2, 3], [4], [5, 6]])
def testMapRaggedTensor(self): # Note: there are additional tests in ragged/ragged_map_fn_op_test.py with self.cached_session(): rt = ragged_factory_ops.constant([[1, 2], [3]]) result = map_fn.map_fn( lambda x: x + 1, rt, fn_output_signature=ragged_tensor.RaggedTensorSpec([None], rt.dtype)) self.assertAllEqual([[2, 3], [4]], result) self.assertEqual([2, None], result.shape.as_list())
def testNestedRagged(self): # Check that TwoCompositeSpecs are compatible if one has a nested # RaggedTensorSpec w/ ragged_rank=0 and the other has a corresponding # nested TensorSpec. spec1 = TwoCompositesSpec( ragged_tensor.RaggedTensorSpec([10], dtypes.int32, ragged_rank=0), tensor_spec.TensorSpec(None, dtypes.int32)) spec2 = TwoCompositesSpec(tensor_spec.TensorSpec([10], dtypes.int32), tensor_spec.TensorSpec(None, dtypes.int32)) spec3 = TwoCompositesSpec(tensor_spec.TensorSpec([12], dtypes.int32), tensor_spec.TensorSpec(None, dtypes.int32)) self.assertTrue(spec1.is_compatible_with(spec2)) self.assertFalse(spec1.is_compatible_with(spec3))
def testFromGeneratorRaggedTensor(self): def generator(): yield ragged_factory_ops.constant([[1, 2], [3]]) dataset = dataset_ops.Dataset.from_generator( generator, output_signature=ragged_tensor.RaggedTensorSpec( shape=(2, None), dtype=dtypes.int32)) get_next = self.getNext(dataset) ret = get_next() self.assertIsInstance(ret, ragged_tensor.RaggedTensor) self.assertAllEqual([[1, 2], [3]], ret)
def testUnknownRank(self): no_rank_spec = ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1) rank_only_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32, 1) matmul_no_rank_for_a = def_function.function( input_signature=[rank_only_spec, no_rank_spec])( ragged_math_ops.matmul) matmul_no_rank_for_b = def_function.function( input_signature=[no_rank_spec, rank_only_spec])( ragged_math_ops.matmul) matmul_no_rank_for_a_or_b = def_function.function( input_signature=[no_rank_spec, no_rank_spec])( ragged_math_ops.matmul) a = ragged_factory_ops.constant([[1, 2]]) b = ragged_factory_ops.constant([[3], [4]]) self.assertAllEqual(matmul_no_rank_for_a(a, b), [[11]]) self.assertAllEqual(matmul_no_rank_for_b(a, b), [[11]]) with self.assertRaisesRegex( ValueError, 'matmul requires at least one input to have known ' 'rank if either input is ragged.'): matmul_no_rank_for_a_or_b(a, b)
def _most_general_compatible_type(spec): """Returns the most general TypeSpec compatible with `spec`.""" # TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API if isinstance(spec, tensor_spec.TensorSpec): return tensor_spec.TensorSpec(None, spec.dtype) elif isinstance(spec, ragged_tensor.RaggedTensorSpec): # pylint: disable=protected-access return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank, spec._row_splits_dtype) elif isinstance(spec, sparse_tensor.SparseTensorSpec): # pylint: disable=protected-access return sparse_tensor.SparseTensorSpec(None, spec.dtype) else: return spec
def testUniformSplitDynamicShape(self, rt_shape): rt = ragged_tensor.RaggedTensor.from_row_lengths([1.0, 2.0, 3.0, 4.0], [3, 1]) rt_spec = ragged_tensor.RaggedTensorSpec(rt_shape, ragged_rank=1) @def_function.function(input_signature=[rt_spec]) def split_tensors(rt): return ragged_array_ops.split(rt, 2) splited_rts = split_tensors(rt) expected_rts = [ ragged_tensor.RaggedTensor.from_row_lengths([1.0, 2.0, 3.0], [3]), ragged_tensor.RaggedTensor.from_row_lengths([4.0], [1]) ] for splited_rt, expected_rt in zip(splited_rts, expected_rts): self.assertAllEqual(splited_rt, expected_rt)
def to_ragged_spec(spec): """Returns the new spec based on RaggedTensors.""" if (not isinstance(spec, tensor_spec.TensorSpec) or spec.shape.rank is None or spec.shape.is_fully_defined()): return spec else: ragged_rank = max([ axis for (axis, size) in enumerate(spec.shape.as_list()) if size is None ]) return ragged_tensor.RaggedTensorSpec( shape=spec.shape, dtype=spec.dtype, ragged_rank=ragged_rank, row_splits_dtype=row_splits_dtype)
def test_save_composite_tensor_signature(self): @def_function.function( input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=2)]) def f(x): return {"output_key": x} root = tracking.AutoTrackable() path = os.path.join(self.get_temp_dir(), "saved_model") inp = ragged_factory_ops.constant([[[1.0, 2.0], [3.0]], [[5.]]]) flat_inp = { "x": constant_op.constant([1., 2., 3., 5]), "x_1": constant_op.constant([0, 2, 3], dtype=dtypes.int64), "x_2": constant_op.constant([0, 2, 3, 4], dtype=dtypes.int64) } save.save(root, path, signatures={"key": f.get_concrete_function()}) # Test that the ragged signature can be loaded back into Python with V2 APIs imported = load.load(path) self.assertAllEqual( inp, imported.signatures["key"](**flat_inp)["output_key"]) graph = ops.Graph() # Try running the signature with V1 APIs. with graph.as_default(), session_lib.Session() as session: meta_graph_def = loader.load(session, [tag_constants.SERVING], path) signature = meta_graph_def.signature_def["key"] feed_dict = {} for arg_name in flat_inp: input_tensor = session.graph.get_tensor_by_name( signature.inputs[arg_name].name) feed_dict[input_tensor] = flat_inp[arg_name].numpy() # Get composite tensor components output_components = ( signature.outputs["output_key"].composite_tensor.components) fetches = {} components_keys = ["x", "x_1", "x_2"] for k, output_tensor_info in zip(components_keys, output_components): fetches[k] = session.graph.get_tensor_by_name( output_tensor_info.name) outputs = session.run(fetches, feed_dict) self.assertAllClose(flat_inp, outputs)
def testRaggedSplitDynamicShape(self, rt_shape, lengths_shape): rt_spec = ragged_tensor.RaggedTensorSpec(rt_shape, ragged_rank=1) lengths_spec = tensor_spec.TensorSpec(lengths_shape, dtype=dtypes.int32) @def_function.function(input_signature=[rt_spec, lengths_spec]) def split_tensors(rt, split_lengths): return ragged_array_ops.split(rt, split_lengths, num=2) rt = ragged_tensor.RaggedTensor.from_row_lengths([1.0, 2.0, 3.0, 4.0], [3, 1]) split_lengths = [1, 1] # split_lengths matches num at runtime splited_rts = split_tensors(rt, split_lengths) expected_rts = [ ragged_tensor.RaggedTensor.from_row_lengths([1.0, 2.0, 3.0], [3]), ragged_tensor.RaggedTensor.from_row_lengths([4.0], [1]) ] for splited_rt, expected_rt in zip(splited_rts, expected_rts): self.assertAllEqual(splited_rt, expected_rt)
def testEncodeDecodeRaggedTensorSpec(self): structure = [ ragged_tensor.RaggedTensorSpec([1, 2, 3], dtypes.int64, 2, dtypes.int32) ] self.assertTrue(self._coder.can_encode(structure)) encoded = self._coder.encode_structure(structure) expected_pbtxt = r""" list_value { values { type_spec_value { type_spec_class: RAGGED_TENSOR_SPEC type_spec_class_name: 'RaggedTensorSpec' num_flat_components: 3 type_state { tuple_value { # spec._shape values { tensor_shape_value { dim { size: 1 } dim { size: 2 } dim { size: 3 } } } # spec._dtype values { tensor_dtype_value: DT_INT64 } # spec._ragged_rank values { int64_value: 2 } # spec._row_splits_dtype values { tensor_dtype_value: DT_INT32 } } } } } } """ expected = struct_pb2.StructuredValue() text_format.Parse(expected_pbtxt, expected) self.assertEqual(expected, encoded) decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded)
def _test_ragged_structure_inequality_combinations(): cases = [ (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1), ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)), (ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32, 1), ragged_tensor.RaggedTensorSpec([5, None], dtypes.int32, 1)), (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1), ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)), ] def reduce_fn(x, y): spec1, spec2 = y return x + combinations.combine(spec1=spec1, spec2=spec2) return functools.reduce(reduce_fn, cases, [])