def _prefetch_fn(handle): """Prefetches one element from `input_iterator`.""" remote_iterator = iterator_ops.Iterator.from_string_handle( handle, input_iterator.output_types, input_iterator.output_shapes, input_iterator.output_classes) ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret))
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: classes = sparse.get_classes(expected) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), types, shapes, sparse.get_classes(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def from_value(value): """Returns an `Optional` that wraps the given value. Args: value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects. Returns: An `Optional` that wraps `value`. """ # TODO(b/110122868): Consolidate this destructuring logic with the # similar code in `Dataset.from_tensors()`. with ops.name_scope("optional") as scope: with ops.name_scope("value"): value = nest.pack_sequence_as(value, [ sparse_tensor_lib.SparseTensor.from_value(t) if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( t, name="component_%d" % i) for i, t in enumerate(nest.flatten(value)) ]) encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value)) output_classes = sparse.get_classes(value) output_shapes = nest.pack_sequence_as( value, [t.get_shape() for t in nest.flatten(value)]) output_types = nest.pack_sequence_as( value, [t.dtype for t in nest.flatten(value)]) return _OptionalImpl( gen_dataset_ops.optional_from_value(encoded_value, name=scope), output_shapes, output_types, output_classes)
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor(indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), sparse.get_sparse_types(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def tf_init_func(key): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) ret = init_func(key) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._state_classes = sparse.get_classes(ret) self._state_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections( "tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def from_value(value): """Returns an `Optional` that wraps the given value. Args: value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects. Returns: An `Optional` that wraps `value`. """ # TODO(b/110122868): Consolidate this destructuring logic with the # similar code in `Dataset.from_tensors()`. with ops.name_scope("optional") as scope: with ops.name_scope("value"): value = nest.pack_sequence_as(value, [ sparse_tensor_lib.SparseTensor.from_value(t) if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( t, name="component_%d" % i) for i, t in enumerate(nest.flatten(value)) ]) encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value)) output_classes = sparse.get_classes(value) output_shapes = nest.pack_sequence_as( value, [t.get_shape() for t in nest.flatten(value)]) output_types = nest.pack_sequence_as( value, [t.dtype for t in nest.flatten(value)]) return _OptionalImpl( gen_dataset_ops.optional_from_value(encoded_value, name=scope), output_shapes, output_types, output_classes)
def _prefetch_fn(handle): """Prefetches one element from `input_iterator`.""" remote_iterator = iterator_ops.Iterator.from_string_handle( handle, self.output_types, self.output_shapes, self.output_classes) ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret))
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: classes = sparse.get_classes(expected) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), types, shapes, sparse.get_classes(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access return gen_dataset_ops.scan_dataset( input_t, nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, **dataset_ops.flat_structure(self))
def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access return gen_dataset_ops.scan_dataset( input_t, nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, **dataset_ops.flat_structure(self))
def tf_reduce_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes( input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) nested_state_args = nest.pack_sequence_as( self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) nested_input_args = nest.pack_sequence_as( input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = reduce_func(nested_state_args, nested_input_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) # Extract shape information from the returned values. flat_new_state = nest.flatten(ret) flat_new_state_shapes.extend( [t.get_shape() for t in flat_new_state]) # Extract and validate type information from the returned values. for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in flat_new_state]))) # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access return gen_experimental_dataset_ops.experimental_scan_dataset( input_t, nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, **dataset_ops.flat_structure(self))
def tf_reduce_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) nested_state_args = nest.pack_sequence_as(self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) nested_input_args = nest.pack_sequence_as(input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = reduce_func(nested_state_args, nested_input_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) # Extract shape information from the returned values. flat_new_state = nest.flatten(ret) flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) # Extract and validate type information from the returned values. for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as(self._state_types, [t.dtype for t in flat_new_state]))) dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access return gen_dataset_ops.scan_dataset( input_t, nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
def testSerializeDeserialize(self, input_fn): test_case = input_fn() classes = sparse.get_classes(test_case) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(test_case), types, shapes, sparse.get_classes(test_case)) nest.assert_same_structure(test_case, actual) for a, e in zip(nest.flatten(actual), nest.flatten(test_case)): self.assertSparseValuesEqual(a, e)
def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access return gen_dataset_ops.scan_dataset( input_t, nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
def _next_func(string_handle): """Calls get_next for created iterator. Args: string_handle: An iterator string handle created by _init_func Returns: The elements generated from `input_dataset` """ with ops.device(self._source_device_string): iterator = iterator_ops.Iterator.from_string_handle( string_handle, self.output_types, self.output_shapes, self.output_classes) ret = iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret))
def _next_func(string_handle): """Calls get_next for created iterator. Args: string_handle: An iterator string handle created by _init_func Returns: The elements generated from `input_dataset` """ with ops.device(self._source_device_string): iterator = iterator_ops.Iterator.from_string_handle( string_handle, self.output_types, self.output_shapes, self.output_classes) ret = iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret))
def _prefetch_fn(handle): """Prefetches one element from `input_iterator`.""" remote_iterator = iterator_ops.Iterator.from_string_handle( handle, self._input_iterator.output_types, self._input_iterator.output_shapes, self._input_iterator.output_classes) ret = remote_iterator.get_next() # Convert any `SparseTensorValue`s to `SparseTensor`s. ret = nest.pack_sequence_as(ret, [ sparse_tensor_lib.SparseTensor.from_value(t) if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) ]) # Serialize any sparse tensors and convert result to tensors. ret = nest.pack_sequence_as(ret, [ ops.convert_to_tensor(t) for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def tf_init_func(key): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) ret = init_func(key) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._state_classes = sparse.get_classes(ret) self._state_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections( "tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def tf_init_func(key): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) ret = init_func(key) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._state_classes = sparse.get_classes(ret) self._state_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), sparse.get_sparse_types(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def tf_init_func(key): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) ret = init_func(key) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._state_classes = sparse.get_classes(ret) self._state_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes( input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) print(self._state_classes) nested_state_args = nest.pack_sequence_as( self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) print(input_dataset.output_classes) nested_input_args = nest.pack_sequence_as( input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError( "The scan function must return a pair comprising the " "new state and the output value.") # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) new_state, output_value = ret # Extract and validate class information from the returned values. for t, clazz in zip(nest.flatten(new_state), nest.flatten(self._state_classes)): if not isinstance(t, clazz): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, nest.pack_sequence_as( self._state_types, [type(t) for t in nest.flatten(new_state)]))) self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. flat_new_state_shapes.extend( [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. for t, dtype in zip(nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(new_state)) ]) output_value = nest.pack_sequence_as(output_value, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(output_value)) ]) return nest.flatten(new_state) + nest.flatten(output_value)
def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) print(self._state_classes) nested_state_args = nest.pack_sequence_as(self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) print(input_dataset.output_classes) nested_input_args = nest.pack_sequence_as(input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) new_state, output_value = ret # Extract and validate class information from the returned values. for t, clazz in zip( nest.flatten(new_state), nest.flatten(self._state_classes)): if not isinstance(t, clazz): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, nest.pack_sequence_as( self._state_types, [type(t) for t in nest.flatten(new_state)]))) self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. flat_new_state_shapes.extend( [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. for t, dtype in zip( nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) ]) output_value = nest.pack_sequence_as(output_value, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(output_value)) ]) return nest.flatten(new_state) + nest.flatten(output_value)