Exemplo n.º 1
0
 def setup(self):
     self.embedding = self.param(
         'embedding',
         self.embedding_init,  # pylint: disable=missing-from-attributes
         (self.num_embeddings, self.features))
     hparams = self.hparams
     if hparams.quant_act is not None and isinstance(
             hparams.quant_act.bounds, get_bounds.GetBounds.Hyper):
         self.get_bounds_logits = get_bounds.GetBounds(  # pylint: disable=missing-from-attributes
             hyper=self.hparams.quant_act.bounds)
     self.quantized_dot = quantization.QuantizedDot(  # pylint: disable=missing-from-attributes
         act_hparams=hparams.quant_act,
         quant_type=hparams.quant_type,
         dot_precision=None,
         prefer_int8_to_int32_dot=self.quant_context.
         prefer_int8_to_int32_dot,
         weight_params=QuantOps.WeightParams(
             prec=hparams.weight_prec,
             axis=(0, ),
             expected_scale_shape=(1, self.embedding.shape[0])))
 def init_model(self,
                update_bounds,
                update_stats=True,
                reset_stats=False,
                use_cams=False,
                granularity=quant_config.QuantGranularity.per_tensor,
                ema_coeff=None):
   self.hyperparam = get_bounds.GetBounds.Hyper(
       initial_bound=self.hyperparam.initial_bound,
       stddev_coeff=self.hyperparam.stddev_coeff,
       absdev_coeff=self.hyperparam.absdev_coeff,
       mix_coeff=self.hyperparam.mix_coeff,
       reset_stats=reset_stats,
       use_cams=use_cams,
       ema_coeff=ema_coeff,
       granularity=granularity)
   gb_bounds_params = get_bounds.GetBounds.Params(
       update_bounds=update_bounds, update_stats=update_stats)
   bounds_module = get_bounds.GetBounds(hyper=self.hyperparam)
   init_state = bounds_module.init(
       self.key2, self.x, bounds_params=gb_bounds_params)
   return bounds_module, init_state, gb_bounds_params
Exemplo n.º 3
0
    def create_input_ops(cls, inputs, *, hparams, get_bounds_params):
        """Create a QuantOps that can quantize and dequantize an activation tensor.

    Args:
      inputs: The inputs to quantize.
      hparams: Input hyperparameter (ActHParams).
      get_bounds_params: GetBoundsParams. Parameters for GetBounds.

    Returns:
      Quantized and rescaled inputs using fake quant approach.
    """

        # TODO(shivaniagrawal): investigate why pytype allows types other than
        # ActsBoundT.
        if isinstance(hparams.bounds, int):
            hparams.bounds = float(hparams.bounds)

        # NOTE: if flax module name is None, default name is used.

        # If we want to train with no quantization at first and then turn on
        # GetBounds quantization, we still have to call GetBounds even before
        # quantization is enabled since GetBounds calculates and stores the running
        # statistics that we will use once quantization is enabled. But before
        # quantization is enabled, we want to ignore the returned bounds and just
        # return the original unquantized input. To do so, we take advantage of the
        # fact that GetBounds returns a constant fixed bound for an initial time
        # period and set that initial bound to a special value (-1) to indicate we
        # want to store activation statistics without applying quantization. That
        # will cause clip_bounds will be a tensor of all '-1', which we will check
        # for in a lax.cond call below.

        # TODO(malmaud): Refactor code to separate bounds calculation from tracking
        # activation statistics to avoid the need to rely on special bounds values
        # when disabling quantization.
        if isinstance(hparams.bounds, get_bounds.GetBounds.Hyper):
            if not get_bounds_params:
                raise ValueError(
                    'act_hparams.bounds is of type GetBounds.Hyper, user must '
                    'provide get_bounds_params, parameters for GetBounds.')
            clip_bounds = get_bounds.GetBounds(
                hyper=hparams.bounds, name=get_bounds_params.module_name)(
                    inputs,
                    bounds_params=get_bounds_params,
                )
        elif isinstance(hparams.bounds, (float, jnp.ndarray)):
            clip_bounds = hparams.bounds
        else:
            assert False, (
                '%s is not a valid type for hparams.bounds, should be float, a list '
                'of floats, or GetBounds.Hyper.' % (type(hparams.bounds)))

        if isinstance(hparams.prec, _FloatQuant):
            ops = cls.create_symmetric_fp(bounds=clip_bounds,
                                          fp_quant=hparams.prec)
        elif hparams.input_distribution == cls.ActHParams.InputDistribution.symmetric:
            ops = cls.create_symmetric(bounds=clip_bounds, prec=hparams.prec)
        elif hparams.input_distribution == cls.ActHParams.InputDistribution.positive:
            ops = cls.create_positive(bounds=clip_bounds, prec=hparams.prec)
        else:
            assert False, "can't happen."

        if get_bounds_params and get_bounds_params.expected_bounds_shape is not None:
            if isinstance(hparams.bounds, get_bounds.GetBounds.Hyper):
                ops.assert_scale_shape_is(
                    shape=get_bounds_params.expected_bounds_shape)
            else:
                logging.info(
                    'Ignoring value of argument expected_scale_shape. Scale for fixed '
                    'bounds would be scalar.')
        return ops