示例#1
0
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
示例#2
0
        def tf_key_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            # Pass in shape information from the input_dataset.
            dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                                  input_dataset.output_classes)
            for arg, shape in zip(args, nest.flatten(dense_shapes)):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(input_dataset.output_types,
                                                args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, input_dataset.output_types,
                input_dataset.output_shapes, input_dataset.output_classes)
            # pylint: disable=protected-access
            if dataset_ops._should_unpack_args(nested_args):
                ret = key_func(*nested_args)
            # pylint: enable=protected-access
            else:
                ret = key_func(nested_args)
            ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
            if ret.dtype != dtypes.int64:
                raise ValueError(
                    "`key_func` must return a single tf.int64 tensor.")
            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_window()")  # pylint: disable=protected-access
            return ret
示例#3
0
    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret)
      if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
        raise ValueError(
            "`key_func` must return a single tf.int64 tensor. "
            "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access
      return ret
示例#4
0
        def tf_init_func(key):
            """A wrapper for Defun that facilitates shape inference."""
            key.set_shape([])
            ret = init_func(key)
            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._state_classes = sparse.get_classes(ret)
            self._state_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._state_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
示例#5
0
 def tf_window_size_func(key):
   key.set_shape([])
   window_size = ops.convert_to_tensor(
       window_size_func(key), dtype=dtypes.int64)
   if window_size.dtype != dtypes.int64:
     raise ValueError(
         "`window_size_func` must return a single tf.int64 tensor.")
   dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()")  # pylint: disable=protected-access
   return window_size
示例#6
0
 def tf_window_size_func(key):
     key.set_shape([])
     window_size = ops.convert_to_tensor(window_size_func(key),
                                         dtype=dtypes.int64)
     if window_size.dtype != dtypes.int64:
         raise ValueError(
             "`window_size_func` must return a single tf.int64 tensor.")
     dataset_ops._warn_if_collections(
         "tf.contrib.data.group_by_window()")  # pylint: disable=protected-access
     return window_size
示例#7
0
      def tf_reduce_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = reduce_func(nested_state_args, nested_input_args)

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])

        # Extract shape information from the returned values.
        flat_new_state = nest.flatten(ret)
        flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(self._state_types,
                                       [t.dtype for t in flat_new_state])))

        dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        ret = nest.pack_sequence_as(
            ret,
            [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
        return nest.flatten(ret)
示例#8
0
 def tf_reduce_func(key, window_dataset_variant):
   """A wrapper for Defun that facilitates shape inference."""
   key.set_shape([])
   window_dataset = _VariantDataset(
       window_dataset_variant, input_dataset.output_types,
       input_dataset.output_shapes, input_dataset.output_classes)
   if not isinstance(window_dataset, dataset_ops.Dataset):
     raise TypeError("`window_dataset` must return a `Dataset` object.")
   output_dataset = reduce_func(key, window_dataset)
   if not isinstance(output_dataset, dataset_ops.Dataset):
     raise TypeError("`reduce_func` must return a `Dataset` object.")
   self._output_classes = output_dataset.output_classes
   self._output_types = output_dataset.output_types
   self._output_shapes = output_dataset.output_shapes
   dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()")  # pylint: disable=protected-access
   return output_dataset._as_variant_tensor()  # pylint: disable=protected-access
示例#9
0
 def tf_reduce_func(key, window_dataset_variant):
     """A wrapper for Defun that facilitates shape inference."""
     key.set_shape([])
     window_dataset = _VariantDataset(window_dataset_variant,
                                      input_dataset.output_types,
                                      input_dataset.output_shapes,
                                      input_dataset.output_classes)
     if not isinstance(window_dataset, dataset_ops.Dataset):
         raise TypeError(
             "`window_dataset` must return a `Dataset` object.")
     output_dataset = reduce_func(key, window_dataset)
     if not isinstance(output_dataset, dataset_ops.Dataset):
         raise TypeError(
             "`reduce_func` must return a `Dataset` object.")
     self._output_classes = output_dataset.output_classes
     self._output_types = output_dataset.output_types
     self._output_shapes = output_dataset.output_shapes
     dataset_ops._warn_if_collections(
         "tf.contrib.data.group_by_window()")  # pylint: disable=protected-access
     return output_dataset._as_variant_tensor()  # pylint: disable=protected-access
示例#10
0
        def tf_finalize_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            for arg, shape in zip(
                    args,
                    nest.flatten(
                        sparse.as_dense_shapes(self._state_shapes,
                                               self._state_classes))):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(self._state_types, args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, self._state_types, self._state_shapes,
                self._state_classes)

            ret = finalize_func(nested_args)

            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._output_classes = sparse.get_classes(ret)
            self._output_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._output_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
示例#11
0
    def tf_init_func(key):
      """A wrapper for Defun that facilitates shape inference."""
      key.set_shape([])
      ret = init_func(key)
      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._state_classes = sparse.get_classes(ret)
      self._state_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._state_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
示例#12
0
            def tf_scan_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                # Pass in shape information from the state and input_dataset.
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                print(self._state_classes)
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                print(input_dataset.output_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = scan_func(nested_state_args, nested_input_args)
                if not isinstance(ret, collections.Sequence) or len(ret) != 2:
                    raise TypeError(
                        "The scan function must return a pair comprising the "
                        "new state and the output value.")

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])
                new_state, output_value = ret

                # Extract and validate class information from the returned values.
                for t, clazz in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_classes)):
                    if not isinstance(t, clazz):
                        raise TypeError(
                            "The element classes for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_classes,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [type(t) for t in nest.flatten(new_state)])))
                self._output_classes = sparse.get_classes(output_value)

                # Extract shape information from the returned values.
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in nest.flatten(new_state)])
                self._output_shapes = nest.pack_sequence_as(
                    output_value,
                    [t.get_shape() for t in nest.flatten(output_value)])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in nest.flatten(new_state)])))
                self._output_types = nest.pack_sequence_as(
                    output_value,
                    [t.dtype for t in nest.flatten(output_value)])

                dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

                # Serialize any sparse tensors.
                new_state = nest.pack_sequence_as(new_state, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(new_state))
                ])
                output_value = nest.pack_sequence_as(output_value, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(output_value))
                ])
                return nest.flatten(new_state) + nest.flatten(output_value)
示例#13
0
      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        print(self._state_classes)
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        print(input_dataset.output_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = scan_func(nested_state_args, nested_input_args)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])
        new_state, output_value = ret

        # Extract and validate class information from the returned values.
        for t, clazz in zip(
            nest.flatten(new_state), nest.flatten(self._state_classes)):
          if not isinstance(t, clazz):
            raise TypeError(
                "The element classes for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_classes,
                 nest.pack_sequence_as(
                     self._state_types,
                     [type(t) for t in nest.flatten(new_state)])))
        self._output_classes = sparse.get_classes(output_value)

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend(
            [t.get_shape() for t in nest.flatten(new_state)])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.get_shape() for t in nest.flatten(output_value)])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(
            nest.flatten(new_state), nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(
                     self._state_types,
                     [t.dtype for t in nest.flatten(new_state)])))
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in nest.flatten(output_value)])

        dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        new_state = nest.pack_sequence_as(new_state, [
            t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
        ])
        output_value = nest.pack_sequence_as(output_value, [
            t for t in nest.flatten(
                sparse.serialize_sparse_tensors(output_value))
        ])
        return nest.flatten(new_state) + nest.flatten(output_value)
示例#14
0
            def tf_reduce_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = reduce_func(nested_state_args, nested_input_args)

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])

                # Extract shape information from the returned values.
                flat_new_state = nest.flatten(ret)
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in flat_new_state])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(flat_new_state,
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in flat_new_state])))

                dataset_ops._warn_if_collections(
                    "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

                # Serialize any sparse tensors.
                ret = nest.pack_sequence_as(ret, [
                    t
                    for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
                ])
                return nest.flatten(ret)