Example #1
0
    def test_annotation_only_changes_hlo_metadata_conv(self, weight_prec,
                                                       acts_prec):
        FLAGS.metadata_enabled = False
        quant_act = quantization.QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            prec=acts_prec,
            bounds=1.0)
        input_shape = (1, 8, 8, 3)
        module_no_annotation = aqt_flax_layers.ConvAqt(
            features=4,
            kernel_size=(3, 3),
            padding='VALID',
            paxis_name='batch',
            quant_context=quant_config.QuantContext(update_bounds=False),
            train=False,
            hparams=aqt_flax_layers.ConvAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant),
            kernel_init=initializers.ones,
            bias_init=initializers.ones,
            dtype=jnp.float32)

        init_state = module_no_annotation.init(
            self.rng_key, jnp.ones(input_shape, jnp.float32))
        output_no_annotation = module_no_annotation.apply(
            init_state, jnp.ones(input_shape))

        hlo_no_annotation = hlo_utils.load_hlo_proto_from_model(
            module_no_annotation, init_state, [input_shape])
        del init_state

        FLAGS.metadata_enabled = True
        module_w_annotation = aqt_flax_layers.ConvAqt(
            features=4,
            kernel_size=(3, 3),
            padding='VALID',
            paxis_name='batch',
            quant_context=quant_config.QuantContext(update_bounds=False),
            train=False,
            hparams=aqt_flax_layers.ConvAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant),
            kernel_init=initializers.ones,
            bias_init=initializers.ones,
            dtype=jnp.float32)

        init_state = module_w_annotation.init(
            self.rng_key, jnp.ones(input_shape, jnp.float32))
        output_w_annotation = module_w_annotation.apply(
            init_state, jnp.ones(input_shape))

        hlo_w_annotation = hlo_utils.load_hlo_proto_from_model(
            module_w_annotation, init_state, [input_shape])
        del init_state

        onp.testing.assert_array_equal(output_no_annotation,
                                       output_w_annotation)
        self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
Example #2
0
 def test_softmax_vs_original(self, input_tensor, softmax_hparams):
     dtype = jax._src.numpy.lax_numpy.float32
     norm_dims = (0, )
     input_tensor = jnp.array(input_tensor)
     output = flax_attention.softmax(
         input_tensor, norm_dims, dtype, softmax_hparams,
         quant_config.QuantContext(update_bounds=False, quantize_acts=True))
     expected_output = flax_attention.softmax(
         input_tensor, norm_dims, dtype, SoftmaxHParams(None, None, None),
         quant_config.QuantContext(update_bounds=False, quantize_acts=True))
     self.assertAllClose(expected_output, output, atol=1e-8)
Example #3
0
 def test_multihead_encoder_decoder_attention(self, weight_prec):
     rng = random.PRNGKey(0)
     q = jnp.ones((4, 3, 5))
     kv = jnp.ones((4, 3, 5))
     sa_module = flax_attention.MultiHeadDotProductAttentionAqt(
         num_heads=8,
         hparams=self.construct_hparams(weight_prec),
         attention_axis=(1, ),
         quant_context=quant_config.QuantContext(update_bounds=False,
                                                 collect_acts_stats=False),
         train=False,
         paxis_name=None,
         qkv_features=16,
         kernel_init=initializers.ones,
         bias_init=initializers.zeros,
         dtype=jnp.float32,
         causal_mask=False,
         dropout_rate=0.0,
         deterministic=False,
         decode=False)
     y, _ = sa_module.init_with_output(rng,
                                       q,
                                       kv,
                                       padding_mask=None,
                                       key_padding_mask=None)
     self.assertEqual(y.shape, q.shape)
Example #4
0
 def init_model_with_1_layer(self,
                             inputs,
                             num_features,
                             kernel_size,
                             kernel_init=flax_layers.default_kernel_init,
                             weight_prec=None,
                             quant_act=None,
                             weight_half_shift=False):
   """Create and initialize a flax model with a single ConvAqt layer."""
   layer_kwargs = {
       'kernel_init': kernel_init,
       'features': num_features,
       'use_bias': False,
       'quant_context': quant_config.QuantContext(update_bounds=False),
       'paxis_name': 'batch',
       'train': False,
       'kernel_size': kernel_size,
       'dtype': jnp.float32
   }
   layer_class = flax_layers.ConvAqt
   layer_kwargs['hparams'] = flax_layers.ConvAqt.HParams(
       weight_prec=weight_prec,
       quant_act=quant_act,
       quant_type=QuantType.fake_quant,
       weight_half_shift=weight_half_shift,
   )
   conv_module = layer_class(**layer_kwargs)
   initial_state = conv_module.init(self.rng_key, jnp.zeros(inputs.shape))
   return conv_module, initial_state
Example #5
0
 def test_custom_softmax_vs_mock(self, input_tensor, norm_dims,
                                 softmax_hparams, expected_output):
   dtype = jax._src.numpy.lax_numpy.float32
   output = flax_attention.softmax(
       input_tensor, norm_dims, dtype, softmax_hparams,
       quant_config.QuantContext(update_bounds=False, quantize_acts=False))
   self.assertAllClose(expected_output, output, atol=1e-6)
Example #6
0
def get_quant_context_for_step(
    *,
    activation_bound_update_freq,
    activation_bound_start_step,
    step,
    collect_acts_stats,
    prefer_int8_to_int32_dot):
  """Returns correct quantization context for a given step.

  Args:
    activation_bound_update_freq: How frequently to update bounds after the
      initial bounds update. A value of '-1' indicates to not update the bounds
      after the first update.
    activation_bound_start_step: The first step to update bounds on. '-1'
      indicates to never update bounds.
    step: The current training step.
    collect_acts_stats: Whether to collect activation statistics.
    prefer_int8_to_int32_dot: Whether to feed lax.dot inputs with an int8 dtype
      and accumulate to int32.

  Returns:
    A quant_config.QuantContext instance.
  """
  update_bounds = should_update_bounds(
      activation_bound_start_step=activation_bound_start_step,
      activation_bound_update_freq=activation_bound_update_freq,
      step=step)
  quantize_acts = step >= activation_bound_start_step
  return quant_config.QuantContext(
      update_bounds=update_bounds,
      quantize_acts=quantize_acts,
      collect_acts_stats=collect_acts_stats,
      prefer_int8_to_int32_dot=prefer_int8_to_int32_dot,
  )
Example #7
0
  def test_embed(self, weight_prec):
    # Since the dummy embedding matrix has a row of all zeros, we need 'epsilon'
    # to be added to it before calculating scale factors.
    quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = False
    rng = random.PRNGKey(0)
    x = jnp.arange(4)[None]
    dummy_embedding = jnp.broadcast_to(jnp.arange(4)[Ellipsis, None],
                                       (4, 3)).astype(jnp.float32)
    embed_module = flax_layers.EmbedAqt(
        num_embeddings=4,
        features=3,
        dtype=jnp.float32,
        hparams=flax_layers.EmbedAqt.HParams(
            weight_prec=weight_prec,
            quant_act=None,
            quant_type=QuantType.fake_quant,
            weight_half_shift=False),
        embedding_init=lambda _rng, _shape: dummy_embedding,
        train=False,
        paxis_name=None,
        quant_context=quant_config.QuantContext(update_bounds=False),
    )
    y, state = embed_module.init_with_output(rng, x)
    test_utils.assert_all_close_prec(dummy_embedding[None], y, weight_prec)

    z = embed_module.apply(
        state, jnp.ones((1, 3)), padding_mask=None, method=embed_module.attend)
    test_utils.assert_all_close_prec(3. * jnp.arange(4), z[0, Ellipsis], weight_prec)
Example #8
0
  def test_embed_equality(self, weight_prec):
    rng = random.PRNGKey(0)
    x = 2 * jnp.ones(4, dtype=jnp.int32)[None]
    dummy_embedding = 2 * jnp.ones((4, 2)).astype(jnp.float32)
    embed_module = flax_layers.EmbedAqt(
        num_embeddings=4,
        features=2,
        dtype=jnp.float32,
        hparams=flax_layers.EmbedAqt.HParams(
            weight_prec=weight_prec,
            quant_act=None,
            quant_type=QuantType.fake_quant,
            weight_half_shift=False),
        embedding_init=lambda _rng, _shape: dummy_embedding,
        train=False,
        quant_context=quant_config.QuantContext(update_bounds=False),
        paxis_name=None)
    y, init_state = embed_module.init_with_output(rng, x)
    onp.testing.assert_array_equal(dummy_embedding[None], y)

    z = embed_module.apply(
        init_state,
        jnp.ones((1, 2)),
        padding_mask=None,
        method=embed_module.attend)
    onp.testing.assert_array_equal(2. * (2 * jnp.ones(4)), z[0, Ellipsis])
Example #9
0
  def init_model_with_1_layer(self,
                              inputs,
                              num_features,
                              kernel_init=flax_layers.default_kernel_init,
                              weight_prec=None,
                              quant_act=None,
                              weight_half_shift=False):
    """Create and initialize a flax model with a single DenseAqt layer."""
    quant_context = quant_config.QuantContext(
        update_bounds=False, collect_acts_stats=False)
    layer_kwargs = {
        'kernel_init': kernel_init,
        'features': num_features,
        'use_bias': False,
        'quant_context': quant_context,
        'paxis_name': 'batch',
        'train': False,
        'dtype': jnp.float32
    }
    layer_kwargs['hparams'] = flax_layers.DenseAqt.HParams(
        weight_prec=weight_prec,
        quant_act=quant_act,
        quant_type=QuantType.fake_quant,
        weight_quant_granularity=quant_config.QuantGranularity.per_channel,
        weight_half_shift=weight_half_shift)

    dense_module = flax_layers.DenseAqt(**layer_kwargs)
    initial_state = dense_module.init(
        self.rng_key, jnp.zeros(inputs.shape), padding_mask=None)
    return dense_module, initial_state
Example #10
0
 def test_multihead_self_attention_w_dropout(self, weight_prec):
     rng = random.PRNGKey(0)
     x = jnp.ones((4, 3, 5))
     sa_module = flax_attention.SelfAttentionAqt(
         num_heads=8,
         hparams=self.construct_hparams(weight_prec),
         attention_axis=(1, ),
         quant_context=quant_config.QuantContext(update_bounds=False,
                                                 collect_acts_stats=False),
         train=False,
         paxis_name=None,
         qkv_features=16,
         kernel_init=initializers.ones,
         bias_init=initializers.zeros,
         dropout_rate=0.1,
         dtype=jnp.float32,
         causal_mask=False,
         deterministic=False,
         decode=False)
     rng_dropout, rng_params = random.split(rng)
     y, _ = sa_module.init_with_output(
         {
             'dropout': rng_dropout,
             'params': rng_params
         },
         x,
         padding_mask=None)
     self.assertEqual(y.shape, x.shape)
def get_quant_context_for_step(*, activation_bound_update_freq,
                               activation_bound_start_step, step,
                               collect_acts_stats):
    """Returns correct quantization context for a given step.

  Args:
    activation_bound_update_freq: How frequently to update bounds after the
      initial bounds update. A value of '-1' indicates to not update the bounds
      after the first update.
    activation_bound_start_step: The first step to update bounds on. '-1'
      indicates to never update bounds.
    step: The current training step.
    collect_acts_stats: Whether to collect activation statistics.

  Returns:
    A quant_config.QuantContext instance.
  """
    update_bounds = should_update_bounds(
        activation_bound_start_step=activation_bound_start_step,
        activation_bound_update_freq=activation_bound_update_freq,
        step=step)
    quantize_acts = step >= activation_bound_start_step
    # TODO(shivaniagrawal): We hardcode this to False to force the inputs to
    # lax.dot_general to be floating-point type. Otherwise, training diverges for
    # 8bit quantization for unnknown reasons.
    prefer_int8_to_int32_dot = False
    return quant_config.QuantContext(
        update_bounds=update_bounds,
        quantize_acts=quantize_acts,
        collect_acts_stats=collect_acts_stats,
        prefer_int8_to_int32_dot=prefer_int8_to_int32_dot)
  def test_padding(self):
    """Test that padding results in the right statistics being collected."""
    # Exact values don't matter here, we just need code to think it's using
    # dynamic bounds so it gathers activation statistics
    bounds = get_bounds.GetBounds.Hyper(
        initial_bound=0.0,
        stddev_coeff=1.0,
        absdev_coeff=0.0,
        mix_coeff=1.0,
        reset_stats=False,
        granularity=quant_config.QuantGranularity.per_channel)
    quant_act = flax_layers.QuantOps.ActHParams(
        input_distribution=flax_layers.QuantOps.ActHParams.InputDistribution
        .symmetric,
        prec=8,
        bounds=bounds)
    hparams = flax_layers.DenseAqt.HParams(
        quant_type=flax_layers.QuantType.fake_quant,
        weight_prec=8,
        quant_act=quant_act,
        weight_quant_granularity=quant_config.QuantGranularity.per_channel)
    module = flax_layers.DenseAqt(
        hparams=hparams,
        features=1,
        paxis_name=None,
        quant_context=quant_config.QuantContext(
            update_bounds=True, collect_acts_stats=False),
        train=True,
        dtype=jnp.float32)

    # Simulate an input with a batch size of 2, three tokens per example, two
    # channels per token
    x = jnp.arange(12).astype(jnp.float32).reshape((2, 3, 2))
    # Reshape it to have dimensions [batch, feature]
    x = x.reshape(6, 2)

    initial_state = module.init(self.rng_key, x, padding_mask=None)

    # Check that the per-channel activation statistics are as expected with no
    # padding
    _, state_nopadding = module.apply(
        initial_state, x, padding_mask=None, mutable='get_bounds')
    expected_means = onp.array([[(0 + 2 + 4 + 6 + 8 + 10) / 6,
                                 (1 + 3 + 5 + 7 + 9 + 11) / 6]])
    actual_means = state_nopadding['get_bounds']['GetBounds_0']['stats'].mean
    onp.testing.assert_allclose(actual_means, expected_means)

    # Now we pad out some of the tokens (chosen arbitrarily) and check that the
    # computed per-channel stats are the means of the non-padding tokens only
    # Exclude the second and third tokens from the first batch and the first
    # token from the second batch.
    padding_mask = jnp.array([[True, False, False], [False, True, True]])
    # Reshape it to have dimensions [batch, feature]
    padding_mask = padding_mask.reshape(6, 1)
    _, state_padding = module.apply(
        initial_state, x, padding_mask=padding_mask, mutable='get_bounds')
    expected_means = onp.array([[(0 + 8 + 10) / 3, (1 + 9 + 11) / 3]])
    actual_means = state_padding['get_bounds']['GetBounds_0']['stats'].mean
    onp.testing.assert_allclose(actual_means, expected_means)
Example #13
0
def create_resnet(hparams, dtype, train, **kwargs):
    return ResNet(num_classes=1000,
                  dtype=dtype,
                  hparams=hparams,
                  quant_context=quant_config.QuantContext(
                      update_bounds=False, quantize_weights=True),
                  num_filters=64,
                  train=train,
                  **kwargs)
Example #14
0
    def test_autoregresive_receptive_field_1d(self, weight_prec):
        """Tests the autoregresive self-attention receptive field."""
        rng = random.PRNGKey(0)
        rng1, rng2 = random.split(rng, num=2)

        def model_loss(inputs, pos):
            out = module.apply(initial_vars, inputs, padding_mask=None)
            assert out.shape == input_shape
            assert len(out.shape) == 3
            return out[0, pos, :].sum()

        grad_fn = jax.jit(jax.grad(model_loss))

        def get_receptive_field_1d(pos):
            g = grad_fn(inputs, pos)[0, :, :]
            return jnp.any((jnp.abs(g) > 1e-5).astype(jnp.uint32), axis=-1)

        length = 10
        dim = 1
        num_heads = 1
        input_shape = (1, length, dim)
        inputs = random.normal(rng2, input_shape)

        module = flax_attention.SelfAttentionAqt(
            num_heads=num_heads,
            hparams=self.construct_hparams(weight_prec),
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            train=False,
            paxis_name=None,
            causal_mask=True,
            kernel_init=initializers.ones,
            dtype=jnp.float32,
            qkv_features=None,
            attention_axis=None,
            dropout_rate=0.0,
            deterministic=False,
            decode=False)
        initial_vars = module.init(rng1,
                                   jnp.ones((1, ) + (length, dim),
                                            jnp.float32),
                                   padding_mask=None)
        # model = nn.Model(module, initial_params)

        for i in range(length):
            deps = get_receptive_field_1d(i)
            assert (deps[:i] == 1).all(), (
                'Receptive Field Error: Some of the '
                'previous positions are not reachable '
                'in autoregressive self-attention.')
            if i != length - 1:
                k = i + 1
                assert (deps[k:] == 0).all(), (
                    'Receptive Field Error: Some of the '
                    'future positions are reachable in '
                    'autoregressive self-attention.')
Example #15
0
 def _without_weights(inputs, params):
     return predict.step(inputs,
                         params,
                         cache,
                         state,
                         EOS_TOKEN,
                         decode_length,
                         transformer_kwargs=transformer_kwargs,
                         hparams=model_hparams,
                         quant_context=quant_config.QuantContext(
                             update_bounds=False, quantize_acts=True))
Example #16
0
 def __call__(self, inputs, hparams, num_classes, dtype=jnp.float32):
     output = aqt_flax_layers.DenseAqt(
         features=num_classes,
         dtype=dtype,
         train=False,
         quant_context=quant_config.QuantContext(
             update_bounds=False, collect_acts_stats=False),
         paxis_name='batch',
         hparams=hparams,
     )(inputs, padding_mask=None)
     return output
Example #17
0
 def init_model(self, transformer_kwargs):
     model = models.Transformer(use_bfloat16=False,
                                quant_context=quant_config.QuantContext(
                                    collect_acts_stats=False,
                                    update_bounds=False),
                                dropout_rate=.1,
                                attention_dropout_rate=.1,
                                should_decode=False,
                                **transformer_kwargs)
     state = model.init(self.key, jnp.zeros(self.input_shape, jnp.float32),
                        jnp.zeros(self.target_shape, jnp.float32))
     return model, state
Example #18
0
    def test_decoding(self, weight_prec, spatial_shape, attn_dims):
        bs = 2
        num_heads = 3
        num_features = 4
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        inputs = random.normal(key1, (bs, ) + spatial_shape +
                               (num_heads * num_features, ))
        module = flax_attention.SelfAttentionAqt(
            num_heads=num_heads,
            hparams=self.construct_hparams(weight_prec),
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            train=False,
            paxis_name=None,
            qkv_features=num_heads * num_features,
            attention_axis=attn_dims,
            decode=False,
            causal_mask=True,
            dtype=jnp.float32,
            dropout_rate=0.0,
            deterministic=False)

        initial_vars = module.init(key2, inputs, padding_mask=None)
        y_ref = module.apply(initial_vars, inputs, padding_mask=None)
        module.decode = True
        initial_vars_decode = module.init(key2, inputs, padding_mask=None)
        cache0 = initial_vars_decode['cache']

        def body_fn(cache, x):
            y, new_vars = module.apply({
                **initial_vars, 'cache': cache
            },
                                       x,
                                       mutable='cache',
                                       padding_mask=None)
            return new_vars['cache'], y

        # scan_in_dim supports scanning multiple dims
        _, y = jax_utils.scan_in_dim(body_fn,
                                     cache0,
                                     inputs,
                                     axis=attn_dims,
                                     keepdims=True)

        onp.testing.assert_allclose(y_ref, y, atol=1e-5)
Example #19
0
  def test_embed_should_call_clip_and_round(self, floor_with_gradient,
                                            round_with_gradient, weight_prec,
                                            acts_prec, fixed_bounds):

    round_with_gradient.side_effect = lambda x: x
    floor_with_gradient.side_effect = lambda x: x

    if fixed_bounds:
      bounds = 6.0
    else:
      bounds = get_bounds.GetBounds.Hyper(
          initial_bound=6.0,
          stddev_coeff=3.0,
          absdev_coeff=2.0,
          mix_coeff=0.5,
          granularity=quant_config.QuantGranularity.per_tensor)
    quant_act = quantization.QuantOps.ActHParams(
        input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
        prec=acts_prec,
        bounds=bounds,
        half_shift=False)
    rng = random.PRNGKey(0)
    x = jnp.ones((1, 3))

    embed_module = flax_layers.EmbedAqt(
        num_embeddings=4,
        features=3,
        dtype=jnp.float32,
        hparams=flax_layers.EmbedAqt.HParams(
            weight_prec=weight_prec,
            quant_act=quant_act,
            quant_type=QuantType.fake_quant,
            weight_half_shift=False),
        quant_context=quant_config.QuantContext(update_bounds=False),
        paxis_name=None,
        train=False)
    init_state = embed_module.init(
        rng, x, method=embed_module.attend, padding_mask=None)
    round_with_gradient.reset_mock()
    floor_with_gradient.reset_mock()
    embed_module.apply(
        init_state, x, padding_mask=None, method=embed_module.attend)
    round_with_gradient.assert_called_with(mock.ANY)
    self.assertEqual(round_with_gradient.call_count, 1)
    floor_with_gradient.assert_not_called()
Example #20
0
 def __call__(self,
              inputs,
              hparams,
              kernel_size,
              num_filters,
              strides,
              dtype=jnp.float32):
     output = aqt_flax_layers.ConvAqt(
         features=num_filters,
         kernel_size=kernel_size,
         strides=strides,
         use_bias=False,
         dtype=dtype,
         train=False,
         quant_context=quant_config.QuantContext(update_bounds=False),
         paxis_name='batch',
         hparams=hparams)(inputs)
     return output
Example #21
0
    def test_self_attention_act_quant_should_call_quant_ops(
            self, mock_inputs_fake_quant, attn_act_q, attn_act_k,
            attn_act_probs, attn_act_v, update_bounds, paxis_name, train):

        mock_inputs_fake_quant.side_effect = (
            lambda inputs, hparams, get_bounds_params: inputs)

        rng = random.PRNGKey(0)
        x = jnp.ones((4, 3, 7))
        hparams = self.construct_hparams(attn_act_q, attn_act_k,
                                         attn_act_probs, attn_act_v)
        sa_module = flax_attention.SelfAttentionAqt(
            hparams=hparams,
            num_heads=4,
            quant_context=quant_config.QuantContext(
                update_bounds=update_bounds, collect_acts_stats=False),
            train=train,
            paxis_name=paxis_name,
            attention_axis=None,
            qkv_features=8,
            kernel_init=initializers.ones,
            bias_init=initializers.zeros,
            causal_mask=False,
            dtype=jnp.float32,
            dropout_rate=0.0,
            deterministic=False,
            decode=False)
        sa_module.init(rng, x, padding_mask=None)
        calls = []
        for hparam in [attn_act_q, attn_act_k, attn_act_probs, attn_act_v]:
            if hparam is not None:
                calls.append(
                    unittest.mock.call(
                        unittest.mock.ANY,
                        hparams=hparam,
                        get_bounds_params=get_bounds.GetBounds.Params(
                            update_stats=train,
                            update_bounds=update_bounds,
                            paxis_name=paxis_name,
                            mask=unittest.mock.ANY,
                            module_name=unittest.mock.ANY)))
        mock_inputs_fake_quant.assert_has_calls(calls, any_order=True)

        self.assertLen(calls, mock_inputs_fake_quant.call_count)
 def test_quant_granularity(self, _, mock_quantized_dot, granularity, axis):
   hparams = flax_layers.DenseAqt.HParams(
       weight_prec=8,
       quant_act=None,
       quant_type=quantization.QuantType.fake_quant,
       weight_quant_granularity=granularity)
   layer = flax_layers.DenseAqt(
       features=2,
       hparams=hparams,
       quant_context=quant_config.QuantContext(
           update_bounds=False, collect_acts_stats=False),
       paxis_name=None,
       train=False,
       dtype=jnp.float32)
   x = jnp.ones((2, 2))
   state = layer.init(self.rng_key, x, padding_mask=None)
   layer.apply(state, x, padding_mask=None)
   weight_params = mock_quantized_dot.call_args[1]['weight_params']
   self.assertEqual(weight_params.axis, axis)
Example #23
0
def create_model(key,
                 batch_size,
                 image_size,
                 model_dtype,
                 hparams,
                 train=True,
                 **kwargs):
    """Creates the ResNet model using hparams."""
    input_shape = (batch_size, image_size, image_size, 3)
    model = models.ResNet(
        num_classes=1000,
        dtype=model_dtype,
        hparams=hparams,
        quant_context=quant_config.QuantContext(update_bounds=False),
        num_filters=64,
        train=train,
        **kwargs)
    init_state = model.init(key, jnp.zeros(input_shape, dtype=model_dtype))
    return model, init_state
 def test_group_conv(self, weight_prec=None):
   x = jnp.ones((1, 8, 8, 4))
   conv_module = flax_layers.ConvAqt(
       features=4,
       kernel_size=(3, 3),
       feature_group_count=2,
       padding='VALID',
       paxis_name='batch',
       quant_context=quant_config.QuantContext(update_bounds=False),
       train=False,
       hparams=flax_layers.ConvAqt.HParams(
           weight_prec=weight_prec,
           quant_act=None,
           quant_type=QuantType.fake_quant),
       kernel_init=initializers.ones,
       bias_init=initializers.ones,
       dtype=jnp.float32)
   y, state = conv_module.init_with_output(self.rng_key, x)
   self.assertEqual(state['params']['kernel'].shape, (3, 3, 2, 4))
   test_utils.assert_all_close_prec(y, onp.full((1, 6, 6, 4), 19.),
                                    weight_prec)
Example #25
0
 def test_epsilon_rounding(self):
   # We give LayerNorm a constant input. Since that input has a variance of
   # zero, we would expect layernorm to return NaN (0/0) unless the 'epsilon'
   # parameter which nudges the denominator away from zero was having an
   # effect. We test the case where the default epsilon value of 1e-6 would
   # ordinarily flush to zero after quantization with a high value of exp_min.
   # This test makes sure our code to round epsilon up to the smallest non-zero
   # representable value is wokring.
   hparams = self.make_hparams(
       exp_min=-2**2, exp_max=2**7, sig_bits=23, quantize_reductions=False)
   layer_norm = flax_layers.LayerNormAqt(
       hparams=hparams,
       use_bias=False,
       use_scale=False,
       epsilon=1e-6,
       dtype=jnp.float32,
       quant_context=quant_config.QuantContext(
           update_bounds=False, quantize_acts=True))
   x = jnp.ones((2, 5))
   y = layer_norm.apply({}, x)
   onp.testing.assert_equal(onp.array(y), onp.zeros(x.shape))
Example #26
0
 def test_quantized_layer_norm_matches_unquantized_in_fp32(
     self, quantize_acts, quantize_reductions):
   # We 'quantize' to a custom floating-point format that is approximately
   # equivalent to IEEE float32 and test that results are the same as using
   # Flax's upstream unquantized LayerNorm.
   hparams = self.make_hparams(
       exp_min=-2**7,
       exp_max=2**7,
       sig_bits=23,
       quantize_reductions=quantize_reductions)
   quantized_layer_norm = flax_layers.LayerNormAqt(
       hparams=hparams,
       dtype=jnp.float32,
       quant_context=quant_config.QuantContext(
           update_bounds=False, quantize_acts=quantize_acts))
   x_rng, param_rng = jax.random.split(self.rng)
   x = jax.random.normal(x_rng, (3, 5))
   initial_params = quantized_layer_norm.init(param_rng, x)
   y_quantized = quantized_layer_norm.apply(initial_params, x)
   unquantized_layer_norm = nn.LayerNorm()
   y_unquantized = unquantized_layer_norm.apply(initial_params, x)
   onp.testing.assert_allclose(y_quantized, y_unquantized, rtol=2e-6)
Example #27
0
    def test_annotation_only_changes_hlo_metadata_dense(
            self, weight_prec, acts_prec):
        FLAGS.metadata_enabled = False
        quant_act = quantization.QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            prec=acts_prec,
            bounds=1.0,
            half_shift=False)
        input_shape = (1, 16)
        module_no_annotation = aqt_flax_layers.DenseAqt(
            features=4,
            use_bias=False,
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            paxis_name='batch',
            train=False,
            hparams=aqt_flax_layers.DenseAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_quant_granularity=quant_config.QuantGranularity.
                per_channel,
                weight_half_shift=False),
            dtype=jnp.float32)

        init_state = module_no_annotation.init(self.rng_key,
                                               jnp.ones(
                                                   input_shape, jnp.float32),
                                               padding_mask=None)
        output_no_annotation = module_no_annotation.apply(
            init_state, jnp.ones(input_shape), padding_mask=None)

        hlo_no_annotation = hlo_utils.load_hlo_proto_from_model(
            module_no_annotation, init_state, [input_shape], padding_mask=None)
        del init_state

        FLAGS.metadata_enabled = True
        module_w_annotation = aqt_flax_layers.DenseAqt(
            features=4,
            use_bias=False,
            paxis_name='batch',
            train=False,
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            dtype=jnp.float32,
            hparams=aqt_flax_layers.DenseAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_quant_granularity=quant_config.QuantGranularity.
                per_channel,
                weight_half_shift=False),
        )

        init_state = module_w_annotation.init(self.rng_key,
                                              jnp.ones(input_shape,
                                                       jnp.float32),
                                              padding_mask=None)
        output_w_annotation = module_w_annotation.apply(init_state,
                                                        jnp.ones(input_shape),
                                                        padding_mask=None)

        hlo_w_annotation = hlo_utils.load_hlo_proto_from_model(
            module_w_annotation, init_state, [input_shape], padding_mask=None)
        del init_state

        onp.testing.assert_array_equal(output_no_annotation,
                                       output_w_annotation)
        self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
Example #28
0
    def test_hparams_without_logits_when_logits_not_shared_raises_error(self):
        # Create hparams without logits hparams by passing in
        # logits_via_embeddings=True.
        inputs_hyper = get_bounds.GetBounds.Hyper(
            initial_bound=6.0,
            stddev_coeff=3.0,
            absdev_coeff=2.0,
            mix_coeff=0.5,
            granularity=quant_config.QuantGranularity.per_channel)
        hparams = training_hparams_generator_lib.create_base_transformer_hparams(
            mlp_weight_prec=8,
            embedding_weight_prec=None,
            attention_weight_prec=8,
            mlp_pos_inputs_prec=8,
            mlp_pos_inputs_hyper=inputs_hyper,
            mlp_signed_inputs_prec=8,
            mlp_signed_inputs_hyper=inputs_hyper,
            attention_kqv_inputs_prec=8,
            attention_kqv_inputs_hyper=inputs_hyper,
            attention_out_inputs_prec=8,
            attention_out_inputs_hyper=inputs_hyper,
            logits_inputs_prec=8,
            logits_inputs_hyper=inputs_hyper,
            logits_via_embeddings=True,
            attention_act_q_inputs_prec=8,
            attention_act_q_inputs_hyper=inputs_hyper,
            attention_act_k_inputs_prec=8,
            attention_act_k_inputs_hyper=inputs_hyper,
            attention_act_probs_inputs_prec=8,
            attention_act_v_inputs_prec=8,
            attention_act_v_inputs_hyper=inputs_hyper,
            num_layers=2,
            emb_dim=5,
            num_heads=2,
            qkv_dim=4,
            mlp_dim=4,
            quant_type=QuantType.fake_quant)

        self.assertIsNone(hparams.decoder.logits)

        # Now set logits_via_embedding in the model hparams to False.
        hparams.logits_via_embedding = False
        module = models.Transformer(hparams=hparams,
                                    quant_context=quant_config.QuantContext(
                                        update_bounds=True,
                                        collect_acts_stats=True),
                                    vocab_size=3,
                                    output_vocab_size=3,
                                    max_len=10,
                                    use_bfloat16=False,
                                    train=False,
                                    dropout_rate=.1,
                                    attention_dropout_rate=.1,
                                    should_decode=False)
        key = jax.random.PRNGKey(0)
        # Mark the first token of the target and last token of the inputs as padding
        # tokens.
        targets = onp.array([[0, 2]])
        inputs = onp.array([[1, 0]])
        # Because the model is not sharing logits with embeddings, but the logits
        # hparams are missing, it should raise an error.
        with self.assertRaises(ValueError):
            module.init(key, inputs=inputs, targets=targets)
Example #29
0
    def test_padding_mask(self):
        # Fuzzing test to make sure activation statistics aren't affected by padding
        # tokens.
        #
        # This tests works by changing the embedding of the padding token (token
        # with id '0') and making sure all the stats stay the same.
        #
        # It also tests that the stats *do* change when the embedding of a
        # non-padding token changes.
        inputs_hyper = get_bounds.GetBounds.Hyper(
            initial_bound=6.0,
            stddev_coeff=3.0,
            absdev_coeff=2.0,
            mix_coeff=0.5,
            granularity=quant_config.QuantGranularity.per_channel)
        # Set logits_via_embedding to false so that the embedding of the padding
        # token doesn't affect the logits calculation at the end of the decoder.
        hparams = training_hparams_generator_lib.create_base_transformer_hparams(
            mlp_weight_prec=8,
            embedding_weight_prec=None,
            attention_weight_prec=8,
            mlp_pos_inputs_prec=8,
            mlp_pos_inputs_hyper=inputs_hyper,
            mlp_signed_inputs_prec=8,
            mlp_signed_inputs_hyper=inputs_hyper,
            attention_kqv_inputs_prec=8,
            attention_kqv_inputs_hyper=inputs_hyper,
            attention_out_inputs_prec=8,
            attention_out_inputs_hyper=inputs_hyper,
            logits_inputs_prec=8,
            logits_inputs_hyper=inputs_hyper,
            logits_via_embeddings=False,
            attention_act_q_inputs_prec=8,
            attention_act_q_inputs_hyper=inputs_hyper,
            attention_act_k_inputs_prec=8,
            attention_act_k_inputs_hyper=inputs_hyper,
            attention_act_probs_inputs_prec=8,
            attention_act_v_inputs_prec=8,
            attention_act_v_inputs_hyper=inputs_hyper,
            num_layers=2,
            emb_dim=5,
            num_heads=2,
            qkv_dim=4,
            mlp_dim=4,
            quant_type=QuantType.fake_quant)
        module = models.Transformer(hparams=hparams,
                                    quant_context=quant_config.QuantContext(
                                        update_bounds=True,
                                        collect_acts_stats=True),
                                    vocab_size=3,
                                    output_vocab_size=3,
                                    max_len=10,
                                    train=False,
                                    use_bfloat16=False,
                                    dropout_rate=.1,
                                    attention_dropout_rate=.1,
                                    should_decode=False)
        key = jax.random.PRNGKey(0)
        # Mark the first token of the target and last token of the inputs as padding
        # tokens.
        targets = onp.array([[0, 2]])
        inputs = onp.array([[1, 0]])
        initial_state = module.init(key, inputs=inputs, targets=targets)
        # Change the embedding of the padding token.
        initial_state = initial_state.unfreeze()
        initial_state['params']['shared_embedding'][
            'embedding'] = initial_state['params']['shared_embedding'][
                'embedding'].at[0, :].set(10.0)
        module.train = True
        _, state1 = module.apply(flax.core.freeze(initial_state),
                                 inputs=inputs,
                                 targets=targets,
                                 mutable=True,
                                 rngs={'dropout': key})
        initial_state['params']['shared_embedding'][
            'embedding'] = initial_state['params']['shared_embedding'][
                'embedding'].at[0, :].set(20.0)
        _, state2 = module.apply(flax.core.freeze(initial_state),
                                 inputs=inputs,
                                 targets=targets,
                                 mutable=True,
                                 rngs={'dropout': key})
        # This tests the statistics in both the GetBounds and StatsTag modules.
        test_utils.assert_stats_are_equal(state1, state2)

        # Now we repeat the test, but changing the embedding of a non-padding token
        # (token with ID 1 here). We expect to see the stats change.
        # print(initial_state)
        initial_state['params']['shared_embedding'][
            'embedding'] = initial_state['params']['shared_embedding'][
                'embedding'].at[1, :].set(10.0)
        _, state1 = module.apply(flax.core.freeze(initial_state),
                                 inputs=inputs,
                                 targets=targets,
                                 mutable=True,
                                 rngs={'dropout': key})
        initial_state['params']['shared_embedding'][
            'embedding'] = initial_state['params']['shared_embedding'][
                'embedding'].at[1, :].set(200.0)
        _, state2 = module.apply(flax.core.freeze(initial_state),
                                 inputs=inputs,
                                 targets=targets,
                                 mutable=True,
                                 rngs={'dropout': key})
        print(initial_state['get_bounds']['encoder']['encoderblock_0']
              ['enc_self_att']['K']['bounds'])
        print(state1['get_bounds']['encoder']['encoderblock_0']['enc_self_att']
              ['K']['bounds'])
        print(state2['get_bounds']['encoder']['encoderblock_0']['enc_self_att']
              ['K']['bounds'])
        print('')
        test_utils.assert_stats_are_unequal(state1, state2)
Example #30
0
def encoder_from_file(config,
                      batch_size=8,
                      encode_length=16,
                      use_bfloat16=True,
                      use_xla_optimizations=True):
    """Generates HLO for just the encoder of the WMT model.

  Args:
    config: A ConfigDict instance.
    batch_size: Batch size.
    encode_length: Max length of an input sentence.
    use_bfloat16: Use bfloat16 mixed precision training instead of float32.
    use_xla_optimizations: Whether to use xla optimizations.
  """
    if FLAGS.checkpoint:
        raise app.UsageError('Checkpoints not yet supported for WMT encoder.')

    input_shape = (batch_size, encode_length)
    rng = jax.random.PRNGKey(0)
    hparams = hparams_utils.load_dataclass_from_config_dict(
        training_hparams.TrainingHParams, config)
    model_hparams = hparams.model_hparams
    model = models.Encoder(vocab_size=32711,
                           hparams=model_hparams.encoder,
                           shared_embedding=None,
                           use_bfloat16=use_bfloat16,
                           emb_dim=model_hparams.emb_dim,
                           num_heads=model_hparams.num_heads,
                           qkv_dim=model_hparams.qkv_dim,
                           mlp_dim=model_hparams.mlp_dim,
                           max_len=encode_length,
                           train=False,
                           dropout_rate=0.1,
                           attention_dropout_rate=0.1,
                           quant_context=quant_config.QuantContext(
                               update_bounds=False,
                               collect_acts_stats=False,
                               quantize_acts=True))
    init_state = model.init(rng, jnp.ones(input_shape, jnp.float32))

    def _fn(state, inputs):
        return model.apply(state, inputs, mutable=False)

    if not use_xla_optimizations:
        computation = jax.xla_computation(_fn)(init_state,
                                               jnp.ones(
                                                   input_shape, jnp.float32))
        hlo_utils.output_hlo(computation, FLAGS.hlo_output)

    else:

        def _wrapped_fn(inputs):
            return _fn(init_state, inputs)

        def to_shape_str(shape_tuple):
            return 'f32[%s]' % ','.join(map(str, shape_tuple))

        hlo_module_proto_str, hlo_txt = jax_to_ir.jax_to_hlo(
            _wrapped_fn,
            [('inputs', jax_to_ir.parse_shape_str(to_shape_str(input_shape)))])
        hlo_utils.output_hlo_to_file(hlo_module_proto_str, hlo_txt,
                                     FLAGS.hlo_output)