示例#1
0
 def test_multihead_self_attention_w_dropout(self, weight_prec):
     rng = random.PRNGKey(0)
     x = jnp.ones((4, 3, 5))
     sa_module = flax_attention.SelfAttentionAqt(
         num_heads=8,
         hparams=self.construct_hparams(weight_prec),
         attention_axis=(1, ),
         quant_context=quant_config.QuantContext(update_bounds=False,
                                                 collect_acts_stats=False),
         train=False,
         paxis_name=None,
         qkv_features=16,
         kernel_init=initializers.ones,
         bias_init=initializers.zeros,
         dropout_rate=0.1,
         dtype=jnp.float32,
         causal_mask=False,
         deterministic=False,
         decode=False)
     rng_dropout, rng_params = random.split(rng)
     y, _ = sa_module.init_with_output(
         {
             'dropout': rng_dropout,
             'params': rng_params
         },
         x,
         padding_mask=None)
     self.assertEqual(y.shape, x.shape)
示例#2
0
    def test_autoregresive_receptive_field_1d(self, weight_prec):
        """Tests the autoregresive self-attention receptive field."""
        rng = random.PRNGKey(0)
        rng1, rng2 = random.split(rng, num=2)

        def model_loss(inputs, pos):
            out = module.apply(initial_vars, inputs, padding_mask=None)
            assert out.shape == input_shape
            assert len(out.shape) == 3
            return out[0, pos, :].sum()

        grad_fn = jax.jit(jax.grad(model_loss))

        def get_receptive_field_1d(pos):
            g = grad_fn(inputs, pos)[0, :, :]
            return jnp.any((jnp.abs(g) > 1e-5).astype(jnp.uint32), axis=-1)

        length = 10
        dim = 1
        num_heads = 1
        input_shape = (1, length, dim)
        inputs = random.normal(rng2, input_shape)

        module = flax_attention.SelfAttentionAqt(
            num_heads=num_heads,
            hparams=self.construct_hparams(weight_prec),
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            train=False,
            paxis_name=None,
            causal_mask=True,
            kernel_init=initializers.ones,
            dtype=jnp.float32,
            qkv_features=None,
            attention_axis=None,
            dropout_rate=0.0,
            deterministic=False,
            decode=False)
        initial_vars = module.init(rng1,
                                   jnp.ones((1, ) + (length, dim),
                                            jnp.float32),
                                   padding_mask=None)
        # model = nn.Model(module, initial_params)

        for i in range(length):
            deps = get_receptive_field_1d(i)
            assert (deps[:i] == 1).all(), (
                'Receptive Field Error: Some of the '
                'previous positions are not reachable '
                'in autoregressive self-attention.')
            if i != length - 1:
                k = i + 1
                assert (deps[k:] == 0).all(), (
                    'Receptive Field Error: Some of the '
                    'future positions are reachable in '
                    'autoregressive self-attention.')
示例#3
0
    def test_decoding(self, weight_prec, spatial_shape, attn_dims):
        bs = 2
        num_heads = 3
        num_features = 4
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        inputs = random.normal(key1, (bs, ) + spatial_shape +
                               (num_heads * num_features, ))
        module = flax_attention.SelfAttentionAqt(
            num_heads=num_heads,
            hparams=self.construct_hparams(weight_prec),
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            train=False,
            paxis_name=None,
            qkv_features=num_heads * num_features,
            attention_axis=attn_dims,
            decode=False,
            causal_mask=True,
            dtype=jnp.float32,
            dropout_rate=0.0,
            deterministic=False)

        initial_vars = module.init(key2, inputs, padding_mask=None)
        y_ref = module.apply(initial_vars, inputs, padding_mask=None)
        module.decode = True
        initial_vars_decode = module.init(key2, inputs, padding_mask=None)
        cache0 = initial_vars_decode['cache']

        def body_fn(cache, x):
            y, new_vars = module.apply({
                **initial_vars, 'cache': cache
            },
                                       x,
                                       mutable='cache',
                                       padding_mask=None)
            return new_vars['cache'], y

        # scan_in_dim supports scanning multiple dims
        _, y = jax_utils.scan_in_dim(body_fn,
                                     cache0,
                                     inputs,
                                     axis=attn_dims,
                                     keepdims=True)

        onp.testing.assert_allclose(y_ref, y, atol=1e-5)
示例#4
0
    def test_self_attention_act_quant_should_call_quant_ops(
            self, mock_inputs_fake_quant, attn_act_q, attn_act_k,
            attn_act_probs, attn_act_v, update_bounds, paxis_name, train):

        mock_inputs_fake_quant.side_effect = (
            lambda inputs, hparams, get_bounds_params: inputs)

        rng = random.PRNGKey(0)
        x = jnp.ones((4, 3, 7))
        hparams = self.construct_hparams(attn_act_q, attn_act_k,
                                         attn_act_probs, attn_act_v)
        sa_module = flax_attention.SelfAttentionAqt(
            hparams=hparams,
            num_heads=4,
            quant_context=quant_config.QuantContext(
                update_bounds=update_bounds, collect_acts_stats=False),
            train=train,
            paxis_name=paxis_name,
            attention_axis=None,
            qkv_features=8,
            kernel_init=initializers.ones,
            bias_init=initializers.zeros,
            causal_mask=False,
            dtype=jnp.float32,
            dropout_rate=0.0,
            deterministic=False,
            decode=False)
        sa_module.init(rng, x, padding_mask=None)
        calls = []
        for hparam in [attn_act_q, attn_act_k, attn_act_probs, attn_act_v]:
            if hparam is not None:
                calls.append(
                    unittest.mock.call(
                        unittest.mock.ANY,
                        hparams=hparam,
                        get_bounds_params=get_bounds.GetBounds.Params(
                            update_stats=train,
                            update_bounds=update_bounds,
                            paxis_name=paxis_name,
                            mask=unittest.mock.ANY,
                            module_name=unittest.mock.ANY)))
        mock_inputs_fake_quant.assert_has_calls(calls, any_order=True)

        self.assertLen(calls, mock_inputs_fake_quant.call_count)
示例#5
0
    def test_padding_mask(self):
        """Test that the activation stats respect masking."""
        # This test's strategy is to change the value of a channels of a padding
        # token and make sure the stats don't change. Because the attention
        # calculation is fairly involved, this is more robust and less tedious than
        # trying to directly test numeric expected values.

        # Construct HParams with dynamic bounds.
        # Exact values don't matter, just need bounds to be dynamic so stats are
        # collected.
        bounds = get_bounds.GetBounds.Hyper(
            initial_bound=0.0,
            stddev_coeff=0.4,
            absdev_coeff=0.6,
            mix_coeff=0.4,
            reset_stats=False,
            granularity=quant_config.QuantGranularity.per_channel)
        quant_act = flax_layers.QuantOps.ActHParams(
            input_distribution=flax_layers.QuantOps.ActHParams.
            InputDistribution.symmetric,
            prec=8,
            bounds=bounds,
            half_shift=False)
        attn_quant_act = flax_layers.QuantOps.ActHParams(
            input_distribution=flax_layers.QuantOps.ActHParams.
            InputDistribution.positive,
            prec=8,
            bounds=1.0,
            half_shift=False)
        dense_hparams = flax_layers.DenseAqt.HParams(
            quant_type=flax_layers.QuantType.fake_quant,
            weight_prec=8,
            quant_act=quant_act,
            weight_quant_granularity=quant_config.QuantGranularity.per_channel,
            weight_half_shift=False)
        dotproduct_attn_hparams = flax_attention.DotProductAttnHParams(
            attn_act_q=quant_act,
            attn_act_k=quant_act,
            attn_act_v=quant_act,
            attn_act_probs=attn_quant_act,
            quant_type=QuantType.fake_quant,
            softmax=SoftmaxHParams(None, None, None))
        attn_hparams = flax_attention.MultiHeadDotProductAttentionAqt.HParams(
            dense_kqv=dense_hparams,
            dense_out=dense_hparams,
            attn_acts=dotproduct_attn_hparams)

        module = flax_attention.SelfAttentionAqt(
            hparams=attn_hparams,
            num_heads=2,
            paxis_name=None,
            train=True,
            quant_context=quant_config.QuantContext(update_bounds=True,
                                                    collect_acts_stats=False),
            dtype=jnp.float32,
            qkv_features=None,
            attention_axis=None,
            causal_mask=False,
            dropout_rate=0.0,
            deterministic=False,
            decode=False)
        # Simulate an input of a batch size of 1 with two tokens, each with four
        # features
        x = onp.arange(8).astype(onp.float32).reshape((1, 2, 4))
        initial_state = module.init(random.PRNGKey(0), x, padding_mask=None)

        padding_mask = onp.full((1, 2, 1), True)
        padding_mask[0, 1, 0] = False  # Mask out the second token
        _, state1 = module.apply(initial_state,
                                 x,
                                 padding_mask=padding_mask,
                                 mutable=True)
        # Now we adjust the input for the masked token and recompute the mean. It
        # should be the same as before.
        x[0, 1, 0] = 100
        _, state2 = module.apply(initial_state,
                                 x,
                                 padding_mask=padding_mask,
                                 mutable=True)
        test_utils.assert_stats_are_equal(state1, state2)
        # Now we adjust the input for an unmasked token and verify that the stats
        # have changed.
        x[0, 0, 0] = 200
        _, state3 = module.apply(initial_state,
                                 x,
                                 padding_mask=padding_mask,
                                 mutable=True)
        test_utils.assert_stats_are_unequal(state1, state3)
示例#6
0
    def __call__(
        self,
        targets,
        encoded,
        padding_mask,
        key_padding_mask,
        inputs_segmentation=None,
        targets_segmentation=None,
    ):
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      padding_mask: bool, mask padding tokens
      key_padding_mask: bool, mask padding tokens
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.

    Returns:
      output after transformer block.
    """

        # Decoder block.
        batch_size, sequence_length, num_channels = targets.shape
        encoded_sequence_length = encoded.shape[1]
        shape_utils.assert_shapes_equal(padding_mask.shape,
                                        (batch_size, sequence_length, 1))
        shape_utils.assert_shapes_equal(
            encoded.shape, (batch_size, encoded_sequence_length, num_channels))
        shape_utils.assert_shapes_equal(
            key_padding_mask.shape, (batch_size, encoded_sequence_length, 1))

        x = aqt_flax_layers.LayerNormAqt(
            dtype=self.dtype,
            hparams=self.hparams.layer_norm,
            quant_context=self.quant_context)(targets)
        x = aqt_flax_attention.SelfAttentionAqt(
            hparams=self.hparams.self_attention,
            num_heads=self.num_heads,
            dtype=self.dtype,
            qkv_features=self.qkv_dim,
            attention_axis=(1, ),
            paxis_name='batch',
            train=self.train,
            quant_context=self.quant_context,
            causal_mask=True,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=self.attention_dropout_rate,
            deterministic=self.deterministic,
            name='dec_self_att',
            decode=self.decode)(
                x,
                padding_mask=padding_mask,
                segmentation=targets_segmentation,
            )
        x = nn.Dropout(rate=self.dropout_rate)(
            x, deterministic=self.deterministic)
        x = x + targets

        # Encoder-Decoder block.
        y = aqt_flax_layers.LayerNormAqt(dtype=self.dtype,
                                         hparams=self.hparams.layer_norm,
                                         quant_context=self.quant_context)(x)
        y = aqt_flax_attention.MultiHeadDotProductAttentionAqt(
            hparams=self.hparams.enc_dec_attention,
            num_heads=self.num_heads,
            dtype=self.dtype,
            qkv_features=self.qkv_dim,
            attention_axis=(1, ),
            paxis_name='batch',
            train=self.train,
            quant_context=self.quant_context,
            causal_mask=False,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=self.attention_dropout_rate,
            deterministic=self.deterministic,
            decode=False,
            name='dec_enc_att')(inputs_q=y,
                                inputs_kv=encoded,
                                padding_mask=padding_mask,
                                key_padding_mask=key_padding_mask,
                                segmentation=targets_segmentation,
                                key_segmentation=inputs_segmentation)
        y = nn.Dropout(rate=self.dropout_rate)(
            y, deterministic=self.deterministic)
        y = y + x

        # MLP block.
        z = aqt_flax_layers.LayerNormAqt(dtype=self.dtype,
                                         hparams=self.hparams.layer_norm,
                                         quant_context=self.quant_context)(y)
        z = MlpBlock(
            mlp_dim=self.mlp_dim,
            dtype=self.dtype,
            hparams=self.hparams.mlp_block,
            # inputs would be signed, called after attention layer.
            train=self.train,
            quant_context=self.quant_context,
            dropout_rate=self.dropout_rate,
            deterministic=self.deterministic,
            name='mlp_block')(z, padding_mask=padding_mask)

        return y + z
示例#7
0
    def __call__(self, inputs, padding_mask, inputs_segmentation=None):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data
      padding_mask: bool, mask padding tokens
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output after transformer block.
    """

        # Attention block.
        batch_size, sequence_length, channel_size = inputs.shape
        shape_utils.assert_shapes_equal(padding_mask.shape,
                                        (batch_size, sequence_length, 1))
        x = aqt_flax_layers.LayerNormAqt(
            dtype=self.dtype,
            hparams=self.hparams.layer_norm,
            quant_context=self.quant_context)(inputs)
        x = aqt_flax_attention.SelfAttentionAqt(
            hparams=self.hparams.attention,
            num_heads=self.num_heads,
            dtype=self.dtype,
            qkv_features=self.qkv_dim,
            attention_axis=(1, ),
            paxis_name='batch',
            train=self.train,
            quant_context=self.quant_context,
            causal_mask=False,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=self.attention_dropout_rate,
            deterministic=self.deterministic,
            decode=False,
            name='enc_self_att')(x,
                                 padding_mask=padding_mask,
                                 segmentation=inputs_segmentation)

        x = nn.Dropout(rate=self.dropout_rate)(
            x, deterministic=self.deterministic)
        x = x + inputs

        # MLP block.
        y = aqt_flax_layers.LayerNormAqt(dtype=self.dtype,
                                         hparams=self.hparams.layer_norm,
                                         quant_context=self.quant_context)(x)
        y = MlpBlock(
            mlp_dim=self.mlp_dim,
            hparams=self.hparams.mlp_block,
            # inputs would be signed, called after attention layer.
            train=self.train,
            quant_context=self.quant_context,
            dtype=self.dtype,
            dropout_rate=self.dropout_rate,
            deterministic=self.deterministic,
            name='mlp_block')(y, padding_mask=padding_mask)
        out = x + y
        shape_utils.assert_shapes_equal(
            out.shape, (batch_size, sequence_length, channel_size))
        return out