def test_rnn_decoder_multiple_unroll(self):
        batch_size = 2
        num_unroll = 3
        num_units = 64
        width = 8
        height = 10
        input_channels = 128

        initial_state = tf.random_normal(
            (batch_size, width, height, num_units))
        inputs = tf.random_normal([batch_size, width, height, input_channels])

        rnn_cell = MockRnnCell(input_channels, num_units)
        outputs, states = rnn_decoder.rnn_decoder(
            decoder_inputs=[inputs] * num_unroll,
            initial_state=(initial_state, initial_state),
            cell=rnn_cell)

        self.assertEqual(len(outputs), num_unroll)
        self.assertEqual(len(states), num_unroll)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            results = sess.run((outputs, states, inputs, initial_state))
            outputs_results = results[0]
            states_results = results[1]
            inputs_results = results[2]
            initial_states_results = results[3]
            for i in range(num_unroll):
                previous_state = ([
                    initial_states_results, initial_states_results
                ] if i == 0 else states_results[i - 1])
                self.assertEqual(
                    outputs_results[i].shape,
                    (batch_size, width, height, input_channels + num_units))
                self.assertAllEqual(
                    outputs_results[i],
                    np.concatenate((inputs_results, previous_state[0]),
                                   axis=3))
                self.assertEqual(states_results[i][0].shape,
                                 (batch_size, width, height, num_units))
                self.assertEqual(states_results[i][1].shape,
                                 (batch_size, width, height, num_units))
                self.assertAllEqual(states_results[i][0],
                                    np.multiply(previous_state[0], 2.0))
                self.assertAllEqual(states_results[i][1], previous_state[1])
    def extract_features(self,
                         preprocessed_inputs,
                         state_saver=None,
                         state_name='lstm_state',
                         unroll_length=5,
                         scope=None):
        """Extracts features from preprocessed inputs.

    The features include the base network features, lstm features and SSD
    features, organized in the following name scope:

    <parent scope>/MobilenetV1/...
    <parent scope>/LSTM/...
    <parent scope>/FeatureMaps/...

    Args:
      preprocessed_inputs: A [batch, height, width, channels] float tensor
        representing a batch of consecutive frames from video clips.
      state_saver: A state saver object with methods `state` and `save_state`.
      state_name: A python string for the name to use with the state_saver.
      unroll_length: The number of steps to unroll the lstm.
      scope: The scope for the base network of the feature extractor.

    Returns:
      A list of tensors where the ith tensor has shape [batch, height_i,
      width_i, depth_i]
    """
        preprocessed_inputs = shape_utils.check_min_image_dim(
            33, preprocessed_inputs)
        with slim.arg_scope(
                mobilenet_v1.mobilenet_v1_arg_scope(
                    is_training=self._is_training)):
            with (slim.arg_scope(self._conv_hyperparams_fn())
                  if self._override_base_feature_extractor_hyperparams else
                  context_manager.IdentityContextManager()):
                with slim.arg_scope([slim.batch_norm], fused=False):
                    # Base network.
                    with tf.variable_scope(scope,
                                           self._base_network_scope,
                                           reuse=self._reuse_weights) as scope:
                        net, image_features = mobilenet_v1.mobilenet_v1_base(
                            ops.pad_to_multiple(preprocessed_inputs,
                                                self._pad_to_multiple),
                            final_endpoint='Conv2d_13_pointwise',
                            min_depth=self._min_depth,
                            depth_multiplier=self._depth_multiplier,
                            scope=scope)

        with slim.arg_scope(self._conv_hyperparams_fn()):
            with slim.arg_scope([slim.batch_norm],
                                fused=False,
                                is_training=self._is_training):
                # ConvLSTM layers.
                with tf.variable_scope(
                        'LSTM', reuse=self._reuse_weights) as lstm_scope:
                    lstm_cell = lstm_cells.BottleneckConvLSTMCell(
                        filter_size=(3, 3),
                        output_size=(net.shape[1].value, net.shape[2].value),
                        num_units=max(self._min_depth, self._lstm_state_depth),
                        activation=tf.nn.relu6,
                        visualize_gates=True)

                    net_seq = list(tf.split(net, unroll_length))
                    if state_saver is None:
                        init_state = lstm_cell.init_state(
                            state_name, net.shape[0].value / unroll_length,
                            tf.float32)
                    else:
                        c = state_saver.state('%s_c' % state_name)
                        h = state_saver.state('%s_h' % state_name)
                        init_state = (c, h)

                    # Identities added for inputing state tensors externally.
                    c_ident = tf.identity(init_state[0],
                                          name='lstm_state_in_c')
                    h_ident = tf.identity(init_state[1],
                                          name='lstm_state_in_h')
                    init_state = (c_ident, h_ident)

                    net_seq, states_out = rnn_decoder.rnn_decoder(
                        net_seq, init_state, lstm_cell, scope=lstm_scope)
                    batcher_ops = None
                    self._states_out = states_out
                    if state_saver is not None:
                        self._step = state_saver.state('%s_step' % state_name)
                        batcher_ops = [
                            state_saver.save_state('%s_c' % state_name,
                                                   states_out[-1][0]),
                            state_saver.save_state('%s_h' % state_name,
                                                   states_out[-1][1]),
                            state_saver.save_state('%s_step' % state_name,
                                                   self._step - 1)
                        ]
                    with tf_ops.control_dependencies(batcher_ops):
                        image_features['Conv2d_13_pointwise_lstm'] = tf.concat(
                            net_seq, 0)

                    # Identities added for reading output states, to be reused externally.
                    tf.identity(states_out[-1][0], name='lstm_state_out_c')
                    tf.identity(states_out[-1][1], name='lstm_state_out_h')

                # SSD layers.
                with tf.variable_scope('FeatureMaps',
                                       reuse=self._reuse_weights):
                    feature_maps = feature_map_generators.multi_resolution_feature_maps(
                        feature_map_layout=self._feature_map_layout,
                        depth_multiplier=(self._depth_multiplier),
                        min_depth=self._min_depth,
                        insert_1x1_conv=True,
                        image_features=image_features)

        return feature_maps.values()
  def extract_features(self,
                       preprocessed_inputs,
                       state_saver=None,
                       state_name='lstm_state',
                       unroll_length=5,
                       scope=None):
    """Extracts features from preprocessed inputs.

    The features include the base network features, lstm features and SSD
    features, organized in the following name scope:

    <parent scope>/MobilenetV1/...
    <parent scope>/LSTM/...
    <parent scope>/FeatureMaps/...

    Args:
      preprocessed_inputs: A [batch, height, width, channels] float tensor
        representing a batch of consecutive frames from video clips.
      state_saver: A state saver object with methods `state` and `save_state`.
      state_name: A python string for the name to use with the state_saver.
      unroll_length: The number of steps to unroll the lstm.
      scope: The scope for the base network of the feature extractor.

    Returns:
      A list of tensors where the ith tensor has shape [batch, height_i,
      width_i, depth_i]
    """
    preprocessed_inputs = shape_utils.check_min_image_dim(
        33, preprocessed_inputs)
    with slim.arg_scope(
        mobilenet_v1.mobilenet_v1_arg_scope(is_training=self._is_training)):
      with (slim.arg_scope(self._conv_hyperparams_fn())
            if self._override_base_feature_extractor_hyperparams else
            context_manager.IdentityContextManager()):
        with slim.arg_scope([slim.batch_norm], fused=False):
          # Base network.
          with tf.variable_scope(
              scope, self._base_network_scope,
              reuse=self._reuse_weights) as scope:
            net, image_features = mobilenet_v1.mobilenet_v1_base(
                ops.pad_to_multiple(preprocessed_inputs, self._pad_to_multiple),
                final_endpoint='Conv2d_13_pointwise',
                min_depth=self._min_depth,
                depth_multiplier=self._depth_multiplier,
                scope=scope)

    with slim.arg_scope(self._conv_hyperparams_fn()):
      with slim.arg_scope(
          [slim.batch_norm], fused=False, is_training=self._is_training):
        # ConvLSTM layers.
        with tf.variable_scope('LSTM', reuse=self._reuse_weights) as lstm_scope:
          lstm_cell = lstm_cells.BottleneckConvLSTMCell(
              filter_size=(3, 3),
              output_size=(net.shape[1].value, net.shape[2].value),
              num_units=max(self._min_depth, self._lstm_state_depth),
              activation=tf.nn.relu6,
              visualize_gates=True)

          net_seq = list(tf.split(net, unroll_length))
          if state_saver is None:
            init_state = lstm_cell.init_state(
                state_name, net.shape[0].value / unroll_length, tf.float32)
          else:
            c = state_saver.state('%s_c' % state_name)
            h = state_saver.state('%s_h' % state_name)
            init_state = (c, h)

          # Identities added for inputing state tensors externally.
          c_ident = tf.identity(init_state[0], name='lstm_state_in_c')
          h_ident = tf.identity(init_state[1], name='lstm_state_in_h')
          init_state = (c_ident, h_ident)

          net_seq, states_out = rnn_decoder.rnn_decoder(
              net_seq, init_state, lstm_cell, scope=lstm_scope)
          batcher_ops = None
          self._states_out = states_out
          if state_saver is not None:
            self._step = state_saver.state('%s_step' % state_name)
            batcher_ops = [
                state_saver.save_state('%s_c' % state_name, states_out[-1][0]),
                state_saver.save_state('%s_h' % state_name, states_out[-1][1]),
                state_saver.save_state('%s_step' % state_name, self._step - 1)
            ]
          with tf_ops.control_dependencies(batcher_ops):
            image_features['Conv2d_13_pointwise_lstm'] = tf.concat(net_seq, 0)

          # Identities added for reading output states, to be reused externally.
          tf.identity(states_out[-1][0], name='lstm_state_out_c')
          tf.identity(states_out[-1][1], name='lstm_state_out_h')

        # SSD layers.
        with tf.variable_scope('FeatureMaps', reuse=self._reuse_weights):
          feature_maps = feature_map_generators.multi_resolution_feature_maps(
              feature_map_layout=self._feature_map_layout,
              depth_multiplier=(self._depth_multiplier),
              min_depth=self._min_depth,
              insert_1x1_conv=True,
              image_features=image_features)

    return feature_maps.values()