Exemplo n.º 1
0
def report_uninitialized_variables(var_list=None, name="report_uninitialized_variables"):
    """Adds ops to list the names of uninitialized variables.

  When run, it returns a 1-D tensor containing the names of uninitialized
  variables if there are any, or an empty array if there are none.

  Args:
    var_list: List of `Variable` objects to check. Defaults to the
      value of `all_variables() + local_variables()`
    name: Optional name of the `Operation`.

  Returns:
    A 1-D tensor containing names of the unintialized variables, or an empty 1-D
    tensor if there are no variables or no uninitialized variables.
  """
    if var_list is None:
        var_list = all_variables() + local_variables()
    # Backwards compatibility for old-style variables. TODO(touts): remove.
    if not var_list:
        var_list = []
        for op in ops.get_default_graph().get_operations():
            if op.type in ["Variable", "AutoReloadVariable"]:
                var_list.append(op.outputs[0])
    if not var_list:
        # Return an empty tensor so we only need to check for returned tensor
        # size being 0 as an indication of model ready.
        return array_ops.constant([], dtype=dtypes.string, name=name)
    else:
        # Get a 1-D boolean tensor listing whether each variable is initialized.
        variables_mask = math_ops.logical_not(array_ops.pack([state_ops.is_variable_initialized(v) for v in var_list]))
        # Get a 1-D string tensor containing all the variable names.
        variable_names_tensor = array_ops.constant([s.op.name for s in var_list])
        # Return a 1-D tensor containing all the names of uninitialized variables.
        return array_ops.boolean_mask(variable_names_tensor, variables_mask, name=name)
Exemplo n.º 2
0
def is_variable_initialized(variable):
    """Returns an Op to check if a variable has been initialized.

  Args:
    variable: A `Variable`.

  Returns:
    An operation to check whether a variable has been initialized.
  """
    return state_ops.is_variable_initialized(variable)
Exemplo n.º 3
0
def is_variable_initialized(variable):
    """Tests if a variable has been initialized.

  Args:
    variable: A `Variable`.

  Returns:
    Returns a scalar boolean Tensor, `True` if the variable has been
    initialized, `False` otherwise.
  """
    return state_ops.is_variable_initialized(variable)
Exemplo n.º 4
0
    def _get_filename_queue(self, epoch_limit):
        """Constructs a filename queue with an epoch limit.

    `epoch_limit` is intended as an error checking fallback to prevent a reader
    from infinitely looping in its requests for more work items if none are
    available in any file. It should be set high enough that it is never reached
    assuming at least one record exists in some file.

    Args:
      epoch_limit: The maximum number of times to read through the complete list
        of files before throwing an OutOfRangeError.

    Returns:
      A tuple of (filename_queue, epoch_limiter):
        filename_queue: A FIFOQueue with filename work items.
        epoch_limiter: The local variable used for epoch limitation. This should
          be set to zero before a reader is passed `filename_queue` in order to
          reset the epoch limiter's state.
    """
        epoch_limiter = variable_scope.variable(
            initial_value=constant_op.constant(0, dtype=dtypes.int64),
            name="epoch_limiter",
            trainable=False,
            collections=[ops.GraphKeys.LOCAL_VARIABLES])
        filenames_tensor = array_ops.reshape(
            ops.convert_to_tensor(self._filenames), [-1])
        # We can't rely on epoch_limiter being initialized, since queue runners are
        # started before local variables are initialized. Instead, we ignore epoch
        # limits before variable initialization. This means that prior to variable
        # initialization, a QueueRunner may cause a reader to enter an un-checked
        # infinite loop. However, as soon as local variables are initialized, we
        # will start incrementing and checking epoch_limiter, which will interrupt
        # any in-progress loops.
        conditional_count_up_to = control_flow_ops.cond(
            state_ops.is_variable_initialized(epoch_limiter),
            lambda: epoch_limiter.count_up_to(epoch_limit),
            lambda: constant_op.constant(0, dtype=dtypes.int64))
        with ops.control_dependencies([conditional_count_up_to]):
            filenames_tensor = array_ops.identity(filenames_tensor)
        filename_queue = input_lib.string_input_producer(filenames_tensor,
                                                         shuffle=False,
                                                         capacity=1)
        return filename_queue, epoch_limiter
Exemplo n.º 5
0
  def _get_filename_queue(self, epoch_limit):
    """Constructs a filename queue with an epoch limit.

    `epoch_limit` is intended as an error checking fallback to prevent a reader
    from infinitely looping in its requests for more work items if none are
    available in any file. It should be set high enough that it is never reached
    assuming at least one record exists in some file.

    Args:
      epoch_limit: The maximum number of times to read through the complete list
        of files before throwing an OutOfRangeError.

    Returns:
      A tuple of (filename_queue, epoch_limiter):
        filename_queue: A FIFOQueue with filename work items.
        epoch_limiter: The local variable used for epoch limitation. This should
          be set to zero before a reader is passed `filename_queue` in order to
          reset the epoch limiter's state.
    """
    epoch_limiter = variable_scope.variable(
        initial_value=constant_op.constant(0, dtype=dtypes.int64),
        name="epoch_limiter",
        trainable=False,
        collections=[ops.GraphKeys.LOCAL_VARIABLES])
    filenames_tensor = array_ops.reshape(
        ops.convert_to_tensor(self._filenames), [-1])
    # We can't rely on epoch_limiter being initialized, since queue runners are
    # started before local variables are initialized. Instead, we ignore epoch
    # limits before variable initialization. This means that prior to variable
    # initialization, a QueueRunner may cause a reader to enter an un-checked
    # infinite loop. However, as soon as local variables are initialized, we
    # will start incrementing and checking epoch_limiter, which will interrupt
    # any in-progress loops.
    conditional_count_up_to = control_flow_ops.cond(
        state_ops.is_variable_initialized(
            epoch_limiter), lambda: epoch_limiter.count_up_to(epoch_limit),
        lambda: constant_op.constant(0, dtype=dtypes.int64))
    with ops.control_dependencies([conditional_count_up_to]):
      filenames_tensor = array_ops.identity(filenames_tensor)
    filename_queue = input_lib.string_input_producer(
        filenames_tensor, shuffle=False, capacity=1)
    return filename_queue, epoch_limiter
Exemplo n.º 6
0
def report_uninitialized_variables(var_list=None,
                                   name="report_uninitialized_variables"):
    """Adds ops to list the names of uninitialized variables.

  When run, it returns a 1-D tensor containing the names of uninitialized
  variables if there are any, or an empty array if there are none.

  Args:
    var_list: List of `Variable` objects to check. Defaults to the
      value of `all_variables() + local_variables()`
    name: Optional name of the `Operation`.

  Returns:
    A 1-D tensor containing names of the uninitialized variables, or an empty
    1-D
    tensor if there are no variables or no uninitialized variables.
  """
    if var_list is None:
        var_list = all_variables() + local_variables()
        # Backwards compatibility for old-style variables. TODO(touts): remove.
        if not var_list:
            var_list = []
            for op in ops.get_default_graph().get_operations():
                if op.type in ["Variable", "AutoReloadVariable"]:
                    var_list.append(op.outputs[0])
    with ops.name_scope(name):
        if not var_list:
            # Return an empty tensor so we only need to check for returned tensor
            # size being 0 as an indication of model ready.
            return array_ops.constant([], dtype=dtypes.string)
        else:
            # Get a 1-D boolean tensor listing whether each variable is initialized.
            variables_mask = math_ops.logical_not(
                array_ops.pack(
                    [state_ops.is_variable_initialized(v) for v in var_list]))
            # Get a 1-D string tensor containing all the variable names.
            variable_names_tensor = array_ops.constant(
                [s.op.name for s in var_list])
            # Return a 1-D tensor containing all the names of uninitialized variables.
            return array_ops.boolean_mask(variable_names_tensor,
                                          variables_mask)
Exemplo n.º 7
0
  def create_batch(self):
    """Create queues to window and batch time series data.

    Returns:
      A dictionary of Tensors corresponding to the output of `self._reader`
      (from the `time_series_reader` constructor argument), each with shapes
      prefixed by [`batch_size`, `window_size`].
    """
    features = self._reader.read()
    if self._jitter:
      # TODO(agarwal, allenl): Figure out if more jitter is needed here.
      jitter = random_ops.random_uniform(shape=[], maxval=2, dtype=dtypes.int32)
    else:
      jitter = 0
    # To keep things efficient, we pass from the windowing batcher to the
    # batch-of-windows batcher in batches. This avoids the need for huge numbers
    # of threads, but does mean that jitter is only applied occasionally.
    # TODO(allenl): Experiment with different internal passing sizes.
    internal_passing_size = self._batch_size
    features_windowed = input_lib.batch(
        features,
        batch_size=self._window_size * internal_passing_size + jitter,
        enqueue_many=True,
        capacity=(self._queue_capacity_multiplier
                  * internal_passing_size * self._window_size),
        num_threads=self._num_threads)
    raw_features_windowed = features_windowed
    if self._jitter:
      features_windowed = {
          key: value[jitter:]
          for key, value in features_windowed.items()}
    features_windowed = {
        key: array_ops.reshape(
            value,
            array_ops.concat(
                [[internal_passing_size, self._window_size],
                 array_ops.shape(value)[1:]],
                axis=0))
        for key, value in features_windowed.items()}
    batch_and_window_shape = tensor_shape.TensorShape(
        [internal_passing_size, self._window_size])
    for key in features_windowed.keys():
      features_windowed[key].set_shape(
          batch_and_window_shape.concatenate(
              raw_features_windowed[key].get_shape()[1:]))
    # When switching files, we may end up with windows where the time is not
    # decreasing, even if times within each file are sorted (and even if those
    # files are visited in order, when looping back around to the beginning of
    # the first file). This is hard for models to deal with, so we either
    # discard such examples, creating a bias where the beginning and end of the
    # series is under-sampled, or we sort the window, creating large gaps.
    times = features_windowed[feature_keys.TrainEvalFeatures.TIMES]
    if self._discard_out_of_order:
      non_decreasing = math_ops.reduce_all(
          times[:, 1:] >= times[:, :-1], axis=1)
      # Ensure that no more than self._discard_limit complete batches are
      # discarded contiguously (resetting the count when we find a single clean
      # window). This prevents infinite looping when the dataset is smaller than
      # the window size.
      # TODO(allenl): Figure out a way to return informative errors from
      # count_up_to.
      discarded_windows_limiter = variable_scope.variable(
          initial_value=constant_op.constant(0, dtype=dtypes.int64),
          name="discarded_windows_limiter",
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES])
      def _initialized_limit_check():
        return control_flow_ops.cond(
            math_ops.reduce_any(non_decreasing),
            lambda: state_ops.assign(discarded_windows_limiter, 0),
            lambda: discarded_windows_limiter.count_up_to(self._discard_limit))
      discard_limit_op = control_flow_ops.cond(
          state_ops.is_variable_initialized(discarded_windows_limiter),
          _initialized_limit_check,
          lambda: constant_op.constant(0, dtype=dtypes.int64))
      with ops.control_dependencies([discard_limit_op]):
        non_decreasing = array_ops.identity(non_decreasing)
    else:
      _, indices_descending = nn.top_k(
          times, k=array_ops.shape(times)[-1], sorted=True)
      indices = array_ops.reverse(indices_descending, axis=[0])
      features_windowed = {
          key: array_ops.gather(params=value, indices=indices)
          for key, value in features_windowed.items()
      }
      non_decreasing = True
    features_batched = input_lib.maybe_shuffle_batch(
        features_windowed,
        num_threads=self._num_threads,
        seed=self._shuffle_seed,
        batch_size=self._batch_size,
        capacity=self._queue_capacity_multiplier * self._batch_size,
        min_after_dequeue=(self._shuffle_min_after_dequeue_multiplier *
                           self._batch_size),
        keep_input=non_decreasing,
        enqueue_many=True)
    return (features_batched, None)
Exemplo n.º 8
0
    def create_op(self, *args, **kwargs):
        """Creates an `Operation`.

    For operations of the following form

      orig_value = op(*args, **kwargs)

    this function constructs the following subgraph :

      v = Variable()
      if v is not initialized:
        orig_value = op(*args, **kwargs)
        v.assign(orig_value) # Initializes v
        return orig_value
      else:
        return v

    The above transformation is not performed and the original op is returned
    as is if any of the following is true:
    * `_return_as_is` flag is set to true.
    * op_type is listed in _PASS_THROUGH_OPS
    * op has no outputs.
    * One of the op's return value has a ref type.

    Args:
      *args: Arguments for create_op()
      **kwargs: Keyword arguments for create_op(). Refer to
        tensorflow.python.framework.ops.Graph.create_op() for the mandatory
        and optional arguments.

    Returns:
      An Operation.

    Raises:
      UnimplementedError: if output type is a reference and the op's type
        is not one of the supported types in `_REF_OPS_WHITELIST`.
    """
        op_type = kwargs['op_type'] if 'op_type' in kwargs else args[0]
        output_dtypes = kwargs['dtypes'] if 'dtypes' in kwargs else args[2]
        output_dtypes = [dtypes.as_dtype(d) for d in output_dtypes]

        if self._return_as_is or op_type in _PASS_THROUGH_OPS:
            return self._wrap(
                super(ImperativeGraph, self).create_op(*args, **kwargs))

        if not output_dtypes:
            return self._wrap(
                super(ImperativeGraph, self).create_op(*args, **kwargs))

        output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes])  # pylint: disable=protected-access

        if output_has_ref:
            if op_type not in _REF_OPS_WHITELIST:
                raise errors.UnimplementedError(
                    None, None, op_type + ' op not supported in '
                    'imperative graph')

            ret = super(ImperativeGraph, self).create_op(*args, **kwargs)

            if self._in_variable_creation:
                if op_type == 'Assign':
                    self.add_pending_init(ret)

            return self._wrap(ret)

        with self.return_as_is():
            # Declares the variables to hold the output values of this op.
            op_output_var = [
                state_ops.variable_op_v2(tensor_shape.TensorShape(None),
                                         dtype,
                                         container=self._name)
                for dtype in output_dtypes
            ]
            # Ops to free the resources used by the temporary cache variables.
            # The following two ops are created for each cache variable,
            # having no control dependencies on any other ops :
            # var_handle_op ----> destroy_resource_op
            for dtype, v in zip(output_dtypes, op_output_var):
                with ops.control_dependencies(None):
                    self._variable_cleanup_ops += [
                        gen_resource_variable_ops.destroy_resource_op(
                            gen_resource_variable_ops.var_handle_op(
                                dtype,
                                tensor_shape.TensorShape(None),
                                container=self._name,
                                shared_name=v.op.name),
                            ignore_lookup_error=True)
                    ]

            # Create the conditional to run the original op only when the variable
            # corresponding to the first output is not initialized.
            inited = state_ops.is_variable_initialized(op_output_var[0])
            v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited)
            # pylint: disable=protected-access
            v_f_op = gen_array_ops._ref_identity(v_f)
            v_t_op = gen_array_ops._ref_identity(v_t)
            # pylint: enable=protected-access

            with ops.control_dependencies([v_f_op.op]):
                # Create the original op
                orig_op = self._wrap(
                    super(ImperativeGraph, self).create_op(*args, **kwargs))
            shapes = [val.get_shape() for val in orig_op.outputs]

            controls = []
            for var, val in zip(op_output_var, orig_op.outputs):
                if (not val.get_shape().is_fully_defined()
                        or val.get_shape().num_elements() > 0):
                    assign_op = state_ops.assign(var,
                                                 val,
                                                 validate_shape=False)
                    assign_op.set_shape(val.get_shape())
                    controls.append(assign_op)

            values = []
            if len(controls) > 1:
                if control_flow_ops.IsSwitch(orig_op):
                    # pylint: disable=protected-access
                    controls = gen_control_flow_ops._ref_merge(controls)
                    # pylint: enable=protected-access
                else:
                    controls = control_flow_ops.tuple(controls)

            for var, val in zip(op_output_var, orig_op.outputs):
                with ops.control_dependencies(controls):
                    with self.colocate_with(v_f_op):
                        real_val = array_ops.identity(val)
                with ops.control_dependencies([v_t_op.op]):
                    with self.colocate_with(v_t_op):
                        stored_val = array_ops.identity(var)
                    stored_val.set_shape(val.get_shape())
                    real_val, _ = control_flow_ops.merge(
                        [real_val, stored_val])
                real_val.op.node_def.attr['_gradient_op_type'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        s=compat.as_bytes(self._merge_op_type)))
                values.append(real_val)

            for i, _ in enumerate(shapes):
                values[i].set_shape(shapes[i])
            self._outputs_map[orig_op.name] = values
            try:
                self._gradient_function_map[
                    orig_op.name] = ops.get_gradient_function(orig_op)
            except (KeyError, LookupError):
                pass
            else:
                orig_op.node_def.attr['_gradient_op_type'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        s=compat.as_bytes(self._imperative_op_type)))

            return MultiOutputOperation(values)
Exemplo n.º 9
0
  def create_op(self, *args, **kwargs):
    """Creates an `Operation`.

    For operations of the following form

      orig_value = op(*args, **kwargs)

    this function constructs the following subgraph :

      v = Variable()
      if v is not initialized:
        orig_value = op(*args, **kwargs)
        v.assign(orig_value) # Initializes v
        return orig_value
      else:
        return v

    The above transformation is not performed and the original op is returned
    as is if any of the following is true:
    * `_return_as_is` flag is set to true.
    * op_type is listed in _PASS_THROUGH_OPS
    * op has no outputs.
    * One of the op's return value has a ref type.

    Args:
      *args: Arguments for create_op()
      **kwargs: Keyword arguments for create_op(). Refer to
        tensorflow.python.framework.ops.Graph.create_op() for the mandatory
        and optional arguments.

    Returns:
      An Operation.

    Raises:
      UnimplementedError: if output type is a reference and the op's type
        is not one of the supported types in `_REF_OPS_WHITELIST`.
    """
    op_type = kwargs['op_type'] if 'op_type' in kwargs else args[0]
    output_dtypes = kwargs['dtypes'] if 'dtypes' in kwargs else args[2]
    output_dtypes = [dtypes.as_dtype(d) for d in output_dtypes]

    if self._return_as_is or op_type in _PASS_THROUGH_OPS:
      return self._wrap(super(ImperativeGraph, self).create_op(*args, **kwargs))

    if not output_dtypes:
      return self._wrap(
          super(ImperativeGraph, self).create_op(*args, **kwargs))

    output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes])  # pylint: disable=protected-access

    if output_has_ref:
      if op_type not in _REF_OPS_WHITELIST:
        raise errors.UnimplementedError(None, None,
                                        op_type + ' op not supported in '
                                        'imperative graph')

      ret = super(ImperativeGraph, self).create_op(*args, **kwargs)

      if self._in_variable_creation:
        if op_type == 'Assign':
          self.add_pending_init(ret)

      return self._wrap(ret)

    with self.return_as_is():
      # Declares the variables to hold the output values of this op.
      op_output_var = [state_ops.variable_op_v2(
          tensor_shape.TensorShape(None), dtype, container=self._name)
                       for dtype in output_dtypes]
      # Ops to free the resources used by the temporary cache variables.
      # The following two ops are created for each cache variable,
      # having no control dependencies on any other ops :
      # var_handle_op ----> destroy_resource_op
      for dtype, v in zip(output_dtypes, op_output_var):
        with ops.control_dependencies(None):
          self._variable_cleanup_ops += [
              gen_resource_variable_ops.destroy_resource_op(
                  gen_resource_variable_ops.var_handle_op(
                      dtype, tensor_shape.TensorShape(None),
                      container=self._name, shared_name=v.op.name),
                  ignore_lookup_error=True)]

      # Create the conditional to run the original op only when the variable
      # corresponding to the first output is not initialized.
      inited = state_ops.is_variable_initialized(op_output_var[0])
      v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited)
      # pylint: disable=protected-access
      v_f_op = gen_array_ops._ref_identity(v_f)
      v_t_op = gen_array_ops._ref_identity(v_t)
      # pylint: enable=protected-access

      with ops.control_dependencies([v_f_op.op]):
        # Create the original op
        orig_op = self._wrap(
            super(ImperativeGraph, self).create_op(*args, **kwargs))
      shapes = [val.get_shape() for val in orig_op.outputs]

      controls = []
      for var, val in zip(op_output_var, orig_op.outputs):
        if (not val.get_shape().is_fully_defined() or
            val.get_shape().num_elements() > 0):
          assign_op = state_ops.assign(var, val, validate_shape=False)
          assign_op.set_shape(val.get_shape())
          controls.append(assign_op)

      values = []
      if len(controls) > 1:
        if control_flow_ops.IsSwitch(orig_op):
          # pylint: disable=protected-access
          controls = gen_control_flow_ops._ref_merge(controls)
          # pylint: enable=protected-access
        else:
          controls = control_flow_ops.tuple(controls)

      for var, val in zip(op_output_var, orig_op.outputs):
        with ops.control_dependencies(controls):
          with self.colocate_with(v_f_op):
            real_val = array_ops.identity(val)
        with ops.control_dependencies([v_t_op.op]):
          with self.colocate_with(v_t_op):
            stored_val = array_ops.identity(var)
          stored_val.set_shape(val.get_shape())
          real_val, _ = control_flow_ops.merge([real_val, stored_val])
        real_val.op.node_def.attr['_gradient_op_type'].CopyFrom(
            attr_value_pb2.AttrValue(s=compat.as_bytes(self._merge_op_type)))
        values.append(real_val)

      for i, _ in enumerate(shapes):
        values[i].set_shape(shapes[i])
      self._outputs_map[orig_op.name] = values
      try:
        self._gradient_function_map[orig_op.name] = ops.get_gradient_function(
            orig_op)
      except (KeyError, LookupError):
        pass
      else:
        orig_op.node_def.attr['_gradient_op_type'].CopyFrom(
            attr_value_pb2.AttrValue(
                s=compat.as_bytes(self._imperative_op_type)))

      return MultiOutputOperation(values, orig_op)
Exemplo n.º 10
0
    def create_batch(self):
        """Create queues to window and batch time series data.

    Returns:
      A dictionary of Tensors corresponding to the output of `self._reader`
      (from the `time_series_reader` constructor argument), each with shapes
      prefixed by [`batch_size`, `window_size`].
    """
        features = self._reader.read()
        if self._jitter:
            # TODO(agarwal, allenl): Figure out if more jitter is needed here.
            jitter = random_ops.random_uniform(shape=[],
                                               maxval=2,
                                               dtype=dtypes.int32)
        else:
            jitter = 0
        # To keep things efficient, we pass from the windowing batcher to the
        # batch-of-windows batcher in batches. This avoids the need for huge numbers
        # of threads, but does mean that jitter is only applied occasionally.
        # TODO(allenl): Experiment with different internal passing sizes.
        internal_passing_size = self._batch_size
        features_windowed = input_lib.batch(
            features,
            batch_size=self._window_size * internal_passing_size + jitter,
            enqueue_many=True,
            capacity=(self._queue_capacity_multiplier * internal_passing_size *
                      self._window_size),
            num_threads=self._num_threads)
        raw_features_windowed = features_windowed
        if self._jitter:
            features_windowed = {
                key: value[jitter:]
                for key, value in features_windowed.items()
            }
        features_windowed = {
            key: array_ops.reshape(
                value,
                array_ops.concat([[internal_passing_size, self._window_size],
                                  array_ops.shape(value)[1:]],
                                 axis=0))
            for key, value in features_windowed.items()
        }
        batch_and_window_shape = tensor_shape.TensorShape(
            [internal_passing_size, self._window_size])
        for key in features_windowed.keys():
            features_windowed[key].set_shape(
                batch_and_window_shape.concatenate(
                    raw_features_windowed[key].get_shape()[1:]))
        # When switching files, we may end up with windows where the time is not
        # decreasing, even if times within each file are sorted (and even if those
        # files are visited in order, when looping back around to the beginning of
        # the first file). This is hard for models to deal with, so we either
        # discard such examples, creating a bias where the beginning and end of the
        # series is under-sampled, or we sort the window, creating large gaps.
        times = features_windowed[feature_keys.TrainEvalFeatures.TIMES]
        if self._discard_out_of_order:
            non_decreasing = math_ops.reduce_all(times[:, 1:] >= times[:, :-1],
                                                 axis=1)
            # Ensure that no more than self._discard_limit complete batches are
            # discarded contiguously (resetting the count when we find a single clean
            # window). This prevents infinite looping when the dataset is smaller than
            # the window size.
            # TODO(allenl): Figure out a way to return informative errors from
            # count_up_to.
            discarded_windows_limiter = variable_scope.variable(
                initial_value=constant_op.constant(0, dtype=dtypes.int64),
                name="discarded_windows_limiter",
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES])

            def _initialized_limit_check():
                return control_flow_ops.cond(
                    math_ops.reduce_any(non_decreasing),
                    lambda: state_ops.assign(discarded_windows_limiter, 0),
                    lambda: discarded_windows_limiter.count_up_to(
                        self._discard_limit))

            discard_limit_op = control_flow_ops.cond(
                state_ops.is_variable_initialized(discarded_windows_limiter),
                _initialized_limit_check,
                lambda: constant_op.constant(0, dtype=dtypes.int64))
            with ops.control_dependencies([discard_limit_op]):
                non_decreasing = array_ops.identity(non_decreasing)
        else:
            _, indices_descending = nn.top_k(times,
                                             k=array_ops.shape(times)[-1],
                                             sorted=True)
            indices = array_ops.reverse(indices_descending, axis=[0])
            features_windowed = {
                key: array_ops.gather(params=value, indices=indices)
                for key, value in features_windowed.items()
            }
            non_decreasing = True
        features_batched = input_lib.maybe_shuffle_batch(
            features_windowed,
            num_threads=self._num_threads,
            seed=self._shuffle_seed,
            batch_size=self._batch_size,
            capacity=self._queue_capacity_multiplier * self._batch_size,
            min_after_dequeue=(self._shuffle_min_after_dequeue_multiplier *
                               self._batch_size),
            keep_input=non_decreasing,
            enqueue_many=True)
        return (features_batched, None)