예제 #1
0
 def _make_reduce_func(self, reduce_func, input_dataset):
   """Make wrapping Defun for reduce_func."""
   nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset)  # pylint: disable=protected-access
   wrapped_func = dataset_ops.StructuredFunctionWrapper(
       reduce_func, "tf.contrib.data.reduce_by_window()",
       input_classes=(ops.Tensor, nested_dataset),
       input_shapes=(tensor_shape.scalar(), nested_dataset),
       input_types=(dtypes.int64, nested_dataset),
       experimental_nested_dataset_support=True)
   if not isinstance(
       wrapped_func.output_classes, dataset_ops._NestedDatasetComponent):  # pylint: disable=protected-access
     raise TypeError("`reduce_func` must return a `Dataset` object.")
   self._output_classes = wrapped_func.output_classes.output_classes
   self._output_types = wrapped_func.output_types.output_types
   self._output_shapes = wrapped_func.output_shapes.output_shapes
   self._reduce_func = wrapped_func.function
예제 #2
0
 def _make_reduce_func(self, reduce_func, input_dataset):
   """Make wrapping defun for reduce_func."""
   nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset)  # pylint: disable=protected-access
   wrapped_func = dataset_ops.StructuredFunctionWrapper(
       reduce_func,
       self._transformation_name(),
       input_classes=(ops.Tensor, nested_dataset),
       input_shapes=(tensor_shape.scalar(), nested_dataset),
       input_types=(dtypes.int64, nested_dataset))
   if not isinstance(
       wrapped_func.output_classes, dataset_ops._NestedDatasetComponent):  # pylint: disable=protected-access
     raise TypeError("`reduce_func` must return a `Dataset` object.")
   self._output_classes = wrapped_func.output_classes.output_classes
   self._output_types = wrapped_func.output_types.output_types
   self._output_shapes = wrapped_func.output_shapes.output_shapes
   self._reduce_func = wrapped_func.function
예제 #3
0
파일: grouping.py 프로젝트: qwerzou1/shibie
 def _make_reduce_func(self, reduce_func, input_dataset):
     """Make wrapping defun for reduce_func."""
     nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset)  # pylint: disable=protected-access
     wrapped_func = dataset_ops.StructuredFunctionWrapper(
         reduce_func,
         self._transformation_name(),
         input_classes=(ops.Tensor, nested_dataset),
         input_shapes=(tensor_shape.scalar(), nested_dataset),
         input_types=(dtypes.int64, nested_dataset))
     if not isinstance(wrapped_func.output_classes,
                       dataset_ops._NestedDatasetComponent):  # pylint: disable=protected-access
         raise TypeError("`reduce_func` must return a `Dataset` object.")
     self._output_classes = wrapped_func.output_classes.output_classes
     self._output_types = wrapped_func.output_types.output_types
     self._output_shapes = wrapped_func.output_shapes.output_shapes
     self._reduce_func = wrapped_func.function
예제 #4
0
 def _make_reduce_func(self, reduce_func, input_dataset):
     """Make wrapping Defun for reduce_func."""
     nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset)  # pylint: disable=protected-access
     wrapped_func = dataset_ops.StructuredFunctionWrapper(
         reduce_func,
         "tf.contrib.data.reduce_by_window()",
         input_classes=(ops.Tensor, nested_dataset),
         input_shapes=(tensor_shape.scalar(), nested_dataset),
         input_types=(dtypes.int64, nested_dataset),
         experimental_nested_dataset_support=True)
     if not isinstance(wrapped_func.output_classes,
                       dataset_ops._NestedDatasetComponent):  # pylint: disable=protected-access
         raise TypeError("`reduce_func` must return a `Dataset` object.")
     self._output_classes = wrapped_func.output_classes.output_classes
     self._output_types = wrapped_func.output_types.output_types
     self._output_shapes = wrapped_func.output_shapes.output_shapes
     self._reduce_func = wrapped_func.function
예제 #5
0
 def __init__(self, input_dataset, window_size):
   """See `window_dataset()` for more details."""
   super(_WindowDataset, self).__init__()
   self._input_dataset = input_dataset
   self._window_size = ops.convert_to_tensor(
       window_size, dtype=dtypes.int64, name="window_size")
   self._output_classes = nest.pack_sequence_as(
       input_dataset.output_classes,
       [
           dataset_ops._NestedDatasetComponent(  # pylint: disable=protected-access
               output_classes=output_class,
               output_shapes=output_shape,
               output_types=output_type)
           for output_class, output_shape, output_type in zip(
               nest.flatten(input_dataset.output_classes),
               nest.flatten(input_dataset.output_shapes),
               nest.flatten(input_dataset.output_types))
       ])
   self._output_shapes = self._output_classes
   self._output_types = self._output_classes
예제 #6
0
 def __init__(self, input_dataset, window_size):
     """See `window_dataset()` for more details."""
     super(_WindowDataset, self).__init__()
     self._input_dataset = input_dataset
     self._window_size = ops.convert_to_tensor(window_size,
                                               dtype=dtypes.int64,
                                               name="window_size")
     self._output_classes = nest.pack_sequence_as(
         input_dataset.output_classes,
         [
             dataset_ops._NestedDatasetComponent(  # pylint: disable=protected-access
                 output_classes=output_class,
                 output_shapes=output_shape,
                 output_types=output_type)
             for output_class, output_shape, output_type in zip(
                 nest.flatten(input_dataset.output_classes),
                 nest.flatten(input_dataset.output_shapes),
                 nest.flatten(input_dataset.output_types))
         ])
     self._output_shapes = self._output_classes
     self._output_types = self._output_classes