def test_multihead_self_attention_w_dropout(self, weight_prec): rng = random.PRNGKey(0) x = jnp.ones((4, 3, 5)) sa_module = flax_attention.SelfAttentionAqt( num_heads=8, hparams=self.construct_hparams(weight_prec), attention_axis=(1, ), quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), train=False, paxis_name=None, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, dropout_rate=0.1, dtype=jnp.float32, causal_mask=False, deterministic=False, decode=False) rng_dropout, rng_params = random.split(rng) y, _ = sa_module.init_with_output( { 'dropout': rng_dropout, 'params': rng_params }, x, padding_mask=None) self.assertEqual(y.shape, x.shape)
def test_autoregresive_receptive_field_1d(self, weight_prec): """Tests the autoregresive self-attention receptive field.""" rng = random.PRNGKey(0) rng1, rng2 = random.split(rng, num=2) def model_loss(inputs, pos): out = module.apply(initial_vars, inputs, padding_mask=None) assert out.shape == input_shape assert len(out.shape) == 3 return out[0, pos, :].sum() grad_fn = jax.jit(jax.grad(model_loss)) def get_receptive_field_1d(pos): g = grad_fn(inputs, pos)[0, :, :] return jnp.any((jnp.abs(g) > 1e-5).astype(jnp.uint32), axis=-1) length = 10 dim = 1 num_heads = 1 input_shape = (1, length, dim) inputs = random.normal(rng2, input_shape) module = flax_attention.SelfAttentionAqt( num_heads=num_heads, hparams=self.construct_hparams(weight_prec), quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), train=False, paxis_name=None, causal_mask=True, kernel_init=initializers.ones, dtype=jnp.float32, qkv_features=None, attention_axis=None, dropout_rate=0.0, deterministic=False, decode=False) initial_vars = module.init(rng1, jnp.ones((1, ) + (length, dim), jnp.float32), padding_mask=None) # model = nn.Model(module, initial_params) for i in range(length): deps = get_receptive_field_1d(i) assert (deps[:i] == 1).all(), ( 'Receptive Field Error: Some of the ' 'previous positions are not reachable ' 'in autoregressive self-attention.') if i != length - 1: k = i + 1 assert (deps[k:] == 0).all(), ( 'Receptive Field Error: Some of the ' 'future positions are reachable in ' 'autoregressive self-attention.')
def test_decoding(self, weight_prec, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 rng = random.PRNGKey(0) key1, key2 = random.split(rng) inputs = random.normal(key1, (bs, ) + spatial_shape + (num_heads * num_features, )) module = flax_attention.SelfAttentionAqt( num_heads=num_heads, hparams=self.construct_hparams(weight_prec), quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), train=False, paxis_name=None, qkv_features=num_heads * num_features, attention_axis=attn_dims, decode=False, causal_mask=True, dtype=jnp.float32, dropout_rate=0.0, deterministic=False) initial_vars = module.init(key2, inputs, padding_mask=None) y_ref = module.apply(initial_vars, inputs, padding_mask=None) module.decode = True initial_vars_decode = module.init(key2, inputs, padding_mask=None) cache0 = initial_vars_decode['cache'] def body_fn(cache, x): y, new_vars = module.apply({ **initial_vars, 'cache': cache }, x, mutable='cache', padding_mask=None) return new_vars['cache'], y # scan_in_dim supports scanning multiple dims _, y = jax_utils.scan_in_dim(body_fn, cache0, inputs, axis=attn_dims, keepdims=True) onp.testing.assert_allclose(y_ref, y, atol=1e-5)
def test_self_attention_act_quant_should_call_quant_ops( self, mock_inputs_fake_quant, attn_act_q, attn_act_k, attn_act_probs, attn_act_v, update_bounds, paxis_name, train): mock_inputs_fake_quant.side_effect = ( lambda inputs, hparams, get_bounds_params: inputs) rng = random.PRNGKey(0) x = jnp.ones((4, 3, 7)) hparams = self.construct_hparams(attn_act_q, attn_act_k, attn_act_probs, attn_act_v) sa_module = flax_attention.SelfAttentionAqt( hparams=hparams, num_heads=4, quant_context=quant_config.QuantContext( update_bounds=update_bounds, collect_acts_stats=False), train=train, paxis_name=paxis_name, attention_axis=None, qkv_features=8, kernel_init=initializers.ones, bias_init=initializers.zeros, causal_mask=False, dtype=jnp.float32, dropout_rate=0.0, deterministic=False, decode=False) sa_module.init(rng, x, padding_mask=None) calls = [] for hparam in [attn_act_q, attn_act_k, attn_act_probs, attn_act_v]: if hparam is not None: calls.append( unittest.mock.call( unittest.mock.ANY, hparams=hparam, get_bounds_params=get_bounds.GetBounds.Params( update_stats=train, update_bounds=update_bounds, paxis_name=paxis_name, mask=unittest.mock.ANY, module_name=unittest.mock.ANY))) mock_inputs_fake_quant.assert_has_calls(calls, any_order=True) self.assertLen(calls, mock_inputs_fake_quant.call_count)
def test_padding_mask(self): """Test that the activation stats respect masking.""" # This test's strategy is to change the value of a channels of a padding # token and make sure the stats don't change. Because the attention # calculation is fairly involved, this is more robust and less tedious than # trying to directly test numeric expected values. # Construct HParams with dynamic bounds. # Exact values don't matter, just need bounds to be dynamic so stats are # collected. bounds = get_bounds.GetBounds.Hyper( initial_bound=0.0, stddev_coeff=0.4, absdev_coeff=0.6, mix_coeff=0.4, reset_stats=False, granularity=quant_config.QuantGranularity.per_channel) quant_act = flax_layers.QuantOps.ActHParams( input_distribution=flax_layers.QuantOps.ActHParams. InputDistribution.symmetric, prec=8, bounds=bounds, half_shift=False) attn_quant_act = flax_layers.QuantOps.ActHParams( input_distribution=flax_layers.QuantOps.ActHParams. InputDistribution.positive, prec=8, bounds=1.0, half_shift=False) dense_hparams = flax_layers.DenseAqt.HParams( quant_type=flax_layers.QuantType.fake_quant, weight_prec=8, quant_act=quant_act, weight_quant_granularity=quant_config.QuantGranularity.per_channel, weight_half_shift=False) dotproduct_attn_hparams = flax_attention.DotProductAttnHParams( attn_act_q=quant_act, attn_act_k=quant_act, attn_act_v=quant_act, attn_act_probs=attn_quant_act, quant_type=QuantType.fake_quant, softmax=SoftmaxHParams(None, None, None)) attn_hparams = flax_attention.MultiHeadDotProductAttentionAqt.HParams( dense_kqv=dense_hparams, dense_out=dense_hparams, attn_acts=dotproduct_attn_hparams) module = flax_attention.SelfAttentionAqt( hparams=attn_hparams, num_heads=2, paxis_name=None, train=True, quant_context=quant_config.QuantContext(update_bounds=True, collect_acts_stats=False), dtype=jnp.float32, qkv_features=None, attention_axis=None, causal_mask=False, dropout_rate=0.0, deterministic=False, decode=False) # Simulate an input of a batch size of 1 with two tokens, each with four # features x = onp.arange(8).astype(onp.float32).reshape((1, 2, 4)) initial_state = module.init(random.PRNGKey(0), x, padding_mask=None) padding_mask = onp.full((1, 2, 1), True) padding_mask[0, 1, 0] = False # Mask out the second token _, state1 = module.apply(initial_state, x, padding_mask=padding_mask, mutable=True) # Now we adjust the input for the masked token and recompute the mean. It # should be the same as before. x[0, 1, 0] = 100 _, state2 = module.apply(initial_state, x, padding_mask=padding_mask, mutable=True) test_utils.assert_stats_are_equal(state1, state2) # Now we adjust the input for an unmasked token and verify that the stats # have changed. x[0, 0, 0] = 200 _, state3 = module.apply(initial_state, x, padding_mask=padding_mask, mutable=True) test_utils.assert_stats_are_unequal(state1, state3)
def __call__( self, targets, encoded, padding_mask, key_padding_mask, inputs_segmentation=None, targets_segmentation=None, ): """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder padding_mask: bool, mask padding tokens key_padding_mask: bool, mask padding tokens inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. Returns: output after transformer block. """ # Decoder block. batch_size, sequence_length, num_channels = targets.shape encoded_sequence_length = encoded.shape[1] shape_utils.assert_shapes_equal(padding_mask.shape, (batch_size, sequence_length, 1)) shape_utils.assert_shapes_equal( encoded.shape, (batch_size, encoded_sequence_length, num_channels)) shape_utils.assert_shapes_equal( key_padding_mask.shape, (batch_size, encoded_sequence_length, 1)) x = aqt_flax_layers.LayerNormAqt( dtype=self.dtype, hparams=self.hparams.layer_norm, quant_context=self.quant_context)(targets) x = aqt_flax_attention.SelfAttentionAqt( hparams=self.hparams.self_attention, num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.qkv_dim, attention_axis=(1, ), paxis_name='batch', train=self.train, quant_context=self.quant_context, causal_mask=True, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, deterministic=self.deterministic, name='dec_self_att', decode=self.decode)( x, padding_mask=padding_mask, segmentation=targets_segmentation, ) x = nn.Dropout(rate=self.dropout_rate)( x, deterministic=self.deterministic) x = x + targets # Encoder-Decoder block. y = aqt_flax_layers.LayerNormAqt(dtype=self.dtype, hparams=self.hparams.layer_norm, quant_context=self.quant_context)(x) y = aqt_flax_attention.MultiHeadDotProductAttentionAqt( hparams=self.hparams.enc_dec_attention, num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.qkv_dim, attention_axis=(1, ), paxis_name='batch', train=self.train, quant_context=self.quant_context, causal_mask=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, deterministic=self.deterministic, decode=False, name='dec_enc_att')(inputs_q=y, inputs_kv=encoded, padding_mask=padding_mask, key_padding_mask=key_padding_mask, segmentation=targets_segmentation, key_segmentation=inputs_segmentation) y = nn.Dropout(rate=self.dropout_rate)( y, deterministic=self.deterministic) y = y + x # MLP block. z = aqt_flax_layers.LayerNormAqt(dtype=self.dtype, hparams=self.hparams.layer_norm, quant_context=self.quant_context)(y) z = MlpBlock( mlp_dim=self.mlp_dim, dtype=self.dtype, hparams=self.hparams.mlp_block, # inputs would be signed, called after attention layer. train=self.train, quant_context=self.quant_context, dropout_rate=self.dropout_rate, deterministic=self.deterministic, name='mlp_block')(z, padding_mask=padding_mask) return y + z
def __call__(self, inputs, padding_mask, inputs_segmentation=None): """Applies Encoder1DBlock module. Args: inputs: input data padding_mask: bool, mask padding tokens inputs_segmentation: input segmentation info for packed examples. Returns: output after transformer block. """ # Attention block. batch_size, sequence_length, channel_size = inputs.shape shape_utils.assert_shapes_equal(padding_mask.shape, (batch_size, sequence_length, 1)) x = aqt_flax_layers.LayerNormAqt( dtype=self.dtype, hparams=self.hparams.layer_norm, quant_context=self.quant_context)(inputs) x = aqt_flax_attention.SelfAttentionAqt( hparams=self.hparams.attention, num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.qkv_dim, attention_axis=(1, ), paxis_name='batch', train=self.train, quant_context=self.quant_context, causal_mask=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, deterministic=self.deterministic, decode=False, name='enc_self_att')(x, padding_mask=padding_mask, segmentation=inputs_segmentation) x = nn.Dropout(rate=self.dropout_rate)( x, deterministic=self.deterministic) x = x + inputs # MLP block. y = aqt_flax_layers.LayerNormAqt(dtype=self.dtype, hparams=self.hparams.layer_norm, quant_context=self.quant_context)(x) y = MlpBlock( mlp_dim=self.mlp_dim, hparams=self.hparams.mlp_block, # inputs would be signed, called after attention layer. train=self.train, quant_context=self.quant_context, dtype=self.dtype, dropout_rate=self.dropout_rate, deterministic=self.deterministic, name='mlp_block')(y, padding_mask=padding_mask) out = x + y shape_utils.assert_shapes_equal( out.shape, (batch_size, sequence_length, channel_size)) return out