Beispiel #1
0
    def __call__(self, inputs, *args, **kwargs):
        def make_quantizer_fn(training, x, quantizer_vars):
            """Use currying to return True/False specialized fns to the cond."""
            def quantizer_fn(x=x,
                             quantizer=self.quantizer,
                             quantizer_vars=quantizer_vars):
                return quantizer(x, training, weights=quantizer_vars)

            return quantizer_fn

        x = inputs
        if self._should_pre_quantize():
            x = utils.smart_cond(
                self._training,
                make_quantizer_fn(True, x, self._pre_activation_vars),
                make_quantizer_fn(False, x, self._pre_activation_vars))

        x = self.activation(x, *args, **kwargs)

        if self._should_post_quantize():
            x = utils.smart_cond(
                self._training,
                make_quantizer_fn(True, x, self._post_activation_vars),
                make_quantizer_fn(False, x, self._post_activation_vars))

        return x
Beispiel #2
0
    def call(self, inputs, training=None, **kwargs):
        if training is None:
            training = tf.keras.backend.learning_phase()

        # Quantize all weights, and replace them in the underlying layer.

        quantized_weights = []
        for unquantized_weight, quantizer, quantizer_vars in self._weight_vars:
            quantized_weight = utils.smart_cond(
                training,
                self._make_quantizer_fn(quantizer, unquantized_weight, True,
                                        quantizer_vars),
                self._make_quantizer_fn(quantizer, unquantized_weight, False,
                                        quantizer_vars))
            quantized_weights.append(quantized_weight)

        self.quantize_config.set_quantize_weights(self.layer,
                                                  quantized_weights)

        # Replace all activations with `QuantizeAwareActivation`s which can
        # quantize activation tensors during graph construction.

        for quantize_activation in self._quantize_activations:
            quantize_activation.training = training

        self.quantize_config.set_quantize_activations(
            self.layer, self._quantize_activations)

        args = tf_inspect.getfullargspec(self.layer.call).args
        if 'training' in args:
            outputs = self.layer.call(inputs, training=training, **kwargs)
        else:
            outputs = self.layer.call(inputs, **kwargs)

        if not self._output_quantizers:
            return outputs

        # Assuming outputs is a single tensor. There might be some rare layers
        # where this is not true. Handle them when enabling such a layer.
        if isinstance(outputs, list) or isinstance(outputs, tuple):
            raise RuntimeError(
                'Multiple output tensors not handled currently.')

        output_quantizer = self._output_quantizers[0]
        return utils.smart_cond(
            training,
            self._make_quantizer_fn(output_quantizer, outputs, True,
                                    self._output_quantizer_vars),
            self._make_quantizer_fn(output_quantizer, outputs, False,
                                    self._output_quantizer_vars))
Beispiel #3
0
    def call(self, inputs, training=None, **kwargs):
        if training is None:
            training = K.learning_phase()

        def increment_step():
            with tf.control_dependencies(
                [tf_compat.assign(self.pruning_step, self.pruning_step + 1)]):
                return tf.no_op('update')

        def add_update():
            with tf.control_dependencies([
                    tf.debugging.assert_greater_equal(
                        self.pruning_step,
                        np.int64(1),
                        message=self._PRUNE_CALLBACK_ERROR_MSG)
            ]):
                with tf.control_dependencies(
                    [self.pruning_obj.conditional_mask_update()]):
                    return tf.no_op('update')

        def no_op():
            return tf.no_op('no_update')

        # Increment the 'pruning_step' after each step.
        update_pruning_step = utils.smart_cond(training, increment_step, no_op)
        self.add_update(update_pruning_step)

        # Update mask tensor after each 'pruning_frequency' steps.
        update_mask = utils.smart_cond(training, add_update, no_op)
        self.add_update(update_mask)

        # Always execute the op that performs weights = weights * mask
        # Relies on UpdatePruningStep callback to ensure the weights
        # are sparse after the final backpropagation.
        #
        # self.add_update does nothing during eager execution.
        self.add_update(self.pruning_obj.weight_mask_op())
        # TODO(evcu) remove this check after dropping py2 support. In py3 getargspec
        # is deprecated.
        if hasattr(inspect, 'getfullargspec'):
            args = inspect.getfullargspec(self.layer.call).args
        else:
            args = inspect.getargspec(self.layer.call).args
        # Propagate the training bool to the underlying layer if it accepts
        # training as an arg.
        if 'training' in args:
            return self.layer.call(inputs, training=training, **kwargs)

        return self.layer.call(inputs, **kwargs)
Beispiel #4
0
    def call(self, inputs, training=None):
        if training is None:
            training = tf.keras.backend.learning_phase()

        def _make_quantizer_fn(train_var):
            def quantizer_fn():
                return self.quantizer(inputs,
                                      train_var,
                                      weights=self.quantizer_vars)

            return quantizer_fn

        return utils.smart_cond(training, _make_quantizer_fn(True),
                                _make_quantizer_fn(False))
Beispiel #5
0
    def _apply_weight_quantizer(self, training, folded_conv_kernel):
        """All Keras call() logic for applying weight quantization."""
        def make_quantizer_fn(training):
            """Return quantizer conditioned on whether training or not."""
            def quantizer_fn():
                return self.weight_quantizer(
                    folded_conv_kernel,
                    training,
                    weights=self._weight_quantizer_vars)  # pylint: disable=protected-access

            return quantizer_fn

        return utils.smart_cond(training, make_quantizer_fn(True),
                                make_quantizer_fn(False))
Beispiel #6
0
    def _apply_activation_quantizer(self, training, activation_output):
        """All Keras call() logic for applying weight quantization."""
        def make_quantizer_fn(training):
            """Return quantizer conditioned on whether training or not."""
            def quantizer_fn():
                weights = {
                    'min_var': self._activation_min_var,  # pylint: disable=protected-access
                    'max_var': self._activation_max_var
                }  # pylint: disable=protected-access
                return self.activation_quantizer(activation_output,
                                                 training,
                                                 weights=weights)

            return quantizer_fn

        return utils.smart_cond(training, make_quantizer_fn(True),
                                make_quantizer_fn(False))