Beispiel #1
0
def select_inputs(decoder_inputs,
                  action,
                  local_step,
                  is_training,
                  is_quantized,
                  get_alt_inputs=False):
    """Selects sequence from decoder_inputs based on 1D actions.

  Given multiple input batches, creates a single output batch by
  selecting from the action[i]-ith input for the i-th batch element.

  Args:
    decoder_inputs: A 2-D list of tensor inputs.
    action: A tensor of shape [batch_size]. Each element corresponds to an index
      of decoder_inputs to choose.
    local_step: The current timestep.
    is_training: boolean, whether the network is training. When using learned
      selection, attempts exploration if training.
    is_quantized: flag to enable/disable quantization mode.
    get_alt_inputs: Whether the non-chosen inputs should also be returned.

  Returns:
    The constructed output. Also outputs the elements that were not chosen
    if get_alt_inputs is True, otherwise None.

  Raises:
    ValueError: if the decoder inputs contains other than two sequences.
  """
    num_seqs = len(decoder_inputs)
    if not num_seqs == 2:
        raise ValueError('Currently only supports two sets of inputs.')
    stacked_inputs = tf.stack([
        decoder_inputs[seq_index][local_step] for seq_index in range(num_seqs)
    ],
                              axis=-1)
    action_index = tf.one_hot(action, num_seqs)
    selected_inputs = (lstm_utils.quantize_op(stacked_inputs * action_index,
                                              is_training,
                                              is_quantized,
                                              scope='quant_selected_inputs'))
    inputs = tf.reduce_sum(selected_inputs, axis=-1)
    inputs_alt = None
    # Only works for 2 models.
    if get_alt_inputs:
        # Reverse of action_index.
        action_index_alt = tf.one_hot(action,
                                      num_seqs,
                                      on_value=0.0,
                                      off_value=1.0)
        selected_inputs = (lstm_utils.quantize_op(
            stacked_inputs * action_index_alt,
            is_training,
            is_quantized,
            scope='quant_selected_inputs_alt'))
        inputs_alt = tf.reduce_sum(selected_inputs, axis=-1)
    return inputs, inputs_alt
Beispiel #2
0
    def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM) with bottlenecking.

    Includes logic for quantization-aware training. Note that all concats and
    activations use fixed ranges unless stated otherwise.

    Args:
      inputs: Input tensor at the current timestep.
      state: Tuple of tensors, the state at the previous timestep.
      scope: Optional scope.

    Returns:
      A tuple where the first element is the LSTM output and the second is
      a LSTMStateTuple of the state at the current timestep.
    """
        scope = scope or 'conv_lstm_cell'
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            c, h = state

            # Set nodes to be under raw_inputs/ name scope for tfmini export.
            with tf.name_scope(None):
                c = tf.identity(c, name='raw_inputs/init_lstm_c')
                # When pre_bottleneck is enabled, input h handle is in rnn_decoder.py
                if not self._pre_bottleneck:
                    h = tf.identity(h, name='raw_inputs/init_lstm_h')

            # unflatten state if necessary
            if self._flatten_state:
                c = tf.reshape(c, [-1] + self.output_size)
                h = tf.reshape(h, [-1] + self.output_size)

            c_list = tf.split(c, self._groups, axis=3)
            if self._pre_bottleneck:
                inputs_list = tf.split(inputs, self._groups, axis=3)
            else:
                h_list = tf.split(h, self._groups, axis=3)
            out_bottleneck = []
            out_c = []
            out_h = []
            # summary of input passed into cell
            if self._viz_gates:
                slim.summaries.add_histogram_summary(inputs, 'cell_input')

            for k in range(self._groups):
                if self._pre_bottleneck:
                    bottleneck = inputs_list[k]
                else:
                    if self._use_batch_norm:
                        b_x = lstm_utils.quantizable_separable_conv2d(
                            inputs,
                            self._num_units // self._groups,
                            self._filter_size,
                            is_quantized=self._is_quantized,
                            depth_multiplier=1,
                            activation_fn=None,
                            normalizer_fn=None,
                            scope='bottleneck_%d_x' % k)
                        b_h = lstm_utils.quantizable_separable_conv2d(
                            h_list[k],
                            self._num_units // self._groups,
                            self._filter_size,
                            is_quantized=self._is_quantized,
                            depth_multiplier=1,
                            activation_fn=None,
                            normalizer_fn=None,
                            scope='bottleneck_%d_h' % k)
                        b_x = slim.batch_norm(b_x,
                                              scale=True,
                                              is_training=self._is_training,
                                              scope='BatchNorm_%d_X' % k)
                        b_h = slim.batch_norm(b_h,
                                              scale=True,
                                              is_training=self._is_training,
                                              scope='BatchNorm_%d_H' % k)
                        bottleneck = b_x + b_h
                    else:
                        # All concats use fixed quantization ranges to prevent rescaling
                        # at inference. Both |inputs| and |h_list| are tensors resulting
                        # from Relu6 operations so we fix the ranges to [0, 6].
                        bottleneck_concat = lstm_utils.quantizable_concat(
                            [inputs, h_list[k]],
                            axis=3,
                            is_training=False,
                            is_quantized=self._is_quantized,
                            scope='bottleneck_%d/quantized_concat' % k)

                        bottleneck = lstm_utils.quantizable_separable_conv2d(
                            bottleneck_concat,
                            self._num_units // self._groups,
                            self._filter_size,
                            is_quantized=self._is_quantized,
                            depth_multiplier=1,
                            activation_fn=self._activation,
                            normalizer_fn=None,
                            scope='bottleneck_%d' % k)

                concat = lstm_utils.quantizable_separable_conv2d(
                    bottleneck,
                    4 * self._num_units // self._groups,
                    self._filter_size,
                    is_quantized=self._is_quantized,
                    depth_multiplier=1,
                    activation_fn=None,
                    normalizer_fn=None,
                    scope='concat_conv_%d' % k)

                # Since there is no activation in the previous separable conv, we
                # quantize here. A starting range of [-6, 6] is used because the
                # tensors are input to a Sigmoid function that saturates at these
                # ranges.
                concat = lstm_utils.quantize_op(
                    concat,
                    is_training=self._is_training,
                    default_min=-6,
                    default_max=6,
                    is_quantized=self._is_quantized,
                    scope='gates_%d/act_quant' % k)

                # i = input_gate, j = new_input, f = forget_gate, o = output_gate
                i, j, f, o = tf.split(concat, 4, 3)

                f_add = f + self._forget_bias
                f_add = lstm_utils.quantize_op(
                    f_add,
                    is_training=self._is_training,
                    default_min=-6,
                    default_max=6,
                    is_quantized=self._is_quantized,
                    scope='forget_gate_%d/add_quant' % k)
                f_act = tf.sigmoid(f_add)
                # The quantization range is fixed for the sigmoid to ensure that zero
                # is exactly representable.
                f_act = lstm_utils.quantize_op(
                    f_act,
                    is_training=False,
                    default_min=0,
                    default_max=1,
                    is_quantized=self._is_quantized,
                    scope='forget_gate_%d/act_quant' % k)

                a = c_list[k] * f_act
                a = lstm_utils.quantize_op(a,
                                           is_training=self._is_training,
                                           is_quantized=self._is_quantized,
                                           scope='forget_gate_%d/mul_quant' %
                                           k)

                i_act = tf.sigmoid(i)
                # The quantization range is fixed for the sigmoid to ensure that zero
                # is exactly representable.
                i_act = lstm_utils.quantize_op(
                    i_act,
                    is_training=False,
                    default_min=0,
                    default_max=1,
                    is_quantized=self._is_quantized,
                    scope='input_gate_%d/act_quant' % k)

                j_act = self._activation(j)
                # The quantization range is fixed for the relu6 to ensure that zero
                # is exactly representable.
                j_act = lstm_utils.quantize_op(j_act,
                                               is_training=False,
                                               default_min=0,
                                               default_max=6,
                                               is_quantized=self._is_quantized,
                                               scope='new_input_%d/act_quant' %
                                               k)

                b = i_act * j_act
                b = lstm_utils.quantize_op(b,
                                           is_training=self._is_training,
                                           is_quantized=self._is_quantized,
                                           scope='input_gate_%d/mul_quant' % k)

                new_c = a + b
                # The quantization range is fixed to [0, 6] due to an optimization in
                # TFLite. The order of operations is as fllows:
                #     Add -> FakeQuant -> Relu6 -> FakeQuant -> Concat.
                # The fakequant ranges to the concat must be fixed to ensure all inputs
                # to the concat have the same range, removing the need for rescaling.
                # The quantization ranges input to the relu6 are propagated to its
                # output. Any mismatch between these two ranges will cause an error.
                new_c = lstm_utils.quantize_op(new_c,
                                               is_training=False,
                                               default_min=0,
                                               default_max=6,
                                               is_quantized=self._is_quantized,
                                               scope='new_c_%d/add_quant' % k)

                if not self._is_quantized:
                    if self._scale_state:
                        normalizer = tf.maximum(
                            1.0,
                            tf.reduce_max(new_c, axis=(1, 2, 3)) / 6)
                        new_c /= tf.reshape(normalizer,
                                            [tf.shape(new_c)[0], 1, 1, 1])
                    elif self._clip_state:
                        new_c = tf.clip_by_value(new_c, -6, 6)

                new_c_act = self._activation(new_c)
                # The quantization range is fixed for the relu6 to ensure that zero
                # is exactly representable.
                new_c_act = lstm_utils.quantize_op(
                    new_c_act,
                    is_training=False,
                    default_min=0,
                    default_max=6,
                    is_quantized=self._is_quantized,
                    scope='new_c_%d/act_quant' % k)

                o_act = tf.sigmoid(o)
                # The quantization range is fixed for the sigmoid to ensure that zero
                # is exactly representable.
                o_act = lstm_utils.quantize_op(o_act,
                                               is_training=False,
                                               default_min=0,
                                               default_max=1,
                                               is_quantized=self._is_quantized,
                                               scope='output_%d/act_quant' % k)

                new_h = new_c_act * o_act
                # The quantization range is fixed since it is input to a concat.
                # A range of [0, 6] is used since |new_h| is a product of ranges [0, 6]
                # and [0, 1].
                new_h_act = lstm_utils.quantize_op(
                    new_h,
                    is_training=False,
                    default_min=0,
                    default_max=6,
                    is_quantized=self._is_quantized,
                    scope='new_h_%d/act_quant' % k)

                out_bottleneck.append(bottleneck)
                out_c.append(new_c_act)
                out_h.append(new_h_act)

            # Since all inputs to the below concats are already quantized, we can use
            # a regular concat operation.
            new_c = tf.concat(out_c, axis=3)
            new_h = tf.concat(out_h, axis=3)

            # |bottleneck| is input to a concat with |new_h|. We must use
            # quantizable_concat() with a fixed range that matches |new_h|.
            bottleneck = lstm_utils.quantizable_concat(
                out_bottleneck,
                axis=3,
                is_training=False,
                is_quantized=self._is_quantized,
                scope='out_bottleneck/quantized_concat')

            # summary of cell output and new state
            if self._viz_gates:
                slim.summaries.add_histogram_summary(new_h, 'cell_output')
                slim.summaries.add_histogram_summary(new_c, 'cell_state')

            output = new_h
            if self._output_bottleneck:
                output = lstm_utils.quantizable_concat(
                    [new_h, bottleneck],
                    axis=3,
                    is_training=False,
                    is_quantized=self._is_quantized,
                    scope='new_output/quantized_concat')

            # reflatten state to store it
            if self._flatten_state:
                new_c = tf.reshape(new_c, [-1, self._param_count],
                                   name='lstm_c')
                new_h = tf.reshape(new_h, [-1, self._param_count],
                                   name='lstm_h')

            # Set nodes to be under raw_outputs/ name scope for tfmini export.
            with tf.name_scope(None):
                new_c = tf.identity(new_c, name='raw_outputs/lstm_c')
                new_h = tf.identity(new_h, name='raw_outputs/lstm_h')
            states_and_output = contrib_rnn.LSTMStateTuple(new_c, new_h)

            return output, states_and_output
Beispiel #3
0
 def test_quantize_op_inferene(self):
     inputs = tf.zeros([4, 10, 10, 128], dtype=tf.float32)
     outputs = utils.quantize_op(inputs, is_training=False)
     self.assertAllEqual(inputs.shape.as_list(), outputs.shape.as_list())
     self._check_no_min_max_ema(tf.get_default_graph())
     self._check_min_max_vars(tf.get_default_graph())