Esempio n. 1
0
    def test_annotation_only_changes_hlo_metadata_conv(self, weight_prec,
                                                       acts_prec):
        FLAGS.metadata_enabled = False
        quant_act = quantization.QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            prec=acts_prec,
            bounds=1.0)
        input_shape = (1, 8, 8, 3)
        module_no_annotation = aqt_flax_layers.ConvAqt(
            features=4,
            kernel_size=(3, 3),
            padding='VALID',
            paxis_name='batch',
            quant_context=quant_config.QuantContext(update_bounds=False),
            train=False,
            hparams=aqt_flax_layers.ConvAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant),
            kernel_init=initializers.ones,
            bias_init=initializers.ones,
            dtype=jnp.float32)

        init_state = module_no_annotation.init(
            self.rng_key, jnp.ones(input_shape, jnp.float32))
        output_no_annotation = module_no_annotation.apply(
            init_state, jnp.ones(input_shape))

        hlo_no_annotation = hlo_utils.load_hlo_proto_from_model(
            module_no_annotation, init_state, [input_shape])
        del init_state

        FLAGS.metadata_enabled = True
        module_w_annotation = aqt_flax_layers.ConvAqt(
            features=4,
            kernel_size=(3, 3),
            padding='VALID',
            paxis_name='batch',
            quant_context=quant_config.QuantContext(update_bounds=False),
            train=False,
            hparams=aqt_flax_layers.ConvAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant),
            kernel_init=initializers.ones,
            bias_init=initializers.ones,
            dtype=jnp.float32)

        init_state = module_w_annotation.init(
            self.rng_key, jnp.ones(input_shape, jnp.float32))
        output_w_annotation = module_w_annotation.apply(
            init_state, jnp.ones(input_shape))

        hlo_w_annotation = hlo_utils.load_hlo_proto_from_model(
            module_w_annotation, init_state, [input_shape])
        del init_state

        onp.testing.assert_array_equal(output_no_annotation,
                                       output_w_annotation)
        self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
Esempio n. 2
0
    def __call__(
        self,
        inputs,
    ):
        """Applies ResNet model. Number of residual blocks inferred from hparams."""
        num_classes = self.num_classes
        hparams = self.hparams
        num_filters = self.num_filters
        dtype = self.dtype

        x = aqt_flax_layers.ConvAqt(
            features=num_filters,
            kernel_size=(7, 7),
            strides=(2, 2),
            padding=[(3, 3), (3, 3)],
            use_bias=False,
            dtype=dtype,
            name='init_conv',
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.conv_init,
        )(inputs)
        x = nn.BatchNorm(use_running_average=not self.train,
                         momentum=0.9,
                         epsilon=1e-5,
                         dtype=dtype,
                         name='init_bn')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        filter_multiplier = hparams.filter_multiplier
        for i, block_hparams in enumerate(hparams.residual_blocks):
            proj = block_hparams.conv_proj
            # For projection layers (unless it is the first layer), strides = (2, 2)
            if i > 0 and proj is not None:
                filter_multiplier *= 2
                strides = (2, 2)
            else:
                strides = (1, 1)
            x = ResidualBlock(filters=int(num_filters * filter_multiplier),
                              hparams=block_hparams,
                              quant_context=self.quant_context,
                              strides=strides,
                              train=self.train,
                              dtype=dtype)(x)
        x = jnp.mean(x, axis=(1, 2))

        x = aqt_flax_layers.DenseAqt(
            features=num_classes,
            dtype=dtype,
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.dense_layer,
        )(x, padding_mask=None)

        x = jnp.asarray(x, dtype)
        output = nn.log_softmax(x)
        return output
Esempio n. 3
0
 def __call__(self,
              inputs,
              hparams,
              kernel_size,
              num_filters,
              strides,
              dtype=jnp.float32):
     output = aqt_flax_layers.ConvAqt(
         features=num_filters,
         kernel_size=kernel_size,
         strides=strides,
         use_bias=False,
         dtype=dtype,
         train=False,
         quant_context=quant_config.QuantContext(update_bounds=False),
         paxis_name='batch',
         hparams=hparams)(inputs)
     return output
 def test_group_conv(self, weight_prec=None):
   x = jnp.ones((1, 8, 8, 4))
   conv_module = flax_layers.ConvAqt(
       features=4,
       kernel_size=(3, 3),
       feature_group_count=2,
       padding='VALID',
       paxis_name='batch',
       quant_context=quant_config.QuantContext(update_bounds=False),
       train=False,
       hparams=flax_layers.ConvAqt.HParams(
           weight_prec=weight_prec,
           quant_act=None,
           quant_type=QuantType.fake_quant),
       kernel_init=initializers.ones,
       bias_init=initializers.ones,
       dtype=jnp.float32)
   y, state = conv_module.init_with_output(self.rng_key, x)
   self.assertEqual(state['params']['kernel'].shape, (3, 3, 2, 4))
   test_utils.assert_all_close_prec(y, onp.full((1, 6, 6, 4), 19.),
                                    weight_prec)
Esempio n. 5
0
  def __call__(
      self,
      inputs,
  ):
    """Applies ResNet model. Number of residual blocks inferred from hparams."""
    num_classes = self.num_classes
    hparams = self.hparams
    num_filters = self.num_filters
    dtype = self.dtype
    assert hparams.act_function in act_function_zoo.keys(
    ), 'Activation function type is not supported.'

    x = aqt_flax_layers.ConvAqt(
        features=num_filters,
        kernel_size=(7, 7),
        strides=(2, 2),
        padding=[(3, 3), (3, 3)],
        use_bias=False,
        dtype=dtype,
        name='init_conv',
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.conv_init,
    )(
        inputs)
    x = nn.BatchNorm(
        use_running_average=not self.train,
        momentum=0.9,
        epsilon=1e-5,
        dtype=dtype,
        name='init_bn')(
            x)
    if hparams.act_function == 'relu':
      x = nn.relu(x)
      x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    else:
      # TODO(yichi): try adding other activation functions here
      # Use avg pool so that for binary nets, the distribution is symmetric.
      x = nn.avg_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    filter_multiplier = hparams.filter_multiplier
    for i, block_hparams in enumerate(hparams.residual_blocks):
      proj = block_hparams.conv_proj
      # For projection layers (unless it is the first layer), strides = (2, 2)
      if i > 0 and proj is not None:
        filter_multiplier *= 2
        strides = (2, 2)
      else:
        strides = (1, 1)
      x = ResidualBlock(
          filters=int(num_filters * filter_multiplier),
          hparams=block_hparams,
          quant_context=self.quant_context,
          strides=strides,
          train=self.train,
          dtype=dtype)(
              x)
    if hparams.act_function == 'none':
      # The DenseAQT below is not binarized.
      # If removing the activation functions, there will be no act function
      # between the last residual block and the dense layer.
      # So add a ReLU in that case.
      # TODO(yichi): try BPReLU
      x = nn.relu(x)
    else:
      pass
    x = jnp.mean(x, axis=(1, 2))

    x = aqt_flax_layers.DenseAqt(
        features=num_classes,
        dtype=dtype,
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.dense_layer,
    )(x, padding_mask=None)

    x = jnp.asarray(x, dtype)
    output = nn.log_softmax(x)
    return output