def __call__(self, inputs, *args, **kwargs): def make_quantizer_fn(quantizer, x, training, mode, quantizer_vars): """Use currying to return True/False specialized fns to the cond.""" def quantizer_fn(): return quantizer(x, training, mode, weights=quantizer_vars) return quantizer_fn x = inputs if self._should_pre_quantize(): x = common_utils.smart_cond( self._training, make_quantizer_fn(self.quantizer, x, True, self.mode, self._pre_activation_vars), make_quantizer_fn(self.quantizer, x, False, self.mode, self._pre_activation_vars)) x = self.activation(x, *args, **kwargs) if self._should_post_quantize(): x = common_utils.smart_cond( self._training, make_quantizer_fn(self.quantizer, x, True, self.mode, self._post_activation_vars), make_quantizer_fn(self.quantizer, x, False, self.mode, self._post_activation_vars)) return x
def _quantize_weights(self, training): # Quantize the folded kernel and bias for weight, quantizer, quantizer_vars in self._weight_vars: weight_tensor = getattr(self, weight) quantized_weight = common_utils.smart_cond( training, self._make_quantizer_fn(quantizer, weight_tensor, True, self.mode, quantizer_vars), self._make_quantizer_fn(quantizer, weight_tensor, False, self.mode, quantizer_vars)) setattr(self, weight, quantized_weight)
def call(self, inputs, training=None): if training is None: training = tf.keras.backend.learning_phase() def _make_quantizer_fn(train_var): def quantizer_fn(): return self.quantizer(inputs, train_var, self.mode, weights=self.quantizer_vars) return quantizer_fn return common_utils.smart_cond(training, _make_quantizer_fn(True), _make_quantizer_fn(False))
def call(self, inputs, training=None): if training is None: training = tf.keras.backend.learning_phase() bias = common_utils.smart_cond(self.conv_layer.use_bias, lambda: self.conv_layer.bias, lambda: 0) if training: self.optimizer_step.assign_add(1) freeze_bn = common_utils.smart_cond( self.freeze_bn_delay is not None, lambda: math_ops.greater_equal( self.optimizer_step, self.freeze_bn_delay), lambda: False) # tf.print('step: {}, freeze_bn: {}'.format(self.optimizer_step, freeze_bn)) if training and not freeze_bn: # Run float conv and bn to update the moving mean and variance conv_out = self.conv_layer.call(inputs) bn_out = self.bn_layer.call(conv_out, training=training) mu_bt, var_bt = self._get_batch_mean_var(conv_out) sigma_bt = math_ops.rsqrt(var_bt + self.bn_layer.epsilon) # Get folded depthwise_kernel and bias self.depthwise_kernel, self.bias = _get_folded_kernel_bias( conv_type='DepthwiseConv2D', kernel=self.conv_layer.depthwise_kernel, bias=bias, mu=mu_bt, var=var_bt, gamma=self.bn_layer.gamma, beta=self.bn_layer.beta, epsilon=self.bn_layer.epsilon) # BatchNorm Correction corr_scale, corr_recip, corr_offset = _get_bn_correction( conv_type='DepthwiseConv2D', kernel=self.conv_layer.depthwise_kernel, bias=bias, mu_bt=mu_bt, var_bt=var_bt, mu_mv=self.bn_layer.moving_mean, var_mv=self.bn_layer.moving_variance, gamma=self.bn_layer.gamma, epsilon=self.bn_layer.epsilon) self.depthwise_kernel = math_ops.mul(self.depthwise_kernel, corr_scale) self.bias = math_ops.add(self.bias, corr_offset) self._quantize_weights(training) outputs = self._run_folded_conv(inputs, training) # BatchNorm Correction for convolution outputs outputs = math_ops.mul(outputs, corr_recip) else: self.depthwise_kernel, self.bias = _get_folded_kernel_bias( conv_type='DepthwiseConv2D', kernel=self.conv_layer.depthwise_kernel, bias=bias, mu=self.bn_layer.moving_mean, var=self.bn_layer.moving_variance, gamma=self.bn_layer.gamma, beta=self.bn_layer.beta, epsilon=self.bn_layer.epsilon) self._quantize_weights(training) outputs = self._run_folded_conv(inputs, training) # Bias Add outputs = self._run_folded_bias_add(outputs) # Quantize activation for quantize_activation in self._quantize_activations: quantize_activation.training = training self.quantize_config.set_quantize_activations(self, self._quantize_activations) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs, training=None): if training is None: training = tf.keras.backend.learning_phase() # Quantize all weights, and replace them in the underlying layer. quantized_weights = [] for unquantized_weight, quantizer, quantizer_vars in self._weight_vars: quantized_weight = common_utils.smart_cond( training, self._make_quantizer_fn(quantizer, unquantized_weight, True, self.mode, quantizer_vars), self._make_quantizer_fn(quantizer, unquantized_weight, False, self.mode, quantizer_vars)) quantized_weights.append(quantized_weight) self.quantize_config.set_quantize_weights(self.layer, quantized_weights) # Quantize all biases, and replace them in the underlying layer. quantized_biases = [] for unquantized_bias, quantizer, quantizer_vars in self._bias_vars: quantized_bias = common_utils.smart_cond( training, self._make_quantizer_fn(quantizer, unquantized_bias, True, self.mode, quantizer_vars), self._make_quantizer_fn(quantizer, unquantized_bias, False, self.mode, quantizer_vars)) quantized_biases.append(quantized_bias) self.quantize_config.set_quantize_biases(self.layer, quantized_biases) # Replace all activations with `QuantizeAwareActivation`s which can # quantize activation tensors during graph construction. for quantize_activation in self._quantize_activations: quantize_activation.training = training self.quantize_config.set_quantize_activations( self.layer, self._quantize_activations) args = tf_inspect.getfullargspec(self.layer.call).args if 'training' in args: outputs = self.layer.call(inputs, training=training) else: outputs = self.layer.call(inputs) if not self._output_quantizer_vars: return outputs # Handle layers with multiple outputs if isinstance(outputs, list) or isinstance(outputs, tuple): quantized_outputs = outputs for output_id, output_quantizer, output_quantizer_vars in self._output_quantizer_vars: quantized_outputs[output_id] = common_utils.smart_cond( training, self._make_quantizer_fn(output_quantizer, outputs[output_id], True, self.mode, output_quantizer_vars), self._make_quantizer_fn(output_quantizer, outputs[output_id], False, self.mode, output_quantizer_vars)) return quantized_outputs output_id, output_quantizer, output_quantizer_vars = self._output_quantizer_vars[ 0] return common_utils.smart_cond( training, self._make_quantizer_fn(output_quantizer, outputs, True, self.mode, output_quantizer_vars), self._make_quantizer_fn(output_quantizer, outputs, False, self.mode, output_quantizer_vars))