コード例 #1
0
 def _prefetch_fn(handle):
   """Prefetches one element from `input_iterator`."""
   remote_iterator = iterator_ops.Iterator.from_string_handle(
       handle, input_iterator.output_types, input_iterator.output_shapes,
       input_iterator.output_classes)
   ret = remote_iterator.get_next()
   return nest.flatten(sparse.serialize_sparse_tensors(ret))
コード例 #2
0
 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((),
        sparse_tensor.SparseTensor(
            indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     classes = sparse.get_classes(expected)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected), types, shapes,
         sparse.get_classes(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
コード例 #3
0
ファイル: grouping.py プロジェクト: Jackiefan/tensorflow
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
コード例 #4
0
ファイル: optional_ops.py プロジェクト: zpdcqu/tensorflow
  def from_value(value):
    """Returns an `Optional` that wraps the given value.

    Args:
      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.

    Returns:
      An `Optional` that wraps `value`.
    """
    # TODO(b/110122868): Consolidate this destructuring logic with the
    # similar code in `Dataset.from_tensors()`.
    with ops.name_scope("optional") as scope:
      with ops.name_scope("value"):
        value = nest.pack_sequence_as(value, [
            sparse_tensor_lib.SparseTensor.from_value(t)
            if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
                t, name="component_%d" % i)
            for i, t in enumerate(nest.flatten(value))
        ])

      encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
      output_classes = sparse.get_classes(value)
      output_shapes = nest.pack_sequence_as(
          value, [t.get_shape() for t in nest.flatten(value)])
      output_types = nest.pack_sequence_as(
          value, [t.dtype for t in nest.flatten(value)])

    return _OptionalImpl(
        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
        output_shapes, output_types, output_classes)
コード例 #5
0
ファイル: sparse_test.py プロジェクト: zxthunter/tensorflow
 def testSerializeDeserialize(self):
     test_cases = (
         (),
         sparse_tensor.SparseTensor(indices=[[0, 0]],
                                    values=[1],
                                    dense_shape=[1, 1]),
         sparse_tensor.SparseTensor(indices=[[3, 4]],
                                    values=[-1],
                                    dense_shape=[4, 5]),
         sparse_tensor.SparseTensor(indices=[[0, 0], [3, 4]],
                                    values=[1, -1],
                                    dense_shape=[4, 5]),
         (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                     values=[1],
                                     dense_shape=[1, 1])),
         (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                     values=[1],
                                     dense_shape=[1, 1]), ()),
         ((),
          sparse_tensor.SparseTensor(indices=[[0, 0]],
                                     values=[1],
                                     dense_shape=[1, 1])),
     )
     for expected in test_cases:
         actual = sparse.deserialize_sparse_tensors(
             sparse.serialize_sparse_tensors(expected),
             sparse.get_sparse_types(expected))
         nest.assert_same_structure(expected, actual)
         for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
             self.assertSparseValuesEqual(a, e)
コード例 #6
0
ファイル: grouping.py プロジェクト: zhuyangda/tensorflow
        def tf_init_func(key):
            """A wrapper for Defun that facilitates shape inference."""
            key.set_shape([])
            ret = init_func(key)
            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._state_classes = sparse.get_classes(ret)
            self._state_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._state_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
コード例 #7
0
        def tf_finalize_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            for arg, shape in zip(
                    args,
                    nest.flatten(
                        sparse.as_dense_shapes(self._state_shapes,
                                               self._state_classes))):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(self._state_types, args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, self._state_types, self._state_shapes,
                self._state_classes)

            ret = finalize_func(nested_args)

            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._output_classes = sparse.get_classes(ret)
            self._output_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._output_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
コード例 #8
0
ファイル: optional_ops.py プロジェクト: AnishShah/tensorflow
  def from_value(value):
    """Returns an `Optional` that wraps the given value.

    Args:
      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.

    Returns:
      An `Optional` that wraps `value`.
    """
    # TODO(b/110122868): Consolidate this destructuring logic with the
    # similar code in `Dataset.from_tensors()`.
    with ops.name_scope("optional") as scope:
      with ops.name_scope("value"):
        value = nest.pack_sequence_as(value, [
            sparse_tensor_lib.SparseTensor.from_value(t)
            if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
                t, name="component_%d" % i)
            for i, t in enumerate(nest.flatten(value))
        ])

      encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
      output_classes = sparse.get_classes(value)
      output_shapes = nest.pack_sequence_as(
          value, [t.get_shape() for t in nest.flatten(value)])
      output_types = nest.pack_sequence_as(
          value, [t.dtype for t in nest.flatten(value)])

    return _OptionalImpl(
        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
        output_shapes, output_types, output_classes)
コード例 #9
0
ファイル: prefetching_ops.py プロジェクト: imdone/tensorflow
 def _prefetch_fn(handle):
     """Prefetches one element from `input_iterator`."""
     remote_iterator = iterator_ops.Iterator.from_string_handle(
         handle, self.output_types, self.output_shapes,
         self.output_classes)
     ret = remote_iterator.get_next()
     return nest.flatten(sparse.serialize_sparse_tensors(ret))
コード例 #10
0
ファイル: sparse_test.py プロジェクト: abidrahmank/tensorflow
 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((), sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     classes = sparse.get_classes(expected)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected), types, shapes,
         sparse.get_classes(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
コード例 #11
0
ファイル: grouping.py プロジェクト: xman/tensorflow
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
コード例 #12
0
 def _as_variant_tensor(self):
     input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
     return gen_dataset_ops.scan_dataset(
         input_t,
         nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
         self._scan_func.captured_inputs,
         f=self._scan_func,
         **dataset_ops.flat_structure(self))
コード例 #13
0
ファイル: scan_ops.py プロジェクト: AnishShah/tensorflow
 def _as_variant_tensor(self):
   input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
   return gen_dataset_ops.scan_dataset(
       input_t,
       nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
       self._scan_func.captured_inputs,
       f=self._scan_func,
       **dataset_ops.flat_structure(self))
コード例 #14
0
            def tf_reduce_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = reduce_func(nested_state_args, nested_input_args)

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])

                # Extract shape information from the returned values.
                flat_new_state = nest.flatten(ret)
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in flat_new_state])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(flat_new_state,
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in flat_new_state])))

                # Serialize any sparse tensors.
                ret = nest.pack_sequence_as(ret, [
                    t
                    for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
                ])
                return nest.flatten(ret)
コード例 #15
0
 def _as_variant_tensor(self):
     input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
     return gen_experimental_dataset_ops.experimental_scan_dataset(
         input_t,
         nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
         self._scan_func.function.captured_inputs,
         f=self._scan_func.function,
         preserve_cardinality=True,
         **dataset_ops.flat_structure(self))
コード例 #16
0
ファイル: grouping.py プロジェクト: xman/tensorflow
      def tf_reduce_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = reduce_func(nested_state_args, nested_input_args)

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])

        # Extract shape information from the returned values.
        flat_new_state = nest.flatten(ret)
        flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(self._state_types,
                                       [t.dtype for t in flat_new_state])))

        dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        ret = nest.pack_sequence_as(
            ret,
            [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
        return nest.flatten(ret)
コード例 #17
0
ファイル: scan_ops.py プロジェクト: xman/tensorflow
 def _as_variant_tensor(self):
   input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
   return gen_dataset_ops.scan_dataset(
       input_t,
       nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
       self._scan_func.captured_inputs,
       f=self._scan_func,
       output_types=nest.flatten(
           sparse.as_dense_types(self.output_types, self.output_classes)),
       output_shapes=nest.flatten(
           sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
コード例 #18
0
ファイル: sparse_test.py プロジェクト: xulin2005/tensorflow-1
 def testSerializeDeserialize(self, input_fn):
     test_case = input_fn()
     classes = sparse.get_classes(test_case)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(test_case), types, shapes,
         sparse.get_classes(test_case))
     nest.assert_same_structure(test_case, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(test_case)):
         self.assertSparseValuesEqual(a, e)
コード例 #19
0
 def _as_variant_tensor(self):
     input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
     return gen_dataset_ops.scan_dataset(
         input_t,
         nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
         self._scan_func.captured_inputs,
         f=self._scan_func,
         output_types=nest.flatten(
             sparse.as_dense_types(self.output_types, self.output_classes)),
         output_shapes=nest.flatten(
             sparse.as_dense_shapes(self.output_shapes,
                                    self.output_classes)))
コード例 #20
0
        def _next_func(string_handle):
            """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
            with ops.device(self._source_device_string):
                iterator = iterator_ops.Iterator.from_string_handle(
                    string_handle, self.output_types, self.output_shapes,
                    self.output_classes)
            ret = iterator.get_next()
            return nest.flatten(sparse.serialize_sparse_tensors(ret))
コード例 #21
0
    def _next_func(string_handle):
      """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
      with ops.device(self._source_device_string):
        iterator = iterator_ops.Iterator.from_string_handle(
            string_handle, self.output_types, self.output_shapes,
            self.output_classes)
      ret = iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))
コード例 #22
0
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, self._input_iterator.output_types,
          self._input_iterator.output_shapes,
          self._input_iterator.output_classes)
      ret = remote_iterator.get_next()

      # Convert any `SparseTensorValue`s to `SparseTensor`s.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor_lib.SparseTensor.from_value(t)
          if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
      ])

      # Serialize any sparse tensors and convert result to tensors.
      ret = nest.pack_sequence_as(ret, [
          ops.convert_to_tensor(t)
          for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
      ])
      return nest.flatten(ret)
コード例 #23
0
ファイル: grouping.py プロジェクト: Jackiefan/tensorflow
    def tf_init_func(key):
      """A wrapper for Defun that facilitates shape inference."""
      key.set_shape([])
      ret = init_func(key)
      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._state_classes = sparse.get_classes(ret)
      self._state_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._state_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
コード例 #24
0
ファイル: grouping.py プロジェクト: zhuyangda/tensorflow
        def tf_finalize_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            for arg, shape in zip(
                    args,
                    nest.flatten(
                        sparse.as_dense_shapes(self._state_shapes,
                                               self._state_classes))):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(self._state_types, args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, self._state_types, self._state_shapes,
                self._state_classes)

            ret = finalize_func(nested_args)

            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._output_classes = sparse.get_classes(ret)
            self._output_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._output_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
コード例 #25
0
        def tf_init_func(key):
            """A wrapper for Defun that facilitates shape inference."""
            key.set_shape([])
            ret = init_func(key)
            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._state_classes = sparse.get_classes(ret)
            self._state_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._state_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
コード例 #26
0
ファイル: sparse_test.py プロジェクト: SylChan/tensorflow
 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((), sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected),
         sparse.get_sparse_types(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
コード例 #27
0
ファイル: grouping.py プロジェクト: xman/tensorflow
    def tf_init_func(key):
      """A wrapper for Defun that facilitates shape inference."""
      key.set_shape([])
      ret = init_func(key)
      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._state_classes = sparse.get_classes(ret)
      self._state_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._state_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
コード例 #28
0
            def tf_scan_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                # Pass in shape information from the state and input_dataset.
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                print(self._state_classes)
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                print(input_dataset.output_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = scan_func(nested_state_args, nested_input_args)
                if not isinstance(ret, collections.Sequence) or len(ret) != 2:
                    raise TypeError(
                        "The scan function must return a pair comprising the "
                        "new state and the output value.")

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])
                new_state, output_value = ret

                # Extract and validate class information from the returned values.
                for t, clazz in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_classes)):
                    if not isinstance(t, clazz):
                        raise TypeError(
                            "The element classes for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_classes,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [type(t) for t in nest.flatten(new_state)])))
                self._output_classes = sparse.get_classes(output_value)

                # Extract shape information from the returned values.
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in nest.flatten(new_state)])
                self._output_shapes = nest.pack_sequence_as(
                    output_value,
                    [t.get_shape() for t in nest.flatten(output_value)])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in nest.flatten(new_state)])))
                self._output_types = nest.pack_sequence_as(
                    output_value,
                    [t.dtype for t in nest.flatten(output_value)])

                dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

                # Serialize any sparse tensors.
                new_state = nest.pack_sequence_as(new_state, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(new_state))
                ])
                output_value = nest.pack_sequence_as(output_value, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(output_value))
                ])
                return nest.flatten(new_state) + nest.flatten(output_value)
コード例 #29
0
ファイル: scan_ops.py プロジェクト: xman/tensorflow
      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        print(self._state_classes)
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        print(input_dataset.output_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = scan_func(nested_state_args, nested_input_args)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])
        new_state, output_value = ret

        # Extract and validate class information from the returned values.
        for t, clazz in zip(
            nest.flatten(new_state), nest.flatten(self._state_classes)):
          if not isinstance(t, clazz):
            raise TypeError(
                "The element classes for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_classes,
                 nest.pack_sequence_as(
                     self._state_types,
                     [type(t) for t in nest.flatten(new_state)])))
        self._output_classes = sparse.get_classes(output_value)

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend(
            [t.get_shape() for t in nest.flatten(new_state)])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.get_shape() for t in nest.flatten(output_value)])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(
            nest.flatten(new_state), nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(
                     self._state_types,
                     [t.dtype for t in nest.flatten(new_state)])))
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in nest.flatten(output_value)])

        dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        new_state = nest.pack_sequence_as(new_state, [
            t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
        ])
        output_value = nest.pack_sequence_as(output_value, [
            t for t in nest.flatten(
                sparse.serialize_sparse_tensors(output_value))
        ])
        return nest.flatten(new_state) + nest.flatten(output_value)