def test_rnn_decoder_multiple_unroll_with_skip(self):
    batch_size = 2
    num_unroll = 5
    num_units = 12
    width = 8
    height = 10
    input_channels_large = 24
    input_channels_small = 12
    bottleneck_channels = 20
    skip = 2

    initial_state_c = tf.random_normal((batch_size, width, height, num_units))
    initial_state_h = tf.random_normal((batch_size, width, height, num_units))
    initial_state = (initial_state_c, initial_state_h)
    inputs_large = tf.random_normal(
        [batch_size, width, height, input_channels_large])
    inputs_small = tf.random_normal(
        [batch_size, width, height, input_channels_small])

    rnn_cell = MockRnnCell(bottleneck_channels, num_units)
    outputs, states = rnn_decoder.multi_input_rnn_decoder(
        decoder_inputs=[[inputs_large] * num_unroll,
                        [inputs_small] * num_unroll],
        initial_state=initial_state,
        cell=rnn_cell,
        sequence_step=tf.zeros([batch_size]),
        pre_bottleneck=True,
        selection_strategy='SKIP%d' % skip)

    self.assertEqual(len(outputs), num_unroll)
    self.assertEqual(len(states), num_unroll)
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      results = sess.run(
          (outputs, states, inputs_large, inputs_small, initial_state))
      outputs_results = results[0]
      states_results = results[1]
      inputs_large_results = results[2]
      inputs_small_results = results[3]
      initial_states_results = results[4]

      for i in range(num_unroll):
        self.assertEqual(
            outputs_results[i].shape,
            (batch_size, width, height, bottleneck_channels + num_units))
        self.assertEqual(states_results[i][0].shape,
                         (batch_size, width, height, num_units))
        self.assertEqual(states_results[i][1].shape,
                         (batch_size, width, height, num_units))

        previous_state = (
            initial_states_results if i == 0 else states_results[i - 1])
        # State only updates during key frames
        if i % (skip + 1) == 0:
          self.assertAllEqual(states_results[i][0],
                              np.multiply(previous_state[0], 2))
          self.assertAllEqual(states_results[i][1], previous_state[1])
        else:
          self.assertAllEqual(states_results[i][0], previous_state[0])
          self.assertAllEqual(states_results[i][1], previous_state[1])
    def test_rnn_decoder_multiple_unroll(self):
        batch_size = 2
        num_unroll = 3
        num_units = 12
        width = 8
        height = 10
        input_channels_large = 24
        input_channels_small = 12
        bottleneck_channels = 20

        initial_state_c = tf.random_normal(
            (batch_size, width, height, num_units))
        initial_state_h = tf.random_normal(
            (batch_size, width, height, num_units))
        initial_state = (initial_state_c, initial_state_h)
        inputs_large = tf.random_normal(
            [batch_size, width, height, input_channels_large])
        inputs_small = tf.random_normal(
            [batch_size, width, height, input_channels_small])

        rnn_cell = MockRnnCell(bottleneck_channels, num_units)
        outputs, states = rnn_decoder.multi_input_rnn_decoder(
            decoder_inputs=[[inputs_large] * num_unroll,
                            [inputs_small] * num_unroll],
            initial_state=initial_state,
            cell=rnn_cell,
            sequence_step=tf.zeros([batch_size]),
            pre_bottleneck=True)

        self.assertEqual(len(outputs), num_unroll)
        self.assertEqual(len(states), num_unroll)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            results = sess.run(
                (outputs, states, inputs_large, inputs_small, initial_state))
            outputs_results = results[0]
            states_results = results[1]
            inputs_large_results = results[2]
            inputs_small_results = results[3]
            initial_states_results = results[4]

            # The first step should always update state.
            self.assertAllEqual(states_results[0][0],
                                np.multiply(initial_states_results[0], 2))
            self.assertAllEqual(states_results[0][1],
                                initial_states_results[1])
            for i in range(num_unroll):
                self.assertEqual(outputs_results[i].shape,
                                 (batch_size, width, height,
                                  bottleneck_channels + num_units))
                self.assertEqual(states_results[i][0].shape,
                                 (batch_size, width, height, num_units))
                self.assertEqual(states_results[i][1].shape,
                                 (batch_size, width, height, num_units))
  def extract_features(self, preprocessed_inputs, state_saver=None,
                       state_name='lstm_state', unroll_length=10, scope=None):
    """Extract features from preprocessed inputs.

    The features include the base network features, lstm features and SSD
    features, organized in the following name scope:

    <scope>/MobilenetV2_1/...
    <scope>/MobilenetV2_2/...
    <scope>/LSTM/...
    <scope>/FeatureMap/...

    Args:
      preprocessed_inputs: a [batch, height, width, channels] float tensor
        representing a batch of consecutive frames from video clips.
      state_saver: A state saver object with methods `state` and `save_state`.
      state_name: Python string, the name to use with the state_saver.
      unroll_length: number of steps to unroll the lstm.
      scope: Scope for the base network of the feature extractor.

    Returns:
      feature_maps: a list of tensors where the ith tensor has shape
        [batch, height_i, width_i, depth_i]
    Raises:
      ValueError: if interleave_method not recognized or large and small base
        network output feature maps of different sizes.
    """
    preprocessed_inputs = shape_utils.check_min_image_dim(
        33, preprocessed_inputs)
    preprocessed_inputs = ops.pad_to_multiple(
        preprocessed_inputs, self._pad_to_multiple)
    batch_size = preprocessed_inputs.shape[0].value / unroll_length
    batch_axis = 0
    nets = []

    # Batch processing of mobilenet features.
    with slim.arg_scope(mobilenet_v2.training_scope(
        is_training=self._is_training,
        bn_decay=0.9997)), \
        slim.arg_scope([mobilenet.depth_multiplier],
                       min_depth=self._min_depth, divisible_by=8):
      # Big model.
      net, _ = self.extract_base_features_large(preprocessed_inputs)
      nets.append(net)
      large_base_feature_shape = net.shape

      # Small models
      net, _ = self.extract_base_features_small(preprocessed_inputs)
      nets.append(net)
      small_base_feature_shape = net.shape
      if not (large_base_feature_shape[1] == small_base_feature_shape[1] and
              large_base_feature_shape[2] == small_base_feature_shape[2]):
        raise ValueError('Large and Small base network feature map dimension '
                         'not equal!')

    with slim.arg_scope(self._conv_hyperparams_fn()):
      with tf.variable_scope('LSTM', reuse=self._reuse_weights):
        output_size = (large_base_feature_shape[1], large_base_feature_shape[2])
        lstm_cell, init_state, step = self.create_lstm_cell(
            batch_size, output_size, state_saver, state_name)

        nets_seq = [
            tf.split(net, unroll_length, axis=batch_axis) for net in nets
        ]

        net_seq, states_out = rnn_decoder.multi_input_rnn_decoder(
            nets_seq,
            init_state,
            lstm_cell,
            step,
            selection_strategy=self._interleave_method,
            is_training=self._is_training,
            is_quantized=self._is_quantized,
            pre_bottleneck=self._pre_bottleneck,
            flatten_state=self._flatten_state,
            scope=None)
        self._states_out = states_out

      batcher_ops = None
      if state_saver is not None:
        self._step = state_saver.state(state_name + '_step')
        batcher_ops = [
            state_saver.save_state(state_name + '_c', states_out[-1][0]),
            state_saver.save_state(state_name + '_h', states_out[-1][1]),
            state_saver.save_state(state_name + '_step', self._step + 1)]
      image_features = {}
      with tf_ops.control_dependencies(batcher_ops):
        image_features['layer_19'] = tf.concat(net_seq, 0)

      # SSD layers.
      with tf.variable_scope('FeatureMap'):
        feature_maps = feature_map_generators.multi_resolution_feature_maps(
            feature_map_layout=self._feature_map_layout,
            depth_multiplier=self._depth_multiplier,
            min_depth=self._min_depth,
            insert_1x1_conv=True,
            image_features=image_features,
            pool_residual=True)
    return feature_maps.values()