Ejemplo n.º 1
0
 def testSerializeManyDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((),
        sparse_tensor.SparseTensor(
            indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     classes = sparse.get_classes(expected)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_many_sparse_tensors(expected), types, shapes,
         sparse.get_classes(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
Ejemplo n.º 2
0
 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((), sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     classes = sparse.get_classes(expected)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected), types, shapes,
         sparse.get_classes(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
Ejemplo n.º 3
0
 def testSerializeManyDeserialize(self, input_fn):
     test_case = input_fn()
     classes = sparse.get_classes(test_case)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_many_sparse_tensors(test_case), types, shapes,
         sparse.get_classes(test_case))
     nest.assert_same_structure(test_case, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(test_case)):
         self.assertSparseValuesEqual(a, e)
Ejemplo n.º 4
0
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
Ejemplo n.º 5
0
        def tf_init_func(key):
            """A wrapper for Defun that facilitates shape inference."""
            key.set_shape([])
            ret = init_func(key)
            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._state_classes = sparse.get_classes(ret)
            self._state_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._state_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
Ejemplo n.º 6
0
  def from_value(value):
    """Returns an `Optional` that wraps the given value.

    Args:
      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.

    Returns:
      An `Optional` that wraps `value`.
    """
    # TODO(b/110122868): Consolidate this destructuring logic with the
    # similar code in `Dataset.from_tensors()`.
    with ops.name_scope("optional") as scope:
      with ops.name_scope("value"):
        value = nest.pack_sequence_as(value, [
            sparse_tensor_lib.SparseTensor.from_value(t)
            if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
                t, name="component_%d" % i)
            for i, t in enumerate(nest.flatten(value))
        ])

      encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
      output_classes = sparse.get_classes(value)
      output_shapes = nest.pack_sequence_as(
          value, [t.get_shape() for t in nest.flatten(value)])
      output_types = nest.pack_sequence_as(
          value, [t.dtype for t in nest.flatten(value)])

    return _OptionalImpl(
        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
        output_shapes, output_types, output_classes)
Ejemplo n.º 7
0
        def tf_finalize_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            for arg, shape in zip(
                    args,
                    nest.flatten(
                        sparse.as_dense_shapes(self._state_shapes,
                                               self._state_classes))):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(self._state_types, args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, self._state_types, self._state_shapes,
                self._state_classes)

            ret = finalize_func(nested_args)

            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._output_classes = sparse.get_classes(ret)
            self._output_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._output_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
Ejemplo n.º 8
0
  def from_value(value):
    """Returns an `Optional` that wraps the given value.

    Args:
      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.

    Returns:
      An `Optional` that wraps `value`.
    """
    # TODO(b/110122868): Consolidate this destructuring logic with the
    # similar code in `Dataset.from_tensors()`.
    with ops.name_scope("optional") as scope:
      with ops.name_scope("value"):
        value = nest.pack_sequence_as(value, [
            sparse_tensor_lib.SparseTensor.from_value(t)
            if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
                t, name="component_%d" % i)
            for i, t in enumerate(nest.flatten(value))
        ])

      encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
      output_classes = sparse.get_classes(value)
      output_shapes = nest.pack_sequence_as(
          value, [t.get_shape() for t in nest.flatten(value)])
      output_types = nest.pack_sequence_as(
          value, [t.dtype for t in nest.flatten(value)])

    return _OptionalImpl(
        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
        output_shapes, output_types, output_classes)
Ejemplo n.º 9
0
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
Ejemplo n.º 10
0
 def testGetClasses(self):
     s = sparse_tensor.SparseTensor(indices=[[0]],
                                    values=[1],
                                    dense_shape=[1])
     d = ops.Tensor
     t = sparse_tensor.SparseTensor
     test_cases = (
         {
             "classes": (),
             "expected": ()
         },
         {
             "classes": s,
             "expected": t
         },
         {
             "classes": constant_op.constant([1]),
             "expected": d
         },
         {
             "classes": (s),
             "expected": (t)
         },
         {
             "classes": (constant_op.constant([1])),
             "expected": (d)
         },
         {
             "classes": (s, ()),
             "expected": (t, ())
         },
         {
             "classes": ((), s),
             "expected": ((), t)
         },
         {
             "classes": (constant_op.constant([1]), ()),
             "expected": (d, ())
         },
         {
             "classes": ((), constant_op.constant([1])),
             "expected": ((), d)
         },
         {
             "classes": (s, (), constant_op.constant([1])),
             "expected": (t, (), d)
         },
         {
             "classes": ((), s, ()),
             "expected": ((), t, ())
         },
         {
             "classes": ((), constant_op.constant([1]), ()),
             "expected": ((), d, ())
         },
     )
     for test_case in test_cases:
         self.assertEqual(sparse.get_classes(test_case["classes"]),
                          test_case["expected"])
Ejemplo n.º 11
0
 def testGetClasses(self):
   s = sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])
   d = ops.Tensor
   t = sparse_tensor.SparseTensor
   test_cases = (
       {
           "classes": (),
           "expected": ()
       },
       {
           "classes": s,
           "expected": t
       },
       {
           "classes": constant_op.constant([1]),
           "expected": d
       },
       {
           "classes": (s),
           "expected": (t)
       },
       {
           "classes": (constant_op.constant([1])),
           "expected": (d)
       },
       {
           "classes": (s, ()),
           "expected": (t, ())
       },
       {
           "classes": ((), s),
           "expected": ((), t)
       },
       {
           "classes": (constant_op.constant([1]), ()),
           "expected": (d, ())
       },
       {
           "classes": ((), constant_op.constant([1])),
           "expected": ((), d)
       },
       {
           "classes": (s, (), constant_op.constant([1])),
           "expected": (t, (), d)
       },
       {
           "classes": ((), s, ()),
           "expected": ((), t, ())
       },
       {
           "classes": ((), constant_op.constant([1]), ()),
           "expected": ((), d, ())
       },
   )
   for test_case in test_cases:
     self.assertEqual(
         sparse.get_classes(test_case["classes"]), test_case["expected"])
Ejemplo n.º 12
0
    def tf_init_func(key):
      """A wrapper for Defun that facilitates shape inference."""
      key.set_shape([])
      ret = init_func(key)
      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._state_classes = sparse.get_classes(ret)
      self._state_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._state_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
Ejemplo n.º 13
0
        def tf_finalize_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            for arg, shape in zip(
                    args,
                    nest.flatten(
                        sparse.as_dense_shapes(self._state_shapes,
                                               self._state_classes))):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(self._state_types, args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, self._state_types, self._state_shapes,
                self._state_classes)

            ret = finalize_func(nested_args)

            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._output_classes = sparse.get_classes(ret)
            self._output_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._output_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
Ejemplo n.º 14
0
        def tf_init_func(key):
            """A wrapper for Defun that facilitates shape inference."""
            key.set_shape([])
            ret = init_func(key)
            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._state_classes = sparse.get_classes(ret)
            self._state_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._state_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
Ejemplo n.º 15
0
    def tf_init_func(key):
      """A wrapper for Defun that facilitates shape inference."""
      key.set_shape([])
      ret = init_func(key)
      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._state_classes = sparse.get_classes(ret)
      self._state_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._state_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
Ejemplo n.º 16
0
  def __init__(self, input_dataset, initial_state, scan_func):
    """See `scan()` for details."""
    super(_ScanDataset, self).__init__()
    self._input_dataset = input_dataset

    with ops.name_scope("initial_state"):
      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      self._initial_state = nest.pack_sequence_as(initial_state, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
              t, name="component_%d" % i)
          for i, t in enumerate(nest.flatten(initial_state))
      ])

    # Compute initial values for the state classes, shapes and types based on
    # the initial state. The shapes may be refined by running `tf_scan_func` one
    # or more times below.
    self._state_classes = sparse.get_classes(self._initial_state)
    self._state_shapes = nest.pack_sequence_as(
        self._initial_state,
        [t.get_shape() for t in nest.flatten(self._initial_state)])
    self._state_types = nest.pack_sequence_as(
        self._initial_state,
        [t.dtype for t in nest.flatten(self._initial_state)])

    # Will be populated by calling `tf_scan_func`.
    self._output_classes = None
    self._output_shapes = None
    self._output_types = None

    # Iteratively rerun the scan function until reaching a fixed point on
    # `self._state_shapes`.
    need_to_rerun = True
    while need_to_rerun:

      wrapped_func = dataset_ops.StructuredFunctionWrapper(
          scan_func, "tf.contrib.data.scan()",
          input_classes=(self._state_classes, input_dataset.output_classes),
          input_shapes=(self._state_shapes, input_dataset.output_shapes),
          input_types=(self._state_types, input_dataset.output_types),
          add_to_graph=False)
      if not (
          isinstance(wrapped_func.output_types, collections.Sequence) and
          len(wrapped_func.output_types) == 2):
        raise TypeError("The scan function must return a pair comprising the "
                        "new state and the output value.")

      new_state_classes, self._output_classes = wrapped_func.output_classes

      # Extract and validate class information from the returned values.
      for new_state_class, state_class in zip(
          nest.flatten(new_state_classes),
          nest.flatten(self._state_classes)):
        if not issubclass(new_state_class, state_class):
          raise TypeError(
              "The element classes for the new state must match the initial "
              "state. Expected %s; got %s." %
              (self._state_classes, new_state_classes))

      # Extract and validate type information from the returned values.
      new_state_types, self._output_types = wrapped_func.output_types
      for new_state_type, state_type in zip(
          nest.flatten(new_state_types), nest.flatten(self._state_types)):
        if new_state_type != state_type:
          raise TypeError(
              "The element types for the new state must match the initial "
              "state. Expected %s; got %s." %
              (self._state_types, new_state_types))

      # Extract shape information from the returned values.
      new_state_shapes, self._output_shapes = wrapped_func.output_shapes

      flat_state_shapes = nest.flatten(self._state_shapes)
      flat_new_state_shapes = nest.flatten(new_state_shapes)
      weakened_state_shapes = [
          original.most_specific_compatible_shape(new)
          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
      ]

      need_to_rerun = False
      for original_shape, weakened_shape in zip(flat_state_shapes,
                                                weakened_state_shapes):
        if original_shape.ndims is not None and (
            weakened_shape.ndims is None or
            original_shape.as_list() != weakened_shape.as_list()):
          need_to_rerun = True
          break

      if need_to_rerun:
        self._state_shapes = nest.pack_sequence_as(self._state_shapes,
                                                   weakened_state_shapes)

    self._scan_func = wrapped_func.function
    self._scan_func.add_to_graph(ops.get_default_graph())
    def __init__(self, input_dataset, initial_state, scan_func):
        """See `scan()` for details."""
        super(_ScanDataset, self).__init__()
        self._input_dataset = input_dataset

        with ops.name_scope("initial_state"):
            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            self._initial_state = nest.pack_sequence_as(
                initial_state, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
                        t, name="component_%d" % i)
                    for i, t in enumerate(nest.flatten(initial_state))
                ])

        # Compute initial values for the state classes, shapes and types based on
        # the initial state. The shapes may be refined by running `tf_scan_func` one
        # or more times below.
        self._state_classes = sparse.get_classes(self._initial_state)
        self._state_shapes = nest.pack_sequence_as(
            self._initial_state,
            [t.get_shape() for t in nest.flatten(self._initial_state)])
        self._state_types = nest.pack_sequence_as(
            self._initial_state,
            [t.dtype for t in nest.flatten(self._initial_state)])

        # Will be populated by calling `tf_scan_func`.
        self._output_classes = None
        self._output_shapes = None
        self._output_types = None

        # Iteratively rerun the scan function until reaching a fixed point on
        # `self._state_shapes`.
        need_to_rerun = True
        while need_to_rerun:

            wrapped_func = dataset_ops.StructuredFunctionWrapper(
                scan_func,
                "tf.contrib.data.scan()",
                input_classes=(self._state_classes,
                               input_dataset.output_classes),
                input_shapes=(self._state_shapes, input_dataset.output_shapes),
                input_types=(self._state_types, input_dataset.output_types),
                add_to_graph=False)
            if not (isinstance(wrapped_func.output_types, collections.Sequence)
                    and len(wrapped_func.output_types) == 2):
                raise TypeError(
                    "The scan function must return a pair comprising the "
                    "new state and the output value.")

            new_state_classes, self._output_classes = wrapped_func.output_classes

            # Extract and validate class information from the returned values.
            for new_state_class, state_class in zip(
                    nest.flatten(new_state_classes),
                    nest.flatten(self._state_classes)):
                if not issubclass(new_state_class, state_class):
                    raise TypeError(
                        "The element classes for the new state must match the initial "
                        "state. Expected %s; got %s." %
                        (self._state_classes, new_state_classes))

            # Extract and validate type information from the returned values.
            new_state_types, self._output_types = wrapped_func.output_types
            for new_state_type, state_type in zip(
                    nest.flatten(new_state_types),
                    nest.flatten(self._state_types)):
                if new_state_type != state_type:
                    raise TypeError(
                        "The element types for the new state must match the initial "
                        "state. Expected %s; got %s." %
                        (self._state_types, new_state_types))

            # Extract shape information from the returned values.
            new_state_shapes, self._output_shapes = wrapped_func.output_shapes

            flat_state_shapes = nest.flatten(self._state_shapes)
            flat_new_state_shapes = nest.flatten(new_state_shapes)
            weakened_state_shapes = [
                original.most_specific_compatible_shape(new) for original, new
                in zip(flat_state_shapes, flat_new_state_shapes)
            ]

            need_to_rerun = False
            for original_shape, weakened_shape in zip(flat_state_shapes,
                                                      weakened_state_shapes):
                if original_shape.ndims is not None and (
                        weakened_shape.ndims is None or
                        original_shape.as_list() != weakened_shape.as_list()):
                    need_to_rerun = True
                    break

            if need_to_rerun:
                self._state_shapes = nest.pack_sequence_as(
                    self._state_shapes, weakened_state_shapes)

        self._scan_func = wrapped_func.function
        self._scan_func.add_to_graph(ops.get_default_graph())
Ejemplo n.º 18
0
            def tf_scan_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                # Pass in shape information from the state and input_dataset.
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                print(self._state_classes)
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                print(input_dataset.output_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = scan_func(nested_state_args, nested_input_args)
                if not isinstance(ret, collections.Sequence) or len(ret) != 2:
                    raise TypeError(
                        "The scan function must return a pair comprising the "
                        "new state and the output value.")

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])
                new_state, output_value = ret

                # Extract and validate class information from the returned values.
                for t, clazz in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_classes)):
                    if not isinstance(t, clazz):
                        raise TypeError(
                            "The element classes for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_classes,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [type(t) for t in nest.flatten(new_state)])))
                self._output_classes = sparse.get_classes(output_value)

                # Extract shape information from the returned values.
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in nest.flatten(new_state)])
                self._output_shapes = nest.pack_sequence_as(
                    output_value,
                    [t.get_shape() for t in nest.flatten(output_value)])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in nest.flatten(new_state)])))
                self._output_types = nest.pack_sequence_as(
                    output_value,
                    [t.dtype for t in nest.flatten(output_value)])

                dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

                # Serialize any sparse tensors.
                new_state = nest.pack_sequence_as(new_state, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(new_state))
                ])
                output_value = nest.pack_sequence_as(output_value, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(output_value))
                ])
                return nest.flatten(new_state) + nest.flatten(output_value)
Ejemplo n.º 19
0
    def __init__(self, input_dataset, initial_state, scan_func):
        """See `scan()` for details."""
        super(_ScanDataset, self).__init__()
        self._input_dataset = input_dataset

        with ops.name_scope("initial_state"):
            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            self._initial_state = nest.pack_sequence_as(
                initial_state, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
                        t, name="component_%d" % i)
                    for i, t in enumerate(nest.flatten(initial_state))
                ])

        # Compute initial values for the state classes, shapes and types based on
        # the initial state. The shapes may be refined by running `tf_scan_func` one
        # or more times below.
        self._state_classes = sparse.get_classes(self._initial_state)
        self._state_shapes = nest.pack_sequence_as(
            self._initial_state,
            [t.get_shape() for t in nest.flatten(self._initial_state)])
        self._state_types = nest.pack_sequence_as(
            self._initial_state,
            [t.dtype for t in nest.flatten(self._initial_state)])

        # Will be populated by calling `tf_scan_func`.
        self._output_classes = None
        self._output_shapes = None
        self._output_types = None

        # Iteratively rerun the scan function until reaching a fixed point on
        # `self._state_shapes`.
        need_to_rerun = True
        while need_to_rerun:

            # Create a list in which `tf_scan_func` will store the new shapes.
            flat_new_state_shapes = []

            @function.Defun(*(nest.flatten(
                sparse.as_dense_types(
                    self._state_types, self._state_classes)) + nest.flatten(
                        sparse.as_dense_types(input_dataset.output_types,
                                              input_dataset.output_classes))))
            def tf_scan_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                # Pass in shape information from the state and input_dataset.
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                print(self._state_classes)
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                print(input_dataset.output_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = scan_func(nested_state_args, nested_input_args)
                if not isinstance(ret, collections.Sequence) or len(ret) != 2:
                    raise TypeError(
                        "The scan function must return a pair comprising the "
                        "new state and the output value.")

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])
                new_state, output_value = ret

                # Extract and validate class information from the returned values.
                for t, clazz in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_classes)):
                    if not isinstance(t, clazz):
                        raise TypeError(
                            "The element classes for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_classes,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [type(t) for t in nest.flatten(new_state)])))
                self._output_classes = sparse.get_classes(output_value)

                # Extract shape information from the returned values.
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in nest.flatten(new_state)])
                self._output_shapes = nest.pack_sequence_as(
                    output_value,
                    [t.get_shape() for t in nest.flatten(output_value)])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in nest.flatten(new_state)])))
                self._output_types = nest.pack_sequence_as(
                    output_value,
                    [t.dtype for t in nest.flatten(output_value)])

                dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

                # Serialize any sparse tensors.
                new_state = nest.pack_sequence_as(new_state, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(new_state))
                ])
                output_value = nest.pack_sequence_as(output_value, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(output_value))
                ])
                return nest.flatten(new_state) + nest.flatten(output_value)

            # Use the private method that will execute `tf_scan_func` but delay
            # adding it to the graph in case we need to rerun the function.
            tf_scan_func._create_definition_if_needed()  # pylint: disable=protected-access

            flat_state_shapes = nest.flatten(self._state_shapes)
            weakened_state_shapes = [
                original.most_specific_compatible_shape(new) for original, new
                in zip(flat_state_shapes, flat_new_state_shapes)
            ]

            need_to_rerun = False
            for original_shape, weakened_shape in zip(flat_state_shapes,
                                                      weakened_state_shapes):
                if original_shape.ndims is not None and (
                        weakened_shape.ndims is None or
                        original_shape.as_list() != weakened_shape.as_list()):
                    need_to_rerun = True
                    break

            if need_to_rerun:
                # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun
                # `tf_scan_func`.
                self._state_shapes = nest.pack_sequence_as(
                    self._state_shapes, weakened_state_shapes)

        self._scan_func = tf_scan_func
        self._scan_func.add_to_graph(ops.get_default_graph())
Ejemplo n.º 20
0
    def __init__(self, input_dataset, initial_state, scan_func):
        """See `scan()` for details."""
        self._input_dataset = input_dataset

        with ops.name_scope("initial_state"):
            self._initial_state = structure.normalize_tensors(initial_state)

        # Compute initial values for the state classes, shapes and types based on
        # the initial state. The shapes may be refined by running `tf_scan_func` one
        # or more times below.
        self._state_classes = sparse.get_classes(self._initial_state)
        self._state_shapes = nest.pack_sequence_as(
            self._initial_state,
            [t.get_shape() for t in nest.flatten(self._initial_state)])
        self._state_types = nest.pack_sequence_as(
            self._initial_state,
            [t.dtype for t in nest.flatten(self._initial_state)])

        # Will be populated by calling `tf_scan_func`.
        self._output_classes = None
        self._output_shapes = None
        self._output_types = None

        # Iteratively rerun the scan function until reaching a fixed point on
        # `self._state_shapes`.
        need_to_rerun = True
        while need_to_rerun:

            wrapped_func = dataset_ops.StructuredFunctionWrapper(
                scan_func,
                self._transformation_name(),
                input_classes=(self._state_classes,
                               input_dataset.output_classes),
                input_shapes=(self._state_shapes, input_dataset.output_shapes),
                input_types=(self._state_types, input_dataset.output_types),
                add_to_graph=False)
            if not (isinstance(wrapped_func.output_types, collections.Sequence)
                    and len(wrapped_func.output_types) == 2):
                raise TypeError(
                    "The scan function must return a pair comprising the "
                    "new state and the output value.")

            new_state_classes, self._output_classes = wrapped_func.output_classes

            # Extract and validate class information from the returned values.
            for new_state_class, state_class in zip(
                    nest.flatten(new_state_classes),
                    nest.flatten(self._state_classes)):
                if not issubclass(new_state_class, state_class):
                    raise TypeError(
                        "The element classes for the new state must match the initial "
                        "state. Expected %s; got %s." %
                        (self._state_classes, new_state_classes))

            # Extract and validate type information from the returned values.
            new_state_types, self._output_types = wrapped_func.output_types
            for new_state_type, state_type in zip(
                    nest.flatten(new_state_types),
                    nest.flatten(self._state_types)):
                if new_state_type != state_type:
                    raise TypeError(
                        "The element types for the new state must match the initial "
                        "state. Expected %s; got %s." %
                        (self._state_types, new_state_types))

            # Extract shape information from the returned values.
            new_state_shapes, self._output_shapes = wrapped_func.output_shapes

            flat_state_shapes = nest.flatten(self._state_shapes)
            flat_new_state_shapes = nest.flatten(new_state_shapes)
            weakened_state_shapes = [
                original.most_specific_compatible_shape(new) for original, new
                in zip(flat_state_shapes, flat_new_state_shapes)
            ]

            need_to_rerun = False
            for original_shape, weakened_shape in zip(flat_state_shapes,
                                                      weakened_state_shapes):
                if original_shape.ndims is not None and (
                        weakened_shape.ndims is None or
                        original_shape.as_list() != weakened_shape.as_list()):
                    need_to_rerun = True
                    break

            if need_to_rerun:
                self._state_shapes = nest.pack_sequence_as(
                    self._state_shapes, weakened_state_shapes)

        self._scan_func = wrapped_func
        self._scan_func.function.add_to_graph(ops.get_default_graph())
        # pylint: disable=protected-access
        variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset(
            self._input_dataset._variant_tensor,
            self._state_structure._to_tensor_list(self._initial_state),
            self._scan_func.function.captured_inputs,
            f=self._scan_func.function,
            preserve_cardinality=True,
            **dataset_ops.flat_structure(self))
        super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
Ejemplo n.º 21
0
      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        print(self._state_classes)
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        print(input_dataset.output_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = scan_func(nested_state_args, nested_input_args)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])
        new_state, output_value = ret

        # Extract and validate class information from the returned values.
        for t, clazz in zip(
            nest.flatten(new_state), nest.flatten(self._state_classes)):
          if not isinstance(t, clazz):
            raise TypeError(
                "The element classes for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_classes,
                 nest.pack_sequence_as(
                     self._state_types,
                     [type(t) for t in nest.flatten(new_state)])))
        self._output_classes = sparse.get_classes(output_value)

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend(
            [t.get_shape() for t in nest.flatten(new_state)])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.get_shape() for t in nest.flatten(output_value)])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(
            nest.flatten(new_state), nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(
                     self._state_types,
                     [t.dtype for t in nest.flatten(new_state)])))
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in nest.flatten(output_value)])

        dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        new_state = nest.pack_sequence_as(new_state, [
            t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
        ])
        output_value = nest.pack_sequence_as(output_value, [
            t for t in nest.flatten(
                sparse.serialize_sparse_tensors(output_value))
        ])
        return nest.flatten(new_state) + nest.flatten(output_value)
Ejemplo n.º 22
0
  def __init__(self, input_dataset, initial_state, scan_func):
    """See `scan()` for details."""
    super(_ScanDataset, self).__init__()
    self._input_dataset = input_dataset

    with ops.name_scope("initial_state"):
      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      self._initial_state = nest.pack_sequence_as(initial_state, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
              t, name="component_%d" % i)
          for i, t in enumerate(nest.flatten(initial_state))
      ])

    # Compute initial values for the state classes, shapes and types based on
    # the initial state. The shapes may be refined by running `tf_scan_func` one
    # or more times below.
    self._state_classes = sparse.get_classes(self._initial_state)
    self._state_shapes = nest.pack_sequence_as(
        self._initial_state,
        [t.get_shape() for t in nest.flatten(self._initial_state)])
    self._state_types = nest.pack_sequence_as(
        self._initial_state,
        [t.dtype for t in nest.flatten(self._initial_state)])

    # Will be populated by calling `tf_scan_func`.
    self._output_classes = None
    self._output_shapes = None
    self._output_types = None

    # Iteratively rerun the scan function until reaching a fixed point on
    # `self._state_shapes`.
    need_to_rerun = True
    while need_to_rerun:

      # Create a list in which `tf_scan_func` will store the new shapes.
      flat_new_state_shapes = []

      @function.Defun(*(nest.flatten(
          sparse.as_dense_types(
              self._state_types, self._state_classes)) + nest.flatten(
                  sparse.as_dense_types(input_dataset.output_types,
                                        input_dataset.output_classes))))
      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        print(self._state_classes)
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        print(input_dataset.output_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = scan_func(nested_state_args, nested_input_args)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])
        new_state, output_value = ret

        # Extract and validate class information from the returned values.
        for t, clazz in zip(
            nest.flatten(new_state), nest.flatten(self._state_classes)):
          if not isinstance(t, clazz):
            raise TypeError(
                "The element classes for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_classes,
                 nest.pack_sequence_as(
                     self._state_types,
                     [type(t) for t in nest.flatten(new_state)])))
        self._output_classes = sparse.get_classes(output_value)

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend(
            [t.get_shape() for t in nest.flatten(new_state)])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.get_shape() for t in nest.flatten(output_value)])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(
            nest.flatten(new_state), nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(
                     self._state_types,
                     [t.dtype for t in nest.flatten(new_state)])))
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in nest.flatten(output_value)])

        dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        new_state = nest.pack_sequence_as(new_state, [
            t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
        ])
        output_value = nest.pack_sequence_as(output_value, [
            t for t in nest.flatten(
                sparse.serialize_sparse_tensors(output_value))
        ])
        return nest.flatten(new_state) + nest.flatten(output_value)

      # Use the private method that will execute `tf_scan_func` but delay
      # adding it to the graph in case we need to rerun the function.
      tf_scan_func._create_definition_if_needed()  # pylint: disable=protected-access

      flat_state_shapes = nest.flatten(self._state_shapes)
      weakened_state_shapes = [
          original.most_specific_compatible_shape(new)
          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
      ]

      need_to_rerun = False
      for original_shape, weakened_shape in zip(flat_state_shapes,
                                                weakened_state_shapes):
        if original_shape.ndims is not None and (
            weakened_shape.ndims is None or
            original_shape.as_list() != weakened_shape.as_list()):
          need_to_rerun = True
          break

      if need_to_rerun:
        # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun
        # `tf_scan_func`.
        self._state_shapes = nest.pack_sequence_as(self._state_shapes,
                                                   weakened_state_shapes)

    self._scan_func = tf_scan_func
    self._scan_func.add_to_graph(ops.get_default_graph())
Ejemplo n.º 23
0
 def testGetClasses(self, classes_fn, expected_fn):
     classes = classes_fn()
     expected = expected_fn()
     self.assertEqual(sparse.get_classes(classes), expected)