def testSerializeManyDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: classes = sparse.get_classes(expected) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_many_sparse_tensors(expected), types, shapes, sparse.get_classes(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: classes = sparse.get_classes(expected) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), types, shapes, sparse.get_classes(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def testSerializeManyDeserialize(self, input_fn): test_case = input_fn() classes = sparse.get_classes(test_case) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_many_sparse_tensors(test_case), types, shapes, sparse.get_classes(test_case)) nest.assert_same_structure(test_case, actual) for a, e in zip(nest.flatten(actual), nest.flatten(test_case)): self.assertSparseValuesEqual(a, e)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def tf_init_func(key): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) ret = init_func(key) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._state_classes = sparse.get_classes(ret) self._state_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections( "tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def from_value(value): """Returns an `Optional` that wraps the given value. Args: value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects. Returns: An `Optional` that wraps `value`. """ # TODO(b/110122868): Consolidate this destructuring logic with the # similar code in `Dataset.from_tensors()`. with ops.name_scope("optional") as scope: with ops.name_scope("value"): value = nest.pack_sequence_as(value, [ sparse_tensor_lib.SparseTensor.from_value(t) if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( t, name="component_%d" % i) for i, t in enumerate(nest.flatten(value)) ]) encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value)) output_classes = sparse.get_classes(value) output_shapes = nest.pack_sequence_as( value, [t.get_shape() for t in nest.flatten(value)]) output_types = nest.pack_sequence_as( value, [t.dtype for t in nest.flatten(value)]) return _OptionalImpl( gen_dataset_ops.optional_from_value(encoded_value, name=scope), output_shapes, output_types, output_classes)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def testGetClasses(self): s = sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]) d = ops.Tensor t = sparse_tensor.SparseTensor test_cases = ( { "classes": (), "expected": () }, { "classes": s, "expected": t }, { "classes": constant_op.constant([1]), "expected": d }, { "classes": (s), "expected": (t) }, { "classes": (constant_op.constant([1])), "expected": (d) }, { "classes": (s, ()), "expected": (t, ()) }, { "classes": ((), s), "expected": ((), t) }, { "classes": (constant_op.constant([1]), ()), "expected": (d, ()) }, { "classes": ((), constant_op.constant([1])), "expected": ((), d) }, { "classes": (s, (), constant_op.constant([1])), "expected": (t, (), d) }, { "classes": ((), s, ()), "expected": ((), t, ()) }, { "classes": ((), constant_op.constant([1]), ()), "expected": ((), d, ()) }, ) for test_case in test_cases: self.assertEqual(sparse.get_classes(test_case["classes"]), test_case["expected"])
def testGetClasses(self): s = sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]) d = ops.Tensor t = sparse_tensor.SparseTensor test_cases = ( { "classes": (), "expected": () }, { "classes": s, "expected": t }, { "classes": constant_op.constant([1]), "expected": d }, { "classes": (s), "expected": (t) }, { "classes": (constant_op.constant([1])), "expected": (d) }, { "classes": (s, ()), "expected": (t, ()) }, { "classes": ((), s), "expected": ((), t) }, { "classes": (constant_op.constant([1]), ()), "expected": (d, ()) }, { "classes": ((), constant_op.constant([1])), "expected": ((), d) }, { "classes": (s, (), constant_op.constant([1])), "expected": (t, (), d) }, { "classes": ((), s, ()), "expected": ((), t, ()) }, { "classes": ((), constant_op.constant([1]), ()), "expected": ((), d, ()) }, ) for test_case in test_cases: self.assertEqual( sparse.get_classes(test_case["classes"]), test_case["expected"])
def tf_init_func(key): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) ret = init_func(key) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._state_classes = sparse.get_classes(ret) self._state_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections( "tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def tf_init_func(key): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) ret = init_func(key) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._state_classes = sparse.get_classes(ret) self._state_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def tf_init_func(key): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) ret = init_func(key) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._state_classes = sparse.get_classes(ret) self._state_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" super(_ScanDataset, self).__init__() self._input_dataset = input_dataset with ops.name_scope("initial_state"): # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. self._initial_state = nest.pack_sequence_as(initial_state, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( t, name="component_%d" % i) for i, t in enumerate(nest.flatten(initial_state)) ]) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_classes = sparse.get_classes(self._initial_state) self._state_shapes = nest.pack_sequence_as( self._initial_state, [t.get_shape() for t in nest.flatten(self._initial_state)]) self._state_types = nest.pack_sequence_as( self._initial_state, [t.dtype for t in nest.flatten(self._initial_state)]) # Will be populated by calling `tf_scan_func`. self._output_classes = None self._output_shapes = None self._output_types = None # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, "tf.contrib.data.scan()", input_classes=(self._state_classes, input_dataset.output_classes), input_shapes=(self._state_shapes, input_dataset.output_shapes), input_types=(self._state_types, input_dataset.output_types), add_to_graph=False) if not ( isinstance(wrapped_func.output_types, collections.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(new_state_classes), nest.flatten(self._state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, self._output_types = wrapped_func.output_types for new_state_type, state_type in zip( nest.flatten(new_state_types), nest.flatten(self._state_types)): if new_state_type != state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, self._output_shapes = wrapped_func.output_shapes flat_state_shapes = nest.flatten(self._state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) self._scan_func = wrapped_func.function self._scan_func.add_to_graph(ops.get_default_graph())
def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" super(_ScanDataset, self).__init__() self._input_dataset = input_dataset with ops.name_scope("initial_state"): # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. self._initial_state = nest.pack_sequence_as( initial_state, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( t, name="component_%d" % i) for i, t in enumerate(nest.flatten(initial_state)) ]) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_classes = sparse.get_classes(self._initial_state) self._state_shapes = nest.pack_sequence_as( self._initial_state, [t.get_shape() for t in nest.flatten(self._initial_state)]) self._state_types = nest.pack_sequence_as( self._initial_state, [t.dtype for t in nest.flatten(self._initial_state)]) # Will be populated by calling `tf_scan_func`. self._output_classes = None self._output_shapes = None self._output_types = None # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, "tf.contrib.data.scan()", input_classes=(self._state_classes, input_dataset.output_classes), input_shapes=(self._state_shapes, input_dataset.output_shapes), input_types=(self._state_types, input_dataset.output_types), add_to_graph=False) if not (isinstance(wrapped_func.output_types, collections.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError( "The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(new_state_classes), nest.flatten(self._state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, self._output_types = wrapped_func.output_types for new_state_type, state_type in zip( nest.flatten(new_state_types), nest.flatten(self._state_types)): if new_state_type != state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, self._output_shapes = wrapped_func.output_shapes flat_state_shapes = nest.flatten(self._state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: self._state_shapes = nest.pack_sequence_as( self._state_shapes, weakened_state_shapes) self._scan_func = wrapped_func.function self._scan_func.add_to_graph(ops.get_default_graph())
def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes( input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) print(self._state_classes) nested_state_args = nest.pack_sequence_as( self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) print(input_dataset.output_classes) nested_input_args = nest.pack_sequence_as( input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError( "The scan function must return a pair comprising the " "new state and the output value.") # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) new_state, output_value = ret # Extract and validate class information from the returned values. for t, clazz in zip(nest.flatten(new_state), nest.flatten(self._state_classes)): if not isinstance(t, clazz): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, nest.pack_sequence_as( self._state_types, [type(t) for t in nest.flatten(new_state)]))) self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. flat_new_state_shapes.extend( [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. for t, dtype in zip(nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(new_state)) ]) output_value = nest.pack_sequence_as(output_value, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(output_value)) ]) return nest.flatten(new_state) + nest.flatten(output_value)
def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" super(_ScanDataset, self).__init__() self._input_dataset = input_dataset with ops.name_scope("initial_state"): # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. self._initial_state = nest.pack_sequence_as( initial_state, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( t, name="component_%d" % i) for i, t in enumerate(nest.flatten(initial_state)) ]) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_classes = sparse.get_classes(self._initial_state) self._state_shapes = nest.pack_sequence_as( self._initial_state, [t.get_shape() for t in nest.flatten(self._initial_state)]) self._state_types = nest.pack_sequence_as( self._initial_state, [t.dtype for t in nest.flatten(self._initial_state)]) # Will be populated by calling `tf_scan_func`. self._output_classes = None self._output_shapes = None self._output_types = None # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: # Create a list in which `tf_scan_func` will store the new shapes. flat_new_state_shapes = [] @function.Defun(*(nest.flatten( sparse.as_dense_types( self._state_types, self._state_classes)) + nest.flatten( sparse.as_dense_types(input_dataset.output_types, input_dataset.output_classes)))) def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes( input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) print(self._state_classes) nested_state_args = nest.pack_sequence_as( self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) print(input_dataset.output_classes) nested_input_args = nest.pack_sequence_as( input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError( "The scan function must return a pair comprising the " "new state and the output value.") # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) new_state, output_value = ret # Extract and validate class information from the returned values. for t, clazz in zip(nest.flatten(new_state), nest.flatten(self._state_classes)): if not isinstance(t, clazz): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, nest.pack_sequence_as( self._state_types, [type(t) for t in nest.flatten(new_state)]))) self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. flat_new_state_shapes.extend( [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. for t, dtype in zip(nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(new_state)) ]) output_value = nest.pack_sequence_as(output_value, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(output_value)) ]) return nest.flatten(new_state) + nest.flatten(output_value) # Use the private method that will execute `tf_scan_func` but delay # adding it to the graph in case we need to rerun the function. tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access flat_state_shapes = nest.flatten(self._state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun # `tf_scan_func`. self._state_shapes = nest.pack_sequence_as( self._state_shapes, weakened_state_shapes) self._scan_func = tf_scan_func self._scan_func.add_to_graph(ops.get_default_graph())
def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" self._input_dataset = input_dataset with ops.name_scope("initial_state"): self._initial_state = structure.normalize_tensors(initial_state) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_classes = sparse.get_classes(self._initial_state) self._state_shapes = nest.pack_sequence_as( self._initial_state, [t.get_shape() for t in nest.flatten(self._initial_state)]) self._state_types = nest.pack_sequence_as( self._initial_state, [t.dtype for t in nest.flatten(self._initial_state)]) # Will be populated by calling `tf_scan_func`. self._output_classes = None self._output_shapes = None self._output_types = None # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, self._transformation_name(), input_classes=(self._state_classes, input_dataset.output_classes), input_shapes=(self._state_shapes, input_dataset.output_shapes), input_types=(self._state_types, input_dataset.output_types), add_to_graph=False) if not (isinstance(wrapped_func.output_types, collections.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError( "The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(new_state_classes), nest.flatten(self._state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, self._output_types = wrapped_func.output_types for new_state_type, state_type in zip( nest.flatten(new_state_types), nest.flatten(self._state_types)): if new_state_type != state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, self._output_shapes = wrapped_func.output_shapes flat_state_shapes = nest.flatten(self._state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: self._state_shapes = nest.pack_sequence_as( self._state_shapes, weakened_state_shapes) self._scan_func = wrapped_func self._scan_func.function.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset( self._input_dataset._variant_tensor, self._state_structure._to_tensor_list(self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, **dataset_ops.flat_structure(self)) super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) print(self._state_classes) nested_state_args = nest.pack_sequence_as(self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) print(input_dataset.output_classes) nested_input_args = nest.pack_sequence_as(input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) new_state, output_value = ret # Extract and validate class information from the returned values. for t, clazz in zip( nest.flatten(new_state), nest.flatten(self._state_classes)): if not isinstance(t, clazz): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, nest.pack_sequence_as( self._state_types, [type(t) for t in nest.flatten(new_state)]))) self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. flat_new_state_shapes.extend( [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. for t, dtype in zip( nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) ]) output_value = nest.pack_sequence_as(output_value, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(output_value)) ]) return nest.flatten(new_state) + nest.flatten(output_value)
def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" super(_ScanDataset, self).__init__() self._input_dataset = input_dataset with ops.name_scope("initial_state"): # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. self._initial_state = nest.pack_sequence_as(initial_state, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( t, name="component_%d" % i) for i, t in enumerate(nest.flatten(initial_state)) ]) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_classes = sparse.get_classes(self._initial_state) self._state_shapes = nest.pack_sequence_as( self._initial_state, [t.get_shape() for t in nest.flatten(self._initial_state)]) self._state_types = nest.pack_sequence_as( self._initial_state, [t.dtype for t in nest.flatten(self._initial_state)]) # Will be populated by calling `tf_scan_func`. self._output_classes = None self._output_shapes = None self._output_types = None # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: # Create a list in which `tf_scan_func` will store the new shapes. flat_new_state_shapes = [] @function.Defun(*(nest.flatten( sparse.as_dense_types( self._state_types, self._state_classes)) + nest.flatten( sparse.as_dense_types(input_dataset.output_types, input_dataset.output_classes)))) def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) print(self._state_classes) nested_state_args = nest.pack_sequence_as(self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) print(input_dataset.output_classes) nested_input_args = nest.pack_sequence_as(input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) new_state, output_value = ret # Extract and validate class information from the returned values. for t, clazz in zip( nest.flatten(new_state), nest.flatten(self._state_classes)): if not isinstance(t, clazz): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, nest.pack_sequence_as( self._state_types, [type(t) for t in nest.flatten(new_state)]))) self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. flat_new_state_shapes.extend( [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. for t, dtype in zip( nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) ]) output_value = nest.pack_sequence_as(output_value, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(output_value)) ]) return nest.flatten(new_state) + nest.flatten(output_value) # Use the private method that will execute `tf_scan_func` but delay # adding it to the graph in case we need to rerun the function. tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access flat_state_shapes = nest.flatten(self._state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun # `tf_scan_func`. self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) self._scan_func = tf_scan_func self._scan_func.add_to_graph(ops.get_default_graph())
def testGetClasses(self, classes_fn, expected_fn): classes = classes_fn() expected = expected_fn() self.assertEqual(sparse.get_classes(classes), expected)