Exemplo n.º 1
0
  def create_batch(self):
    """Create queues to window and batch time series data.

    Returns:
      A dictionary of Tensors corresponding to the output of `self._reader`
      (from the `time_series_reader` constructor argument), each with shapes
      prefixed by [`batch_size`, `window_size`].
    """
    features = self._reader.read()
    if self._jitter:
      # TODO(agarwal, allenl): Figure out if more jitter is needed here.
      jitter = random_ops.random_uniform(shape=[], maxval=2, dtype=dtypes.int32)
    else:
      jitter = 0
    # To keep things efficient, we pass from the windowing batcher to the
    # batch-of-windows batcher in batches. This avoids the need for huge numbers
    # of threads, but does mean that jitter is only applied occasionally.
    # TODO(allenl): Experiment with different internal passing sizes.
    internal_passing_size = self._batch_size
    features_windowed = input_lib.batch(
        features,
        batch_size=self._window_size * internal_passing_size + jitter,
        enqueue_many=True,
        capacity=(self._queue_capacity_multiplier
                  * internal_passing_size * self._window_size),
        num_threads=self._num_threads)
    raw_features_windowed = features_windowed
    if self._jitter:
      features_windowed = {
          key: value[jitter:]
          for key, value in features_windowed.items()}
    features_windowed = {
        key: array_ops.reshape(
            value,
            array_ops.concat(
                [[internal_passing_size, self._window_size],
                 array_ops.shape(value)[1:]],
                axis=0))
        for key, value in features_windowed.items()}
    batch_and_window_shape = tensor_shape.TensorShape(
        [internal_passing_size, self._window_size])
    for key in features_windowed.keys():
      features_windowed[key].set_shape(
          batch_and_window_shape.concatenate(
              raw_features_windowed[key].get_shape()[1:]))
    # When switching files, we may end up with windows where the time is not
    # decreasing, even if times within each file are sorted (and even if those
    # files are visited in order, when looping back around to the beginning of
    # the first file). This is hard for models to deal with, so we either
    # discard such examples, creating a bias where the beginning and end of the
    # series is under-sampled, or we sort the window, creating large gaps.
    times = features_windowed[feature_keys.TrainEvalFeatures.TIMES]
    if self._discard_out_of_order:
      non_decreasing = math_ops.reduce_all(
          times[:, 1:] >= times[:, :-1], axis=1)
      # Ensure that no more than self._discard_limit complete batches are
      # discarded contiguously (resetting the count when we find a single clean
      # window). This prevents infinite looping when the dataset is smaller than
      # the window size.
      # TODO(allenl): Figure out a way to return informative errors from
      # count_up_to.
      discarded_windows_limiter = variable_scope.variable(
          initial_value=constant_op.constant(0, dtype=dtypes.int64),
          name="discarded_windows_limiter",
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES])
      def _initialized_limit_check():
        return control_flow_ops.cond(
            math_ops.reduce_any(non_decreasing),
            lambda: state_ops.assign(discarded_windows_limiter, 0),
            lambda: discarded_windows_limiter.count_up_to(self._discard_limit))
      discard_limit_op = control_flow_ops.cond(
          state_ops.is_variable_initialized(discarded_windows_limiter),
          _initialized_limit_check,
          lambda: constant_op.constant(0, dtype=dtypes.int64))
      with ops.control_dependencies([discard_limit_op]):
        non_decreasing = array_ops.identity(non_decreasing)
    else:
      _, indices_descending = nn.top_k(
          times, k=array_ops.shape(times)[-1], sorted=True)
      indices = array_ops.reverse(indices_descending, axis=[0])
      features_windowed = {
          key: array_ops.gather(params=value, indices=indices)
          for key, value in features_windowed.items()
      }
      non_decreasing = True
    features_batched = input_lib.maybe_shuffle_batch(
        features_windowed,
        num_threads=self._num_threads,
        seed=self._shuffle_seed,
        batch_size=self._batch_size,
        capacity=self._queue_capacity_multiplier * self._batch_size,
        min_after_dequeue=(self._shuffle_min_after_dequeue_multiplier *
                           self._batch_size),
        keep_input=non_decreasing,
        enqueue_many=True)
    return (features_batched, None)
    def create_batch(self):
        """Create queues to window and batch time series data.

    Returns:
      A dictionary of Tensors corresponding to the output of `self._reader`
      (from the `time_series_reader` constructor argument), each with shapes
      prefixed by [`batch_size`, `window_size`].
    """
        features = self._reader.read()
        if self._jitter:
            # TODO(agarwal, allenl): Figure out if more jitter is needed here.
            jitter = random_ops.random_uniform(shape=[],
                                               maxval=2,
                                               dtype=dtypes.int32)
        else:
            jitter = 0
        # To keep things efficient, we pass from the windowing batcher to the
        # batch-of-windows batcher in batches. This avoids the need for huge numbers
        # of threads, but does mean that jitter is only applied occasionally.
        # TODO(allenl): Experiment with different internal passing sizes.
        internal_passing_size = self._batch_size
        features_windowed = input_lib.batch(
            features,
            batch_size=self._window_size * internal_passing_size + jitter,
            enqueue_many=True,
            capacity=(self._queue_capacity_multiplier * internal_passing_size *
                      self._window_size),
            num_threads=self._num_threads)
        raw_features_windowed = features_windowed
        if self._jitter:
            features_windowed = {
                key: value[jitter:]
                for key, value in features_windowed.items()
            }
        features_windowed = {
            key: array_ops.reshape(
                value,
                array_ops.concat([[internal_passing_size, self._window_size],
                                  array_ops.shape(value)[1:]],
                                 axis=0))
            for key, value in features_windowed.items()
        }
        batch_and_window_shape = tensor_shape.TensorShape(
            [internal_passing_size, self._window_size])
        for key in features_windowed.keys():
            features_windowed[key].set_shape(
                batch_and_window_shape.concatenate(
                    raw_features_windowed[key].get_shape()[1:]))
        # When switching files, we may end up with windows where the time is not
        # decreasing, even if times within each file are sorted (and even if those
        # files are visited in order, when looping back around to the beginning of
        # the first file). This is hard for models to deal with, so we either
        # discard such examples, creating a bias where the beginning and end of the
        # series is under-sampled, or we sort the window, creating large gaps.
        times = features_windowed[feature_keys.TrainEvalFeatures.TIMES]
        if self._discard_out_of_order:
            non_decreasing = math_ops.reduce_all(times[:, 1:] >= times[:, :-1],
                                                 axis=1)
            # Ensure that no more than self._discard_limit complete batches are
            # discarded contiguously (resetting the count when we find a single clean
            # window). This prevents infinite looping when the dataset is smaller than
            # the window size.
            # TODO(allenl): Figure out a way to return informative errors from
            # count_up_to.
            discarded_windows_limiter = variable_scope.variable(
                initial_value=constant_op.constant(0, dtype=dtypes.int64),
                name="discarded_windows_limiter",
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES])

            def _initialized_limit_check():
                return control_flow_ops.cond(
                    math_ops.reduce_any(non_decreasing),
                    lambda: state_ops.assign(discarded_windows_limiter, 0),
                    lambda: discarded_windows_limiter.count_up_to(
                        self._discard_limit))

            discard_limit_op = control_flow_ops.cond(
                state_ops.is_variable_initialized(discarded_windows_limiter),
                _initialized_limit_check,
                lambda: constant_op.constant(0, dtype=dtypes.int64))
            with ops.control_dependencies([discard_limit_op]):
                non_decreasing = array_ops.identity(non_decreasing)
        else:
            _, indices_descending = nn.top_k(times,
                                             k=array_ops.shape(times)[-1],
                                             sorted=True)
            indices = array_ops.reverse(indices_descending, axis=[0])
            features_windowed = {
                key: array_ops.gather(params=value, indices=indices)
                for key, value in features_windowed.items()
            }
            non_decreasing = True
        features_batched = input_lib.maybe_shuffle_batch(
            features_windowed,
            num_threads=self._num_threads,
            seed=self._shuffle_seed,
            batch_size=self._batch_size,
            capacity=self._queue_capacity_multiplier * self._batch_size,
            min_after_dequeue=(self._shuffle_min_after_dequeue_multiplier *
                               self._batch_size),
            keep_input=non_decreasing,
            enqueue_many=True)
        return (features_batched, None)