def testNotAMultiple(self):
     num_unroll = 3  # Not a divisor of value_length -
     # so padding would have been necessary.
     with self.test_session() as sess:
         with self.assertRaisesRegexp(
                 errors_impl.InvalidArgumentError,
                 ".*should be a multiple of: 3, but saw "
                 "value: 4. Consider setting pad=True."):
             coord = coordinator.Coordinator()
             threads = None
             try:
                 with coord.stop_on_exception():
                     next_batch = sqss.batch_sequences_with_states(
                         input_key=self.key,
                         input_sequences=self.sequences,
                         input_context=self.context,
                         input_length=3,
                         initial_states=self.initial_states,
                         num_unroll=num_unroll,
                         batch_size=self.batch_size,
                         num_threads=3,
                         # to enforce that we only move on to the next examples after
                         # finishing all segments of the first ones.
                         capacity=2,
                         pad=False)
                     threads = queue_runner_impl.start_queue_runners(
                         coord=coord)
                     sess.run([next_batch.key])
             except errors_impl.OutOfRangeError:
                 pass
             finally:
                 coord.request_stop()
                 if threads is not None:
                     coord.join(threads, stop_grace_period_secs=2)
 def testNotAMultiple(self):
   num_unroll = 3  # Not a divisor of value_length -
   # so padding would have been necessary.
   with self.test_session() as sess:
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  ".*should be a multiple of: 3, but saw "
                                  "value: 4. Consider setting pad=True."):
       coord = coordinator.Coordinator()
       threads = None
       try:
         with coord.stop_on_exception():
           next_batch = sqss.batch_sequences_with_states(
               input_key=self.key,
               input_sequences=self.sequences,
               input_context=self.context,
               input_length=3,
               initial_states=self.initial_states,
               num_unroll=num_unroll,
               batch_size=self.batch_size,
               num_threads=3,
               # to enforce that we only move on to the next examples after
               # finishing all segments of the first ones.
               capacity=2,
               pad=False)
           threads = queue_runner_impl.start_queue_runners(coord=coord)
           sess.run([next_batch.key])
       except errors_impl.OutOfRangeError:
         pass
       finally:
         coord.request_stop()
         if threads is not None:
           coord.join(threads, stop_grace_period_secs=2)
def _read_batch(cell,
                features,
                labels,
                mode,
                num_unroll,
                num_layers,
                batch_size,
                input_key_column_name,
                sequence_feature_columns,
                context_feature_columns=None,
                num_threads=3,
                queue_capacity=1000):
  """Reads a batch from a state saving sequence queue.

  Args:
    cell: An initialized `RNNCell` to be used in the RNN.
    features: A dict of Python string to an iterable of `Tensor`, the
      `features` argument of a TF.Learn model_fn.
    labels: An iterable of `Tensor`, the `labels` argument of a
      TF.Learn model_fn.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    num_unroll: Python integer, how many time steps to unroll at a time.
      The input sequences of length `k` are then split into `k / num_unroll`
      many segments.
    num_layers: Python integer, number of layers in the RNN.
    batch_size: Python integer, the size of the minibatch produced by the SQSS.
    input_key_column_name: Python string, the name of the feature column
      containing a string scalar `Tensor` that serves as a unique key to
      identify input sequence across minibatches.
    sequence_feature_columns: An iterable containing all the feature columns
      describing sequence features. All items in the set should be instances
      of classes derived from `FeatureColumn`.
    context_feature_columns: An iterable containing all the feature columns
      describing context features, i.e., features that apply accross all time
      steps. All items in the set should be instances of classes derived from
      `FeatureColumn`.
    num_threads: The Python integer number of threads enqueuing input examples
      into a queue. Defaults to 3.
    queue_capacity: The max capacity of the queue in number of examples.
      Needs to be at least `batch_size`. Defaults to 1000. When iterating
      over the same input example multiple times reusing their keys the
      `queue_capacity` must be smaller than the number of examples.

  Returns:
    batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample`
      values and their saved internal states.
  """
  # Set batch_size=1 to initialize SQSS with cell's zero state.
  values = cell.zero_state(batch_size=1, dtype=dtypes.float32)

  # Set up stateful queue reader.
  states = {}
  state_names = _get_lstm_state_names(num_layers)
  for i in range(num_layers):
    states[state_names[i][0]] = array_ops.squeeze(values[i][0], axis=0)
    states[state_names[i][1]] = array_ops.squeeze(values[i][1], axis=0)

  input_key, sequences, context = _prepare_features_for_sqss(
      features, labels, mode, input_key_column_name, sequence_feature_columns,
      context_feature_columns)

  return sqss.batch_sequences_with_states(
      input_key=input_key,
      input_sequences=sequences,
      input_context=context,
      input_length=None,  # infer sequence lengths
      initial_states=states,
      num_unroll=num_unroll,
      batch_size=batch_size,
      pad=True,  # pad to a multiple of num_unroll
      num_threads=num_threads,
      capacity=queue_capacity)
    def _testBasics(self, num_unroll, length, pad, expected_seq1_batch1,
                    expected_seq2_batch1, expected_seq1_batch2,
                    expected_seq2_batch2, expected_seq3_batch1,
                    expected_seq3_batch2, expected_seq4_batch1,
                    expected_seq4_batch2):

        with self.test_session() as sess:
            next_batch = sqss.batch_sequences_with_states(
                input_key=self.key,
                input_sequences=self.sequences,
                input_context=self.context,
                input_length=length,
                initial_states=self.initial_states,
                num_unroll=num_unroll,
                batch_size=self.batch_size,
                num_threads=3,
                # to enforce that we only move on to the next examples after finishing
                # all segments of the first ones.
                capacity=2,
                pad=pad)

            state1 = next_batch.state("state1")
            state2 = next_batch.state("state2")
            state1_update = next_batch.save_state("state1", state1 + 1)
            state2_update = next_batch.save_state("state2", state2 - 1)

            # Make sure queue runner with SQSS is added properly to meta graph def.
            # Saver requires at least one variable.
            v0 = variables.Variable(10.0, name="v0")
            ops.add_to_collection("variable_collection", v0)
            variables.global_variables_initializer()
            save = saver.Saver([v0])
            test_dir = os.path.join(test.get_temp_dir(), "sqss_test")
            filename = os.path.join(test_dir, "metafile")
            meta_graph_def = save.export_meta_graph(filename)
            qr_saved = meta_graph_def.collection_def[
                ops.GraphKeys.QUEUE_RUNNERS]
            self.assertTrue(qr_saved.bytes_list.value is not None)

            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(coord=coord)

            # Step 1
            (key_value, next_key_value, seq1_value, seq2_value, seq3_value,
             seq4_value, context1_value, state1_value, state2_value,
             length_value, _, _) = sess.run(
                 (next_batch.key, next_batch.next_key,
                  next_batch.sequences["seq1"], next_batch.sequences["seq2"],
                  next_batch.sequences["seq3"], next_batch.sequences["seq4"],
                  next_batch.context["context1"], state1, state2,
                  next_batch.length, state1_update, state2_update))
            expected_first_keys = set([b"00000_of_00002"])
            expected_second_keys = set([b"00001_of_00002"])
            expected_final_keys = set([b"STOP"])

            self.assertEqual(expected_first_keys, self._prefix(key_value))
            self.assertEqual(expected_second_keys,
                             self._prefix(next_key_value))
            self.assertAllEqual(
                np.tile(self.context["context1"], (self.batch_size, 1)),
                context1_value)
            self.assertAllEqual(expected_seq1_batch1, seq1_value)
            self.assertAllEqual(expected_seq2_batch1, seq2_value)
            self.assertAllEqual(expected_seq3_batch1.indices,
                                seq3_value.indices)
            self.assertAllEqual(expected_seq3_batch1.values, seq3_value.values)
            self.assertAllEqual(expected_seq3_batch1.dense_shape,
                                seq3_value.dense_shape)
            self.assertAllEqual(expected_seq4_batch1.indices,
                                seq4_value.indices)
            self.assertAllEqual(expected_seq4_batch1.values, seq4_value.values)
            self.assertAllEqual(expected_seq4_batch1.dense_shape,
                                seq4_value.dense_shape)
            self.assertAllEqual(
                np.tile(self.initial_states["state1"],
                        (self.batch_size, 1, 1)), state1_value)
            self.assertAllEqual(
                np.tile(self.initial_states["state2"], (self.batch_size, 1)),
                state2_value)
            self.assertAllEqual(length_value, [num_unroll, num_unroll])

            # Step 2
            (key_value, next_key_value, seq1_value, seq2_value, seq3_value,
             seq4_value, context1_value, state1_value, state2_value,
             length_value, _, _) = sess.run(
                 (next_batch.key, next_batch.next_key,
                  next_batch.sequences["seq1"], next_batch.sequences["seq2"],
                  next_batch.sequences["seq3"], next_batch.sequences["seq4"],
                  next_batch.context["context1"], state1, state2,
                  next_batch.length, state1_update, state2_update))

            self.assertEqual(expected_second_keys, self._prefix(key_value))
            self.assertEqual(expected_final_keys, self._prefix(next_key_value))
            self.assertAllEqual(
                np.tile(self.context["context1"], (self.batch_size, 1)),
                context1_value)
            self.assertAllEqual(expected_seq1_batch2, seq1_value)
            self.assertAllEqual(expected_seq2_batch2, seq2_value)
            self.assertAllEqual(expected_seq3_batch2.indices,
                                seq3_value.indices)
            self.assertAllEqual(expected_seq3_batch2.values, seq3_value.values)
            self.assertAllEqual(expected_seq3_batch2.dense_shape,
                                seq3_value.dense_shape)
            self.assertAllEqual(expected_seq4_batch2.indices,
                                seq4_value.indices)
            self.assertAllEqual(expected_seq4_batch2.values, seq4_value.values)
            self.assertAllEqual(expected_seq4_batch2.dense_shape,
                                seq4_value.dense_shape)
            self.assertAllEqual(
                1 + np.tile(self.initial_states["state1"],
                            (self.batch_size, 1, 1)), state1_value)
            self.assertAllEqual(
                -1 + np.tile(self.initial_states["state2"],
                             (self.batch_size, 1)), state2_value)
            self.assertAllEqual([1, 1], length_value)

            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=2)
  def _testBasics(self, num_unroll, length, pad,
                  expected_seq1_batch1, expected_seq2_batch1,
                  expected_seq1_batch2, expected_seq2_batch2,
                  expected_seq3_batch1, expected_seq3_batch2,
                  expected_seq4_batch1, expected_seq4_batch2,
                  key=None, make_keys_unique=False):

    with self.test_session() as sess:
      next_batch = sqss.batch_sequences_with_states(
          input_key=key if key is not None else self.key,
          input_sequences=self.sequences,
          input_context=self.context,
          input_length=length,
          initial_states=self.initial_states,
          num_unroll=num_unroll,
          batch_size=self.batch_size,
          num_threads=3,
          # to enforce that we only move on to the next examples after finishing
          # all segments of the first ones.
          capacity=2,
          pad=pad,
          make_keys_unique=make_keys_unique,
          make_keys_unique_seed=9)

      state1 = next_batch.state("state1")
      state2 = next_batch.state("state2")
      state1_update = next_batch.save_state("state1", state1 + 1)
      state2_update = next_batch.save_state("state2", state2 - 1)

      # Make sure queue runner with SQSS is added properly to meta graph def.
      # Saver requires at least one variable.
      v0 = variables.Variable(10.0, name="v0")
      ops.add_to_collection("variable_collection", v0)
      variables.global_variables_initializer()
      save = saver.Saver([v0])
      test_dir = os.path.join(test.get_temp_dir(), "sqss_test")
      filename = os.path.join(test_dir, "metafile")
      meta_graph_def = save.export_meta_graph(filename)
      qr_saved = meta_graph_def.collection_def[ops.GraphKeys.QUEUE_RUNNERS]
      self.assertTrue(qr_saved.bytes_list.value is not None)

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(coord=coord)

      # Step 1
      (key_value, next_key_value, seq1_value, seq2_value, seq3_value,
       seq4_value, context1_value, context2_value, state1_value, state2_value,
       length_value, _, _) = sess.run(
           (next_batch.key, next_batch.next_key, next_batch.sequences["seq1"],
            next_batch.sequences["seq2"], next_batch.sequences["seq3"],
            next_batch.sequences["seq4"], next_batch.context["context1"],
            next_batch.context["sp_context"], state1, state2, next_batch.length,
            state1_update, state2_update))
      expected_first_keys = set([b"00000_of_00002"])
      expected_second_keys = set([b"00001_of_00002"])
      expected_final_keys = set([b"STOP"])

      self.assertEqual(expected_first_keys, self._prefix(key_value))
      self.assertEqual(expected_second_keys, self._prefix(next_key_value))
      self.assertAllEqual(
          np.tile(self.context["context1"], (self.batch_size, 1)),
          context1_value)
      self.assertAllEqual(self.sp_tensor3_expected.indices,
                          context2_value.indices)
      self.assertAllEqual(self.sp_tensor3_expected.values,
                          context2_value.values)
      self.assertAllEqual(self.sp_tensor3_expected.dense_shape,
                          context2_value.dense_shape)
      self.assertAllEqual(expected_seq1_batch1, seq1_value)
      self.assertAllEqual(expected_seq2_batch1, seq2_value)
      self.assertAllEqual(expected_seq3_batch1.indices, seq3_value.indices)
      self.assertAllEqual(expected_seq3_batch1.values, seq3_value.values)
      self.assertAllEqual(expected_seq3_batch1.dense_shape,
                          seq3_value.dense_shape)
      self.assertAllEqual(expected_seq4_batch1.indices, seq4_value.indices)
      self.assertAllEqual(expected_seq4_batch1.values, seq4_value.values)
      self.assertAllEqual(expected_seq4_batch1.dense_shape,
                          seq4_value.dense_shape)
      self.assertAllEqual(
          np.tile(self.initial_states["state1"], (self.batch_size, 1, 1)),
          state1_value)
      self.assertAllEqual(
          np.tile(self.initial_states["state2"], (self.batch_size, 1)),
          state2_value)
      self.assertAllEqual(length_value, [num_unroll, num_unroll])

      # Step 2
      (key_value, next_key_value, seq1_value, seq2_value, seq3_value,
       seq4_value, context1_value, context2_value, state1_value, state2_value,
       length_value, _, _) = sess.run(
           (next_batch.key, next_batch.next_key, next_batch.sequences["seq1"],
            next_batch.sequences["seq2"], next_batch.sequences["seq3"],
            next_batch.sequences["seq4"], next_batch.context["context1"],
            next_batch.context["sp_context"], state1, state2, next_batch.length,
            state1_update, state2_update))

      self.assertEqual(expected_second_keys, self._prefix(key_value))
      self.assertEqual(expected_final_keys, self._prefix(next_key_value))
      self.assertAllEqual(
          np.tile(self.context["context1"], (self.batch_size, 1)),
          context1_value)
      self.assertAllEqual(self.sp_tensor3_expected.indices,
                          context2_value.indices)
      self.assertAllEqual(self.sp_tensor3_expected.values,
                          context2_value.values)
      self.assertAllEqual(self.sp_tensor3_expected.dense_shape,
                          context2_value.dense_shape)
      self.assertAllEqual(expected_seq1_batch2, seq1_value)
      self.assertAllEqual(expected_seq2_batch2, seq2_value)
      self.assertAllEqual(expected_seq3_batch2.indices, seq3_value.indices)
      self.assertAllEqual(expected_seq3_batch2.values, seq3_value.values)
      self.assertAllEqual(expected_seq3_batch2.dense_shape,
                          seq3_value.dense_shape)
      self.assertAllEqual(expected_seq4_batch2.indices, seq4_value.indices)
      self.assertAllEqual(expected_seq4_batch2.values, seq4_value.values)
      self.assertAllEqual(expected_seq4_batch2.dense_shape,
                          seq4_value.dense_shape)
      self.assertAllEqual(1 + np.tile(self.initial_states["state1"],
                                      (self.batch_size, 1, 1)), state1_value)
      self.assertAllEqual(-1 + np.tile(self.initial_states["state2"],
                                       (self.batch_size, 1)), state2_value)
      self.assertAllEqual([1, 1], length_value)

      coord.request_stop()
      coord.join(threads, stop_grace_period_secs=2)
Exemplo n.º 6
0
def _read_batch(cell,
                features,
                labels,
                mode,
                num_unroll,
                num_layers,
                batch_size,
                input_key_column_name,
                sequence_feature_columns,
                context_feature_columns=None,
                num_threads=3,
                queue_capacity=1000):
    """Reads a batch from a state saving sequence queue.

  Args:
    cell: An initialized `RNNCell` to be used in the RNN.
    features: A dict of Python string to an iterable of `Tensor`, the
      `features` argument of a TF.Learn model_fn.
    labels: An iterable of `Tensor`, the `labels` argument of a
      TF.Learn model_fn.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    num_unroll: Python integer, how many time steps to unroll at a time.
      The input sequences of length `k` are then split into `k / num_unroll`
      many segments.
    num_layers: Python integer, number of layers in the RNN.
    batch_size: Python integer, the size of the minibatch produced by the SQSS.
    input_key_column_name: Python string, the name of the feature column
      containing a string scalar `Tensor` that serves as a unique key to
      identify input sequence across minibatches.
    sequence_feature_columns: An iterable containing all the feature columns
      describing sequence features. All items in the set should be instances
      of classes derived from `FeatureColumn`.
    context_feature_columns: An iterable containing all the feature columns
      describing context features, i.e., features that apply accross all time
      steps. All items in the set should be instances of classes derived from
      `FeatureColumn`.
    num_threads: The Python integer number of threads enqueuing input examples
      into a queue. Defaults to 3.
    queue_capacity: The max capacity of the queue in number of examples.
      Needs to be at least `batch_size`. Defaults to 1000. When iterating
      over the same input example multiple times reusing their keys the
      `queue_capacity` must be smaller than the number of examples.

  Returns:
    batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample`
      values and their saved internal states.
  """
    # Set batch_size=1 to initialize SQSS with cell's zero state.
    values = cell.zero_state(batch_size=1, dtype=dtypes.float32)

    # Set up stateful queue reader.
    states = {}
    state_names = _get_lstm_state_names(num_layers)
    for i in range(num_layers):
        states[state_names[i][0]] = array_ops.squeeze(values[i][0], axis=0)
        states[state_names[i][1]] = array_ops.squeeze(values[i][1], axis=0)

    input_key, sequences, context = _prepare_features_for_sqss(
        features, labels, mode, input_key_column_name,
        sequence_feature_columns, context_feature_columns)

    return sqss.batch_sequences_with_states(
        input_key=input_key,
        input_sequences=sequences,
        input_context=context,
        input_length=None,  # infer sequence lengths
        initial_states=states,
        num_unroll=num_unroll,
        batch_size=batch_size,
        pad=True,  # pad to a multiple of num_unroll
        num_threads=num_threads,
        capacity=queue_capacity)
Exemplo n.º 7
0
def _read_batch(cell,
                features,
                labels,
                mode,
                num_unroll,
                batch_size,
                sequence_feature_columns,
                context_feature_columns=None,
                num_threads=3,
                queue_capacity=1000,
                seed=None):
    """Reads a batch from a state saving sequence queue.

  Args:
    cell: An initialized `RNNCell` to be used in the RNN.
    features: A dict of Python string to an iterable of `Tensor`, the
      `features` argument of a TF.Learn model_fn.
    labels: An iterable of `Tensor`, the `labels` argument of a
      TF.Learn model_fn.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    num_unroll: Python integer, how many time steps to unroll at a time.
      The input sequences of length `k` are then split into `k / num_unroll`
      many segments.
    batch_size: Python integer, the size of the minibatch produced by the SQSS.
    sequence_feature_columns: An iterable containing all the feature columns
      describing sequence features. All items in the set should be instances
      of classes derived from `FeatureColumn`.
    context_feature_columns: An iterable containing all the feature columns
      describing context features, i.e., features that apply across all time
      steps. All items in the set should be instances of classes derived from
      `FeatureColumn`.
    num_threads: The Python integer number of threads enqueuing input examples
      into a queue. Defaults to 3.
    queue_capacity: The max capacity of the queue in number of examples.
      Needs to be at least `batch_size`. Defaults to 1000. When iterating
      over the same input example multiple times reusing their keys the
      `queue_capacity` must be smaller than the number of examples.
    seed: Fixes the random seed used for generating input keys by the SQSS.

  Returns:
    batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample`
      values and their saved internal states.
  """
    states = _get_initial_states(cell)

    sequences, context = _prepare_features_for_sqss(features, labels, mode,
                                                    sequence_feature_columns,
                                                    context_feature_columns)

    return sqss.batch_sequences_with_states(
        input_key='key',
        input_sequences=sequences,
        input_context=context,
        input_length=None,  # infer sequence lengths
        initial_states=states,
        num_unroll=num_unroll,
        batch_size=batch_size,
        pad=True,  # pad to a multiple of num_unroll
        make_keys_unique=True,
        make_keys_unique_seed=seed,
        num_threads=num_threads,
        capacity=queue_capacity)
def _read_batch(cell,
                features,
                labels,
                mode,
                num_unroll,
                batch_size,
                sequence_feature_columns,
                context_feature_columns=None,
                num_threads=3,
                queue_capacity=1000,
                seed=None):
  """Reads a batch from a state saving sequence queue.

  Args:
    cell: An initialized `RNNCell` to be used in the RNN.
    features: A dict of Python string to an iterable of `Tensor`, the
      `features` argument of a TF.Learn model_fn.
    labels: An iterable of `Tensor`, the `labels` argument of a
      TF.Learn model_fn.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    num_unroll: Python integer, how many time steps to unroll at a time.
      The input sequences of length `k` are then split into `k / num_unroll`
      many segments.
    batch_size: Python integer, the size of the minibatch produced by the SQSS.
    sequence_feature_columns: An iterable containing all the feature columns
      describing sequence features. All items in the set should be instances
      of classes derived from `FeatureColumn`.
    context_feature_columns: An iterable containing all the feature columns
      describing context features, i.e., features that apply accross all time
      steps. All items in the set should be instances of classes derived from
      `FeatureColumn`.
    num_threads: The Python integer number of threads enqueuing input examples
      into a queue. Defaults to 3.
    queue_capacity: The max capacity of the queue in number of examples.
      Needs to be at least `batch_size`. Defaults to 1000. When iterating
      over the same input example multiple times reusing their keys the
      `queue_capacity` must be smaller than the number of examples.
    seed: Fixes the random seed used for generating input keys by the SQSS.

  Returns:
    batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample`
      values and their saved internal states.
  """
  states = _get_initial_states(cell)

  sequences, context = _prepare_features_for_sqss(
      features, labels, mode, sequence_feature_columns,
      context_feature_columns)

  return sqss.batch_sequences_with_states(
      input_key='key',
      input_sequences=sequences,
      input_context=context,
      input_length=None,  # infer sequence lengths
      initial_states=states,
      num_unroll=num_unroll,
      batch_size=batch_size,
      pad=True,  # pad to a multiple of num_unroll
      make_keys_unique=True,
      make_keys_unique_seed=seed,
      num_threads=num_threads,
      capacity=queue_capacity)
Exemplo n.º 9
0
def build(input_reader_config,
          model_config,
          lstm_config,
          unroll_length,
          data_augmentation_options=None,
          batch_size=1):
    """Builds a tensor dictionary based on the InputReader config.

    Args:
      input_reader_config: An input_reader_builder.InputReader object.
      model_config: A model.proto object containing the config for the desired
        DetectionModel.
      lstm_config: LSTM specific configs.
      unroll_length: Unrolled length for LSTM training.
      data_augmentation_options: A list of tuples, where each tuple contains a
        data augmentation function and a dictionary containing arguments and their
        values (see preprocessor.py).
      batch_size: Batch size for queue outputs.

    Returns:
      A dictionary of tensors based on items in the input_reader_config.

    Raises:
      ValueError: On invalid input reader proto.
      ValueError: If no input paths are specified.
    """
    if not isinstance(input_reader_config, input_reader_pb2.InputReader):
        raise ValueError('input_reader_config not of type '
                         'input_reader_pb2.InputReader.')

    external_reader_config = input_reader_config.external_input_reader
    external_input_reader_config = external_reader_config.Extensions[
        input_reader_google_pb2.GoogleInputReader.google_input_reader]
    input_reader_type = external_input_reader_config.WhichOneof('input_reader')

    if input_reader_type == 'tf_record_video_input_reader':
        config = external_input_reader_config.tf_record_video_input_reader
        reader_type_class = tf.TFRecordReader
    else:
        raise ValueError(
            'Unsupported reader in input_reader_config: %s' % input_reader_type)

    if not config.input_path:
        raise ValueError('At least one input path must be specified in '
                         '`input_reader_config`.')
    key, value = parallel_reader.parallel_read(
        config.input_path[:],  # Convert `RepeatedScalarContainer` to list.
        reader_class=reader_type_class,
        num_epochs=(input_reader_config.num_epochs
                    if input_reader_config.num_epochs else None),
        num_readers=input_reader_config.num_readers,
        shuffle=input_reader_config.shuffle,
        dtypes=[tf.string, tf.string],
        capacity=input_reader_config.queue_capacity,
        min_after_dequeue=input_reader_config.min_after_dequeue)

    # TODO(yinxiao): Add loading instance mask option.
    decoder = tf_sequence_example_decoder.TFSequenceExampleDecoder()

    keys_to_decode = [
        fields.InputDataFields.image, fields.InputDataFields.groundtruth_boxes,
        fields.InputDataFields.groundtruth_classes
    ]
    tensor_dict = decoder.decode(value, items=keys_to_decode)

    tensor_dict['image'].set_shape([None, None, None, 3])
    tensor_dict['groundtruth_boxes'].set_shape([None, None, 4])

    height = model_config.ssd.image_resizer.fixed_shape_resizer.height
    width = model_config.ssd.image_resizer.fixed_shape_resizer.width

    # If data augmentation is specified in the config file, the preprocessor
    # will be called here to augment the data as specified. Most common
    # augmentations include horizontal flip and cropping.
    if data_augmentation_options:
        images_pre = tf.split(
            tensor_dict['image'], config.video_length, axis=0)
        bboxes_pre = tf.split(
            tensor_dict['groundtruth_boxes'], config.video_length, axis=0)
        labels_pre = tf.split(
            tensor_dict['groundtruth_classes'], config.video_length, axis=0)
        images_proc, bboxes_proc, labels_proc = [], [], []
        cache = preprocessor_cache.PreprocessorCache()

        for i, _ in enumerate(images_pre):
            image_dict = {
                fields.InputDataFields.image:
                    images_pre[i],
                fields.InputDataFields.groundtruth_boxes:
                    tf.squeeze(bboxes_pre[i], axis=0),
                fields.InputDataFields.groundtruth_classes:
                    tf.squeeze(labels_pre[i], axis=0),
            }
            image_dict = preprocessor.preprocess(
                image_dict,
                data_augmentation_options,
                func_arg_map=preprocessor.get_default_func_arg_map(),
                preprocess_vars_cache=cache)
            # Pads detection count to _PADDING_SIZE.
            image_dict[fields.InputDataFields.groundtruth_boxes] = tf.pad(
                image_dict[fields.InputDataFields.groundtruth_boxes],
                [[0, _PADDING_SIZE], [0, 0]])
            image_dict[fields.InputDataFields.groundtruth_boxes] = tf.slice(
                image_dict[fields.InputDataFields.groundtruth_boxes], [0, 0],
                [_PADDING_SIZE, -1])
            image_dict[fields.InputDataFields.groundtruth_classes] = tf.pad(
                image_dict[fields.InputDataFields.groundtruth_classes],
                [[0, _PADDING_SIZE]])
            image_dict[fields.InputDataFields.groundtruth_classes] = tf.slice(
                image_dict[fields.InputDataFields.groundtruth_classes], [0],
                [_PADDING_SIZE])
            images_proc.append(image_dict[fields.InputDataFields.image])
            bboxes_proc.append(
                image_dict[fields.InputDataFields.groundtruth_boxes])
            labels_proc.append(
                image_dict[fields.InputDataFields.groundtruth_classes])
        tensor_dict['image'] = tf.concat(images_proc, axis=0)
        tensor_dict['groundtruth_boxes'] = tf.stack(bboxes_proc, axis=0)
        tensor_dict['groundtruth_classes'] = tf.stack(labels_proc, axis=0)
    else:
        # Pads detection count to _PADDING_SIZE per frame.
        tensor_dict['groundtruth_boxes'] = tf.pad(
            tensor_dict['groundtruth_boxes'], [[0, 0], [0, _PADDING_SIZE], [0, 0]])
        tensor_dict['groundtruth_boxes'] = tf.slice(
            tensor_dict['groundtruth_boxes'], [0, 0, 0], [-1, _PADDING_SIZE, -1])
        tensor_dict['groundtruth_classes'] = tf.pad(
            tensor_dict['groundtruth_classes'], [[0, 0], [0, _PADDING_SIZE]])
        tensor_dict['groundtruth_classes'] = tf.slice(
            tensor_dict['groundtruth_classes'], [0, 0], [-1, _PADDING_SIZE])

    tensor_dict['image'], _ = preprocessor.resize_image(
        tensor_dict['image'], new_height=height, new_width=width)

    num_steps = config.video_length / unroll_length

    init_states = {
        'lstm_state_c':
            tf.zeros([height / 32, width / 32, lstm_config.lstm_state_depth]),
        'lstm_state_h':
            tf.zeros([height / 32, width / 32, lstm_config.lstm_state_depth]),
        'lstm_state_step':
            tf.constant(num_steps, shape=[]),
    }

    batch = sqss.batch_sequences_with_states(
        input_key=key,
        input_sequences=tensor_dict,
        input_context={},
        input_length=None,
        initial_states=init_states,
        num_unroll=unroll_length,
        batch_size=batch_size,
        num_threads=batch_size,
        make_keys_unique=True,
        capacity=batch_size * batch_size)

    return _build_training_batch_dict(batch, unroll_length, batch_size)
Exemplo n.º 10
0
def build(input_reader_config,
          model_config,
          lstm_config,
          unroll_length,
          data_augmentation_options=None,
          batch_size=1):
  """Builds a tensor dictionary based on the InputReader config.

  Args:
    input_reader_config: An input_reader_builder.InputReader object.
    model_config: A model.proto object containing the config for the desired
      DetectionModel.
    lstm_config: LSTM specific configs.
    unroll_length: Unrolled length for LSTM training.
    data_augmentation_options: A list of tuples, where each tuple contains a
      data augmentation function and a dictionary containing arguments and their
      values (see preprocessor.py).
    batch_size: Batch size for queue outputs.

  Returns:
    A dictionary of tensors based on items in the input_reader_config.

  Raises:
    ValueError: On invalid input reader proto.
    ValueError: If no input paths are specified.
  """
  if not isinstance(input_reader_config, input_reader_pb2.InputReader):
    raise ValueError('input_reader_config not of type '
                     'input_reader_pb2.InputReader.')

  external_reader_config = input_reader_config.external_input_reader
  google_input_reader_config = external_reader_config.Extensions[
      input_reader_google_pb2.GoogleInputReader.google_input_reader]
  input_reader_type = google_input_reader_config.WhichOneof('input_reader')

  if input_reader_type == 'tf_record_video_input_reader':
    config = google_input_reader_config.tf_record_video_input_reader
    reader_type_class = tf.TFRecordReader
  else:
    raise ValueError(
        'Unsupported reader in input_reader_config: %s' % input_reader_type)

  if not config.input_path:
    raise ValueError('At least one input path must be specified in '
                     '`input_reader_config`.')
  key, value = parallel_reader.parallel_read(
      config.input_path[:],  # Convert `RepeatedScalarContainer` to list.
      reader_class=reader_type_class,
      num_epochs=(input_reader_config.num_epochs
                  if input_reader_config.num_epochs else None),
      num_readers=input_reader_config.num_readers,
      shuffle=input_reader_config.shuffle,
      dtypes=[tf.string, tf.string],
      capacity=input_reader_config.queue_capacity,
      min_after_dequeue=input_reader_config.min_after_dequeue)

  # TODO(yinxiao): Add loading instance mask option.
  decoder = tf_sequence_example_decoder.TFSequenceExampleDecoder()

  keys_to_decode = [
      fields.InputDataFields.image, fields.InputDataFields.groundtruth_boxes,
      fields.InputDataFields.groundtruth_classes
  ]
  tensor_dict = decoder.decode(value, items=keys_to_decode)

  tensor_dict['image'].set_shape([None, None, None, 3])
  tensor_dict['groundtruth_boxes'].set_shape([None, None, 4])

  height = model_config.ssd.image_resizer.fixed_shape_resizer.height
  width = model_config.ssd.image_resizer.fixed_shape_resizer.width

  # If data augmentation is specified in the config file, the preprocessor
  # will be called here to augment the data as specified. Most common
  # augmentations include horizontal flip and cropping.
  if data_augmentation_options:
    images_pre = tf.split(tensor_dict['image'], config.video_length, axis=0)
    bboxes_pre = tf.split(
        tensor_dict['groundtruth_boxes'], config.video_length, axis=0)
    labels_pre = tf.split(
        tensor_dict['groundtruth_classes'], config.video_length, axis=0)
    images_proc, bboxes_proc, labels_proc = [], [], []
    cache = preprocessor_cache.PreprocessorCache()

    for i, _ in enumerate(images_pre):
      image_dict = {
          fields.InputDataFields.image:
              images_pre[i],
          fields.InputDataFields.groundtruth_boxes:
              tf.squeeze(bboxes_pre[i], axis=0),
          fields.InputDataFields.groundtruth_classes:
              tf.squeeze(labels_pre[i], axis=0),
      }
      image_dict = preprocessor.preprocess(
          image_dict,
          data_augmentation_options,
          func_arg_map=preprocessor.get_default_func_arg_map(),
          preprocess_vars_cache=cache)
      # Pads detection count to _PADDING_SIZE.
      image_dict[fields.InputDataFields.groundtruth_boxes] = tf.pad(
          image_dict[fields.InputDataFields.groundtruth_boxes],
          [[0, _PADDING_SIZE], [0, 0]])
      image_dict[fields.InputDataFields.groundtruth_boxes] = tf.slice(
          image_dict[fields.InputDataFields.groundtruth_boxes], [0, 0],
          [_PADDING_SIZE, -1])
      image_dict[fields.InputDataFields.groundtruth_classes] = tf.pad(
          image_dict[fields.InputDataFields.groundtruth_classes],
          [[0, _PADDING_SIZE]])
      image_dict[fields.InputDataFields.groundtruth_classes] = tf.slice(
          image_dict[fields.InputDataFields.groundtruth_classes], [0],
          [_PADDING_SIZE])
      images_proc.append(image_dict[fields.InputDataFields.image])
      bboxes_proc.append(image_dict[fields.InputDataFields.groundtruth_boxes])
      labels_proc.append(image_dict[fields.InputDataFields.groundtruth_classes])
    tensor_dict['image'] = tf.concat(images_proc, axis=0)
    tensor_dict['groundtruth_boxes'] = tf.stack(bboxes_proc, axis=0)
    tensor_dict['groundtruth_classes'] = tf.stack(labels_proc, axis=0)
  else:
    # Pads detection count to _PADDING_SIZE per frame.
    tensor_dict['groundtruth_boxes'] = tf.pad(
        tensor_dict['groundtruth_boxes'], [[0, 0], [0, _PADDING_SIZE], [0, 0]])
    tensor_dict['groundtruth_boxes'] = tf.slice(
        tensor_dict['groundtruth_boxes'], [0, 0, 0], [-1, _PADDING_SIZE, -1])
    tensor_dict['groundtruth_classes'] = tf.pad(
        tensor_dict['groundtruth_classes'], [[0, 0], [0, _PADDING_SIZE]])
    tensor_dict['groundtruth_classes'] = tf.slice(
        tensor_dict['groundtruth_classes'], [0, 0], [-1, _PADDING_SIZE])

  tensor_dict['image'], _ = preprocessor.resize_image(
      tensor_dict['image'], new_height=height, new_width=width)

  num_steps = config.video_length / unroll_length

  init_states = {
      'lstm_state_c':
          tf.zeros([height / 32, width / 32, lstm_config.lstm_state_depth]),
      'lstm_state_h':
          tf.zeros([height / 32, width / 32, lstm_config.lstm_state_depth]),
      'lstm_state_step':
          tf.constant(num_steps, shape=[]),
  }

  batch = sqss.batch_sequences_with_states(
      input_key=key,
      input_sequences=tensor_dict,
      input_context={},
      input_length=None,
      initial_states=init_states,
      num_unroll=unroll_length,
      batch_size=batch_size,
      num_threads=batch_size,
      make_keys_unique=True,
      capacity=batch_size * batch_size)

  return _build_training_batch_dict(batch, unroll_length, batch_size)