def select_inputs(decoder_inputs, action, local_step, is_training, is_quantized, get_alt_inputs=False): """Selects sequence from decoder_inputs based on 1D actions. Given multiple input batches, creates a single output batch by selecting from the action[i]-ith input for the i-th batch element. Args: decoder_inputs: A 2-D list of tensor inputs. action: A tensor of shape [batch_size]. Each element corresponds to an index of decoder_inputs to choose. local_step: The current timestep. is_training: boolean, whether the network is training. When using learned selection, attempts exploration if training. is_quantized: flag to enable/disable quantization mode. get_alt_inputs: Whether the non-chosen inputs should also be returned. Returns: The constructed output. Also outputs the elements that were not chosen if get_alt_inputs is True, otherwise None. Raises: ValueError: if the decoder inputs contains other than two sequences. """ num_seqs = len(decoder_inputs) if not num_seqs == 2: raise ValueError('Currently only supports two sets of inputs.') stacked_inputs = tf.stack([ decoder_inputs[seq_index][local_step] for seq_index in range(num_seqs) ], axis=-1) action_index = tf.one_hot(action, num_seqs) selected_inputs = (lstm_utils.quantize_op(stacked_inputs * action_index, is_training, is_quantized, scope='quant_selected_inputs')) inputs = tf.reduce_sum(selected_inputs, axis=-1) inputs_alt = None # Only works for 2 models. if get_alt_inputs: # Reverse of action_index. action_index_alt = tf.one_hot(action, num_seqs, on_value=0.0, off_value=1.0) selected_inputs = (lstm_utils.quantize_op( stacked_inputs * action_index_alt, is_training, is_quantized, scope='quant_selected_inputs_alt')) inputs_alt = tf.reduce_sum(selected_inputs, axis=-1) return inputs, inputs_alt
def __call__(self, inputs, state, scope=None): """Long short-term memory cell (LSTM) with bottlenecking. Includes logic for quantization-aware training. Note that all concats and activations use fixed ranges unless stated otherwise. Args: inputs: Input tensor at the current timestep. state: Tuple of tensors, the state at the previous timestep. scope: Optional scope. Returns: A tuple where the first element is the LSTM output and the second is a LSTMStateTuple of the state at the current timestep. """ scope = scope or 'conv_lstm_cell' with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): c, h = state # Set nodes to be under raw_inputs/ name scope for tfmini export. with tf.name_scope(None): c = tf.identity(c, name='raw_inputs/init_lstm_c') # When pre_bottleneck is enabled, input h handle is in rnn_decoder.py if not self._pre_bottleneck: h = tf.identity(h, name='raw_inputs/init_lstm_h') # unflatten state if necessary if self._flatten_state: c = tf.reshape(c, [-1] + self.output_size) h = tf.reshape(h, [-1] + self.output_size) c_list = tf.split(c, self._groups, axis=3) if self._pre_bottleneck: inputs_list = tf.split(inputs, self._groups, axis=3) else: h_list = tf.split(h, self._groups, axis=3) out_bottleneck = [] out_c = [] out_h = [] # summary of input passed into cell if self._viz_gates: slim.summaries.add_histogram_summary(inputs, 'cell_input') for k in range(self._groups): if self._pre_bottleneck: bottleneck = inputs_list[k] else: if self._use_batch_norm: b_x = lstm_utils.quantizable_separable_conv2d( inputs, self._num_units // self._groups, self._filter_size, is_quantized=self._is_quantized, depth_multiplier=1, activation_fn=None, normalizer_fn=None, scope='bottleneck_%d_x' % k) b_h = lstm_utils.quantizable_separable_conv2d( h_list[k], self._num_units // self._groups, self._filter_size, is_quantized=self._is_quantized, depth_multiplier=1, activation_fn=None, normalizer_fn=None, scope='bottleneck_%d_h' % k) b_x = slim.batch_norm(b_x, scale=True, is_training=self._is_training, scope='BatchNorm_%d_X' % k) b_h = slim.batch_norm(b_h, scale=True, is_training=self._is_training, scope='BatchNorm_%d_H' % k) bottleneck = b_x + b_h else: # All concats use fixed quantization ranges to prevent rescaling # at inference. Both |inputs| and |h_list| are tensors resulting # from Relu6 operations so we fix the ranges to [0, 6]. bottleneck_concat = lstm_utils.quantizable_concat( [inputs, h_list[k]], axis=3, is_training=False, is_quantized=self._is_quantized, scope='bottleneck_%d/quantized_concat' % k) bottleneck = lstm_utils.quantizable_separable_conv2d( bottleneck_concat, self._num_units // self._groups, self._filter_size, is_quantized=self._is_quantized, depth_multiplier=1, activation_fn=self._activation, normalizer_fn=None, scope='bottleneck_%d' % k) concat = lstm_utils.quantizable_separable_conv2d( bottleneck, 4 * self._num_units // self._groups, self._filter_size, is_quantized=self._is_quantized, depth_multiplier=1, activation_fn=None, normalizer_fn=None, scope='concat_conv_%d' % k) # Since there is no activation in the previous separable conv, we # quantize here. A starting range of [-6, 6] is used because the # tensors are input to a Sigmoid function that saturates at these # ranges. concat = lstm_utils.quantize_op( concat, is_training=self._is_training, default_min=-6, default_max=6, is_quantized=self._is_quantized, scope='gates_%d/act_quant' % k) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = tf.split(concat, 4, 3) f_add = f + self._forget_bias f_add = lstm_utils.quantize_op( f_add, is_training=self._is_training, default_min=-6, default_max=6, is_quantized=self._is_quantized, scope='forget_gate_%d/add_quant' % k) f_act = tf.sigmoid(f_add) # The quantization range is fixed for the sigmoid to ensure that zero # is exactly representable. f_act = lstm_utils.quantize_op( f_act, is_training=False, default_min=0, default_max=1, is_quantized=self._is_quantized, scope='forget_gate_%d/act_quant' % k) a = c_list[k] * f_act a = lstm_utils.quantize_op(a, is_training=self._is_training, is_quantized=self._is_quantized, scope='forget_gate_%d/mul_quant' % k) i_act = tf.sigmoid(i) # The quantization range is fixed for the sigmoid to ensure that zero # is exactly representable. i_act = lstm_utils.quantize_op( i_act, is_training=False, default_min=0, default_max=1, is_quantized=self._is_quantized, scope='input_gate_%d/act_quant' % k) j_act = self._activation(j) # The quantization range is fixed for the relu6 to ensure that zero # is exactly representable. j_act = lstm_utils.quantize_op(j_act, is_training=False, default_min=0, default_max=6, is_quantized=self._is_quantized, scope='new_input_%d/act_quant' % k) b = i_act * j_act b = lstm_utils.quantize_op(b, is_training=self._is_training, is_quantized=self._is_quantized, scope='input_gate_%d/mul_quant' % k) new_c = a + b # The quantization range is fixed to [0, 6] due to an optimization in # TFLite. The order of operations is as fllows: # Add -> FakeQuant -> Relu6 -> FakeQuant -> Concat. # The fakequant ranges to the concat must be fixed to ensure all inputs # to the concat have the same range, removing the need for rescaling. # The quantization ranges input to the relu6 are propagated to its # output. Any mismatch between these two ranges will cause an error. new_c = lstm_utils.quantize_op(new_c, is_training=False, default_min=0, default_max=6, is_quantized=self._is_quantized, scope='new_c_%d/add_quant' % k) if not self._is_quantized: if self._scale_state: normalizer = tf.maximum( 1.0, tf.reduce_max(new_c, axis=(1, 2, 3)) / 6) new_c /= tf.reshape(normalizer, [tf.shape(new_c)[0], 1, 1, 1]) elif self._clip_state: new_c = tf.clip_by_value(new_c, -6, 6) new_c_act = self._activation(new_c) # The quantization range is fixed for the relu6 to ensure that zero # is exactly representable. new_c_act = lstm_utils.quantize_op( new_c_act, is_training=False, default_min=0, default_max=6, is_quantized=self._is_quantized, scope='new_c_%d/act_quant' % k) o_act = tf.sigmoid(o) # The quantization range is fixed for the sigmoid to ensure that zero # is exactly representable. o_act = lstm_utils.quantize_op(o_act, is_training=False, default_min=0, default_max=1, is_quantized=self._is_quantized, scope='output_%d/act_quant' % k) new_h = new_c_act * o_act # The quantization range is fixed since it is input to a concat. # A range of [0, 6] is used since |new_h| is a product of ranges [0, 6] # and [0, 1]. new_h_act = lstm_utils.quantize_op( new_h, is_training=False, default_min=0, default_max=6, is_quantized=self._is_quantized, scope='new_h_%d/act_quant' % k) out_bottleneck.append(bottleneck) out_c.append(new_c_act) out_h.append(new_h_act) # Since all inputs to the below concats are already quantized, we can use # a regular concat operation. new_c = tf.concat(out_c, axis=3) new_h = tf.concat(out_h, axis=3) # |bottleneck| is input to a concat with |new_h|. We must use # quantizable_concat() with a fixed range that matches |new_h|. bottleneck = lstm_utils.quantizable_concat( out_bottleneck, axis=3, is_training=False, is_quantized=self._is_quantized, scope='out_bottleneck/quantized_concat') # summary of cell output and new state if self._viz_gates: slim.summaries.add_histogram_summary(new_h, 'cell_output') slim.summaries.add_histogram_summary(new_c, 'cell_state') output = new_h if self._output_bottleneck: output = lstm_utils.quantizable_concat( [new_h, bottleneck], axis=3, is_training=False, is_quantized=self._is_quantized, scope='new_output/quantized_concat') # reflatten state to store it if self._flatten_state: new_c = tf.reshape(new_c, [-1, self._param_count], name='lstm_c') new_h = tf.reshape(new_h, [-1, self._param_count], name='lstm_h') # Set nodes to be under raw_outputs/ name scope for tfmini export. with tf.name_scope(None): new_c = tf.identity(new_c, name='raw_outputs/lstm_c') new_h = tf.identity(new_h, name='raw_outputs/lstm_h') states_and_output = contrib_rnn.LSTMStateTuple(new_c, new_h) return output, states_and_output
def test_quantize_op_inferene(self): inputs = tf.zeros([4, 10, 10, 128], dtype=tf.float32) outputs = utils.quantize_op(inputs, is_training=False) self.assertAllEqual(inputs.shape.as_list(), outputs.shape.as_list()) self._check_no_min_max_ema(tf.get_default_graph()) self._check_min_max_vars(tf.get_default_graph())