def test_annotation_only_changes_hlo_metadata_conv(self, weight_prec, acts_prec): FLAGS.metadata_enabled = False quant_act = quantization.QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, prec=acts_prec, bounds=1.0) input_shape = (1, 8, 8, 3) module_no_annotation = aqt_flax_layers.ConvAqt( features=4, kernel_size=(3, 3), padding='VALID', paxis_name='batch', quant_context=quant_config.QuantContext(update_bounds=False), train=False, hparams=aqt_flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant), kernel_init=initializers.ones, bias_init=initializers.ones, dtype=jnp.float32) init_state = module_no_annotation.init( self.rng_key, jnp.ones(input_shape, jnp.float32)) output_no_annotation = module_no_annotation.apply( init_state, jnp.ones(input_shape)) hlo_no_annotation = hlo_utils.load_hlo_proto_from_model( module_no_annotation, init_state, [input_shape]) del init_state FLAGS.metadata_enabled = True module_w_annotation = aqt_flax_layers.ConvAqt( features=4, kernel_size=(3, 3), padding='VALID', paxis_name='batch', quant_context=quant_config.QuantContext(update_bounds=False), train=False, hparams=aqt_flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant), kernel_init=initializers.ones, bias_init=initializers.ones, dtype=jnp.float32) init_state = module_w_annotation.init( self.rng_key, jnp.ones(input_shape, jnp.float32)) output_w_annotation = module_w_annotation.apply( init_state, jnp.ones(input_shape)) hlo_w_annotation = hlo_utils.load_hlo_proto_from_model( module_w_annotation, init_state, [input_shape]) del init_state onp.testing.assert_array_equal(output_no_annotation, output_w_annotation) self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
def test_softmax_vs_original(self, input_tensor, softmax_hparams): dtype = jax._src.numpy.lax_numpy.float32 norm_dims = (0, ) input_tensor = jnp.array(input_tensor) output = flax_attention.softmax( input_tensor, norm_dims, dtype, softmax_hparams, quant_config.QuantContext(update_bounds=False, quantize_acts=True)) expected_output = flax_attention.softmax( input_tensor, norm_dims, dtype, SoftmaxHParams(None, None, None), quant_config.QuantContext(update_bounds=False, quantize_acts=True)) self.assertAllClose(expected_output, output, atol=1e-8)
def test_multihead_encoder_decoder_attention(self, weight_prec): rng = random.PRNGKey(0) q = jnp.ones((4, 3, 5)) kv = jnp.ones((4, 3, 5)) sa_module = flax_attention.MultiHeadDotProductAttentionAqt( num_heads=8, hparams=self.construct_hparams(weight_prec), attention_axis=(1, ), quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), train=False, paxis_name=None, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, dtype=jnp.float32, causal_mask=False, dropout_rate=0.0, deterministic=False, decode=False) y, _ = sa_module.init_with_output(rng, q, kv, padding_mask=None, key_padding_mask=None) self.assertEqual(y.shape, q.shape)
def init_model_with_1_layer(self, inputs, num_features, kernel_size, kernel_init=flax_layers.default_kernel_init, weight_prec=None, quant_act=None, weight_half_shift=False): """Create and initialize a flax model with a single ConvAqt layer.""" layer_kwargs = { 'kernel_init': kernel_init, 'features': num_features, 'use_bias': False, 'quant_context': quant_config.QuantContext(update_bounds=False), 'paxis_name': 'batch', 'train': False, 'kernel_size': kernel_size, 'dtype': jnp.float32 } layer_class = flax_layers.ConvAqt layer_kwargs['hparams'] = flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_half_shift=weight_half_shift, ) conv_module = layer_class(**layer_kwargs) initial_state = conv_module.init(self.rng_key, jnp.zeros(inputs.shape)) return conv_module, initial_state
def test_custom_softmax_vs_mock(self, input_tensor, norm_dims, softmax_hparams, expected_output): dtype = jax._src.numpy.lax_numpy.float32 output = flax_attention.softmax( input_tensor, norm_dims, dtype, softmax_hparams, quant_config.QuantContext(update_bounds=False, quantize_acts=False)) self.assertAllClose(expected_output, output, atol=1e-6)
def get_quant_context_for_step( *, activation_bound_update_freq, activation_bound_start_step, step, collect_acts_stats, prefer_int8_to_int32_dot): """Returns correct quantization context for a given step. Args: activation_bound_update_freq: How frequently to update bounds after the initial bounds update. A value of '-1' indicates to not update the bounds after the first update. activation_bound_start_step: The first step to update bounds on. '-1' indicates to never update bounds. step: The current training step. collect_acts_stats: Whether to collect activation statistics. prefer_int8_to_int32_dot: Whether to feed lax.dot inputs with an int8 dtype and accumulate to int32. Returns: A quant_config.QuantContext instance. """ update_bounds = should_update_bounds( activation_bound_start_step=activation_bound_start_step, activation_bound_update_freq=activation_bound_update_freq, step=step) quantize_acts = step >= activation_bound_start_step return quant_config.QuantContext( update_bounds=update_bounds, quantize_acts=quantize_acts, collect_acts_stats=collect_acts_stats, prefer_int8_to_int32_dot=prefer_int8_to_int32_dot, )
def test_embed(self, weight_prec): # Since the dummy embedding matrix has a row of all zeros, we need 'epsilon' # to be added to it before calculating scale factors. quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = False rng = random.PRNGKey(0) x = jnp.arange(4)[None] dummy_embedding = jnp.broadcast_to(jnp.arange(4)[Ellipsis, None], (4, 3)).astype(jnp.float32) embed_module = flax_layers.EmbedAqt( num_embeddings=4, features=3, dtype=jnp.float32, hparams=flax_layers.EmbedAqt.HParams( weight_prec=weight_prec, quant_act=None, quant_type=QuantType.fake_quant, weight_half_shift=False), embedding_init=lambda _rng, _shape: dummy_embedding, train=False, paxis_name=None, quant_context=quant_config.QuantContext(update_bounds=False), ) y, state = embed_module.init_with_output(rng, x) test_utils.assert_all_close_prec(dummy_embedding[None], y, weight_prec) z = embed_module.apply( state, jnp.ones((1, 3)), padding_mask=None, method=embed_module.attend) test_utils.assert_all_close_prec(3. * jnp.arange(4), z[0, Ellipsis], weight_prec)
def test_embed_equality(self, weight_prec): rng = random.PRNGKey(0) x = 2 * jnp.ones(4, dtype=jnp.int32)[None] dummy_embedding = 2 * jnp.ones((4, 2)).astype(jnp.float32) embed_module = flax_layers.EmbedAqt( num_embeddings=4, features=2, dtype=jnp.float32, hparams=flax_layers.EmbedAqt.HParams( weight_prec=weight_prec, quant_act=None, quant_type=QuantType.fake_quant, weight_half_shift=False), embedding_init=lambda _rng, _shape: dummy_embedding, train=False, quant_context=quant_config.QuantContext(update_bounds=False), paxis_name=None) y, init_state = embed_module.init_with_output(rng, x) onp.testing.assert_array_equal(dummy_embedding[None], y) z = embed_module.apply( init_state, jnp.ones((1, 2)), padding_mask=None, method=embed_module.attend) onp.testing.assert_array_equal(2. * (2 * jnp.ones(4)), z[0, Ellipsis])
def init_model_with_1_layer(self, inputs, num_features, kernel_init=flax_layers.default_kernel_init, weight_prec=None, quant_act=None, weight_half_shift=False): """Create and initialize a flax model with a single DenseAqt layer.""" quant_context = quant_config.QuantContext( update_bounds=False, collect_acts_stats=False) layer_kwargs = { 'kernel_init': kernel_init, 'features': num_features, 'use_bias': False, 'quant_context': quant_context, 'paxis_name': 'batch', 'train': False, 'dtype': jnp.float32 } layer_kwargs['hparams'] = flax_layers.DenseAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity.per_channel, weight_half_shift=weight_half_shift) dense_module = flax_layers.DenseAqt(**layer_kwargs) initial_state = dense_module.init( self.rng_key, jnp.zeros(inputs.shape), padding_mask=None) return dense_module, initial_state
def test_multihead_self_attention_w_dropout(self, weight_prec): rng = random.PRNGKey(0) x = jnp.ones((4, 3, 5)) sa_module = flax_attention.SelfAttentionAqt( num_heads=8, hparams=self.construct_hparams(weight_prec), attention_axis=(1, ), quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), train=False, paxis_name=None, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, dropout_rate=0.1, dtype=jnp.float32, causal_mask=False, deterministic=False, decode=False) rng_dropout, rng_params = random.split(rng) y, _ = sa_module.init_with_output( { 'dropout': rng_dropout, 'params': rng_params }, x, padding_mask=None) self.assertEqual(y.shape, x.shape)
def get_quant_context_for_step(*, activation_bound_update_freq, activation_bound_start_step, step, collect_acts_stats): """Returns correct quantization context for a given step. Args: activation_bound_update_freq: How frequently to update bounds after the initial bounds update. A value of '-1' indicates to not update the bounds after the first update. activation_bound_start_step: The first step to update bounds on. '-1' indicates to never update bounds. step: The current training step. collect_acts_stats: Whether to collect activation statistics. Returns: A quant_config.QuantContext instance. """ update_bounds = should_update_bounds( activation_bound_start_step=activation_bound_start_step, activation_bound_update_freq=activation_bound_update_freq, step=step) quantize_acts = step >= activation_bound_start_step # TODO(shivaniagrawal): We hardcode this to False to force the inputs to # lax.dot_general to be floating-point type. Otherwise, training diverges for # 8bit quantization for unnknown reasons. prefer_int8_to_int32_dot = False return quant_config.QuantContext( update_bounds=update_bounds, quantize_acts=quantize_acts, collect_acts_stats=collect_acts_stats, prefer_int8_to_int32_dot=prefer_int8_to_int32_dot)
def test_padding(self): """Test that padding results in the right statistics being collected.""" # Exact values don't matter here, we just need code to think it's using # dynamic bounds so it gathers activation statistics bounds = get_bounds.GetBounds.Hyper( initial_bound=0.0, stddev_coeff=1.0, absdev_coeff=0.0, mix_coeff=1.0, reset_stats=False, granularity=quant_config.QuantGranularity.per_channel) quant_act = flax_layers.QuantOps.ActHParams( input_distribution=flax_layers.QuantOps.ActHParams.InputDistribution .symmetric, prec=8, bounds=bounds) hparams = flax_layers.DenseAqt.HParams( quant_type=flax_layers.QuantType.fake_quant, weight_prec=8, quant_act=quant_act, weight_quant_granularity=quant_config.QuantGranularity.per_channel) module = flax_layers.DenseAqt( hparams=hparams, features=1, paxis_name=None, quant_context=quant_config.QuantContext( update_bounds=True, collect_acts_stats=False), train=True, dtype=jnp.float32) # Simulate an input with a batch size of 2, three tokens per example, two # channels per token x = jnp.arange(12).astype(jnp.float32).reshape((2, 3, 2)) # Reshape it to have dimensions [batch, feature] x = x.reshape(6, 2) initial_state = module.init(self.rng_key, x, padding_mask=None) # Check that the per-channel activation statistics are as expected with no # padding _, state_nopadding = module.apply( initial_state, x, padding_mask=None, mutable='get_bounds') expected_means = onp.array([[(0 + 2 + 4 + 6 + 8 + 10) / 6, (1 + 3 + 5 + 7 + 9 + 11) / 6]]) actual_means = state_nopadding['get_bounds']['GetBounds_0']['stats'].mean onp.testing.assert_allclose(actual_means, expected_means) # Now we pad out some of the tokens (chosen arbitrarily) and check that the # computed per-channel stats are the means of the non-padding tokens only # Exclude the second and third tokens from the first batch and the first # token from the second batch. padding_mask = jnp.array([[True, False, False], [False, True, True]]) # Reshape it to have dimensions [batch, feature] padding_mask = padding_mask.reshape(6, 1) _, state_padding = module.apply( initial_state, x, padding_mask=padding_mask, mutable='get_bounds') expected_means = onp.array([[(0 + 8 + 10) / 3, (1 + 9 + 11) / 3]]) actual_means = state_padding['get_bounds']['GetBounds_0']['stats'].mean onp.testing.assert_allclose(actual_means, expected_means)
def create_resnet(hparams, dtype, train, **kwargs): return ResNet(num_classes=1000, dtype=dtype, hparams=hparams, quant_context=quant_config.QuantContext( update_bounds=False, quantize_weights=True), num_filters=64, train=train, **kwargs)
def test_autoregresive_receptive_field_1d(self, weight_prec): """Tests the autoregresive self-attention receptive field.""" rng = random.PRNGKey(0) rng1, rng2 = random.split(rng, num=2) def model_loss(inputs, pos): out = module.apply(initial_vars, inputs, padding_mask=None) assert out.shape == input_shape assert len(out.shape) == 3 return out[0, pos, :].sum() grad_fn = jax.jit(jax.grad(model_loss)) def get_receptive_field_1d(pos): g = grad_fn(inputs, pos)[0, :, :] return jnp.any((jnp.abs(g) > 1e-5).astype(jnp.uint32), axis=-1) length = 10 dim = 1 num_heads = 1 input_shape = (1, length, dim) inputs = random.normal(rng2, input_shape) module = flax_attention.SelfAttentionAqt( num_heads=num_heads, hparams=self.construct_hparams(weight_prec), quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), train=False, paxis_name=None, causal_mask=True, kernel_init=initializers.ones, dtype=jnp.float32, qkv_features=None, attention_axis=None, dropout_rate=0.0, deterministic=False, decode=False) initial_vars = module.init(rng1, jnp.ones((1, ) + (length, dim), jnp.float32), padding_mask=None) # model = nn.Model(module, initial_params) for i in range(length): deps = get_receptive_field_1d(i) assert (deps[:i] == 1).all(), ( 'Receptive Field Error: Some of the ' 'previous positions are not reachable ' 'in autoregressive self-attention.') if i != length - 1: k = i + 1 assert (deps[k:] == 0).all(), ( 'Receptive Field Error: Some of the ' 'future positions are reachable in ' 'autoregressive self-attention.')
def _without_weights(inputs, params): return predict.step(inputs, params, cache, state, EOS_TOKEN, decode_length, transformer_kwargs=transformer_kwargs, hparams=model_hparams, quant_context=quant_config.QuantContext( update_bounds=False, quantize_acts=True))
def __call__(self, inputs, hparams, num_classes, dtype=jnp.float32): output = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=False, quant_context=quant_config.QuantContext( update_bounds=False, collect_acts_stats=False), paxis_name='batch', hparams=hparams, )(inputs, padding_mask=None) return output
def init_model(self, transformer_kwargs): model = models.Transformer(use_bfloat16=False, quant_context=quant_config.QuantContext( collect_acts_stats=False, update_bounds=False), dropout_rate=.1, attention_dropout_rate=.1, should_decode=False, **transformer_kwargs) state = model.init(self.key, jnp.zeros(self.input_shape, jnp.float32), jnp.zeros(self.target_shape, jnp.float32)) return model, state
def test_decoding(self, weight_prec, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 rng = random.PRNGKey(0) key1, key2 = random.split(rng) inputs = random.normal(key1, (bs, ) + spatial_shape + (num_heads * num_features, )) module = flax_attention.SelfAttentionAqt( num_heads=num_heads, hparams=self.construct_hparams(weight_prec), quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), train=False, paxis_name=None, qkv_features=num_heads * num_features, attention_axis=attn_dims, decode=False, causal_mask=True, dtype=jnp.float32, dropout_rate=0.0, deterministic=False) initial_vars = module.init(key2, inputs, padding_mask=None) y_ref = module.apply(initial_vars, inputs, padding_mask=None) module.decode = True initial_vars_decode = module.init(key2, inputs, padding_mask=None) cache0 = initial_vars_decode['cache'] def body_fn(cache, x): y, new_vars = module.apply({ **initial_vars, 'cache': cache }, x, mutable='cache', padding_mask=None) return new_vars['cache'], y # scan_in_dim supports scanning multiple dims _, y = jax_utils.scan_in_dim(body_fn, cache0, inputs, axis=attn_dims, keepdims=True) onp.testing.assert_allclose(y_ref, y, atol=1e-5)
def test_embed_should_call_clip_and_round(self, floor_with_gradient, round_with_gradient, weight_prec, acts_prec, fixed_bounds): round_with_gradient.side_effect = lambda x: x floor_with_gradient.side_effect = lambda x: x if fixed_bounds: bounds = 6.0 else: bounds = get_bounds.GetBounds.Hyper( initial_bound=6.0, stddev_coeff=3.0, absdev_coeff=2.0, mix_coeff=0.5, granularity=quant_config.QuantGranularity.per_tensor) quant_act = quantization.QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, prec=acts_prec, bounds=bounds, half_shift=False) rng = random.PRNGKey(0) x = jnp.ones((1, 3)) embed_module = flax_layers.EmbedAqt( num_embeddings=4, features=3, dtype=jnp.float32, hparams=flax_layers.EmbedAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_half_shift=False), quant_context=quant_config.QuantContext(update_bounds=False), paxis_name=None, train=False) init_state = embed_module.init( rng, x, method=embed_module.attend, padding_mask=None) round_with_gradient.reset_mock() floor_with_gradient.reset_mock() embed_module.apply( init_state, x, padding_mask=None, method=embed_module.attend) round_with_gradient.assert_called_with(mock.ANY) self.assertEqual(round_with_gradient.call_count, 1) floor_with_gradient.assert_not_called()
def __call__(self, inputs, hparams, kernel_size, num_filters, strides, dtype=jnp.float32): output = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=kernel_size, strides=strides, use_bias=False, dtype=dtype, train=False, quant_context=quant_config.QuantContext(update_bounds=False), paxis_name='batch', hparams=hparams)(inputs) return output
def test_self_attention_act_quant_should_call_quant_ops( self, mock_inputs_fake_quant, attn_act_q, attn_act_k, attn_act_probs, attn_act_v, update_bounds, paxis_name, train): mock_inputs_fake_quant.side_effect = ( lambda inputs, hparams, get_bounds_params: inputs) rng = random.PRNGKey(0) x = jnp.ones((4, 3, 7)) hparams = self.construct_hparams(attn_act_q, attn_act_k, attn_act_probs, attn_act_v) sa_module = flax_attention.SelfAttentionAqt( hparams=hparams, num_heads=4, quant_context=quant_config.QuantContext( update_bounds=update_bounds, collect_acts_stats=False), train=train, paxis_name=paxis_name, attention_axis=None, qkv_features=8, kernel_init=initializers.ones, bias_init=initializers.zeros, causal_mask=False, dtype=jnp.float32, dropout_rate=0.0, deterministic=False, decode=False) sa_module.init(rng, x, padding_mask=None) calls = [] for hparam in [attn_act_q, attn_act_k, attn_act_probs, attn_act_v]: if hparam is not None: calls.append( unittest.mock.call( unittest.mock.ANY, hparams=hparam, get_bounds_params=get_bounds.GetBounds.Params( update_stats=train, update_bounds=update_bounds, paxis_name=paxis_name, mask=unittest.mock.ANY, module_name=unittest.mock.ANY))) mock_inputs_fake_quant.assert_has_calls(calls, any_order=True) self.assertLen(calls, mock_inputs_fake_quant.call_count)
def test_quant_granularity(self, _, mock_quantized_dot, granularity, axis): hparams = flax_layers.DenseAqt.HParams( weight_prec=8, quant_act=None, quant_type=quantization.QuantType.fake_quant, weight_quant_granularity=granularity) layer = flax_layers.DenseAqt( features=2, hparams=hparams, quant_context=quant_config.QuantContext( update_bounds=False, collect_acts_stats=False), paxis_name=None, train=False, dtype=jnp.float32) x = jnp.ones((2, 2)) state = layer.init(self.rng_key, x, padding_mask=None) layer.apply(state, x, padding_mask=None) weight_params = mock_quantized_dot.call_args[1]['weight_params'] self.assertEqual(weight_params.axis, axis)
def create_model(key, batch_size, image_size, model_dtype, hparams, train=True, **kwargs): """Creates the ResNet model using hparams.""" input_shape = (batch_size, image_size, image_size, 3) model = models.ResNet( num_classes=1000, dtype=model_dtype, hparams=hparams, quant_context=quant_config.QuantContext(update_bounds=False), num_filters=64, train=train, **kwargs) init_state = model.init(key, jnp.zeros(input_shape, dtype=model_dtype)) return model, init_state
def test_group_conv(self, weight_prec=None): x = jnp.ones((1, 8, 8, 4)) conv_module = flax_layers.ConvAqt( features=4, kernel_size=(3, 3), feature_group_count=2, padding='VALID', paxis_name='batch', quant_context=quant_config.QuantContext(update_bounds=False), train=False, hparams=flax_layers.ConvAqt.HParams( weight_prec=weight_prec, quant_act=None, quant_type=QuantType.fake_quant), kernel_init=initializers.ones, bias_init=initializers.ones, dtype=jnp.float32) y, state = conv_module.init_with_output(self.rng_key, x) self.assertEqual(state['params']['kernel'].shape, (3, 3, 2, 4)) test_utils.assert_all_close_prec(y, onp.full((1, 6, 6, 4), 19.), weight_prec)
def test_epsilon_rounding(self): # We give LayerNorm a constant input. Since that input has a variance of # zero, we would expect layernorm to return NaN (0/0) unless the 'epsilon' # parameter which nudges the denominator away from zero was having an # effect. We test the case where the default epsilon value of 1e-6 would # ordinarily flush to zero after quantization with a high value of exp_min. # This test makes sure our code to round epsilon up to the smallest non-zero # representable value is wokring. hparams = self.make_hparams( exp_min=-2**2, exp_max=2**7, sig_bits=23, quantize_reductions=False) layer_norm = flax_layers.LayerNormAqt( hparams=hparams, use_bias=False, use_scale=False, epsilon=1e-6, dtype=jnp.float32, quant_context=quant_config.QuantContext( update_bounds=False, quantize_acts=True)) x = jnp.ones((2, 5)) y = layer_norm.apply({}, x) onp.testing.assert_equal(onp.array(y), onp.zeros(x.shape))
def test_quantized_layer_norm_matches_unquantized_in_fp32( self, quantize_acts, quantize_reductions): # We 'quantize' to a custom floating-point format that is approximately # equivalent to IEEE float32 and test that results are the same as using # Flax's upstream unquantized LayerNorm. hparams = self.make_hparams( exp_min=-2**7, exp_max=2**7, sig_bits=23, quantize_reductions=quantize_reductions) quantized_layer_norm = flax_layers.LayerNormAqt( hparams=hparams, dtype=jnp.float32, quant_context=quant_config.QuantContext( update_bounds=False, quantize_acts=quantize_acts)) x_rng, param_rng = jax.random.split(self.rng) x = jax.random.normal(x_rng, (3, 5)) initial_params = quantized_layer_norm.init(param_rng, x) y_quantized = quantized_layer_norm.apply(initial_params, x) unquantized_layer_norm = nn.LayerNorm() y_unquantized = unquantized_layer_norm.apply(initial_params, x) onp.testing.assert_allclose(y_quantized, y_unquantized, rtol=2e-6)
def test_annotation_only_changes_hlo_metadata_dense( self, weight_prec, acts_prec): FLAGS.metadata_enabled = False quant_act = quantization.QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, prec=acts_prec, bounds=1.0, half_shift=False) input_shape = (1, 16) module_no_annotation = aqt_flax_layers.DenseAqt( features=4, use_bias=False, quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), paxis_name='batch', train=False, hparams=aqt_flax_layers.DenseAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity. per_channel, weight_half_shift=False), dtype=jnp.float32) init_state = module_no_annotation.init(self.rng_key, jnp.ones( input_shape, jnp.float32), padding_mask=None) output_no_annotation = module_no_annotation.apply( init_state, jnp.ones(input_shape), padding_mask=None) hlo_no_annotation = hlo_utils.load_hlo_proto_from_model( module_no_annotation, init_state, [input_shape], padding_mask=None) del init_state FLAGS.metadata_enabled = True module_w_annotation = aqt_flax_layers.DenseAqt( features=4, use_bias=False, paxis_name='batch', train=False, quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), dtype=jnp.float32, hparams=aqt_flax_layers.DenseAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity. per_channel, weight_half_shift=False), ) init_state = module_w_annotation.init(self.rng_key, jnp.ones(input_shape, jnp.float32), padding_mask=None) output_w_annotation = module_w_annotation.apply(init_state, jnp.ones(input_shape), padding_mask=None) hlo_w_annotation = hlo_utils.load_hlo_proto_from_model( module_w_annotation, init_state, [input_shape], padding_mask=None) del init_state onp.testing.assert_array_equal(output_no_annotation, output_w_annotation) self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
def test_hparams_without_logits_when_logits_not_shared_raises_error(self): # Create hparams without logits hparams by passing in # logits_via_embeddings=True. inputs_hyper = get_bounds.GetBounds.Hyper( initial_bound=6.0, stddev_coeff=3.0, absdev_coeff=2.0, mix_coeff=0.5, granularity=quant_config.QuantGranularity.per_channel) hparams = training_hparams_generator_lib.create_base_transformer_hparams( mlp_weight_prec=8, embedding_weight_prec=None, attention_weight_prec=8, mlp_pos_inputs_prec=8, mlp_pos_inputs_hyper=inputs_hyper, mlp_signed_inputs_prec=8, mlp_signed_inputs_hyper=inputs_hyper, attention_kqv_inputs_prec=8, attention_kqv_inputs_hyper=inputs_hyper, attention_out_inputs_prec=8, attention_out_inputs_hyper=inputs_hyper, logits_inputs_prec=8, logits_inputs_hyper=inputs_hyper, logits_via_embeddings=True, attention_act_q_inputs_prec=8, attention_act_q_inputs_hyper=inputs_hyper, attention_act_k_inputs_prec=8, attention_act_k_inputs_hyper=inputs_hyper, attention_act_probs_inputs_prec=8, attention_act_v_inputs_prec=8, attention_act_v_inputs_hyper=inputs_hyper, num_layers=2, emb_dim=5, num_heads=2, qkv_dim=4, mlp_dim=4, quant_type=QuantType.fake_quant) self.assertIsNone(hparams.decoder.logits) # Now set logits_via_embedding in the model hparams to False. hparams.logits_via_embedding = False module = models.Transformer(hparams=hparams, quant_context=quant_config.QuantContext( update_bounds=True, collect_acts_stats=True), vocab_size=3, output_vocab_size=3, max_len=10, use_bfloat16=False, train=False, dropout_rate=.1, attention_dropout_rate=.1, should_decode=False) key = jax.random.PRNGKey(0) # Mark the first token of the target and last token of the inputs as padding # tokens. targets = onp.array([[0, 2]]) inputs = onp.array([[1, 0]]) # Because the model is not sharing logits with embeddings, but the logits # hparams are missing, it should raise an error. with self.assertRaises(ValueError): module.init(key, inputs=inputs, targets=targets)
def test_padding_mask(self): # Fuzzing test to make sure activation statistics aren't affected by padding # tokens. # # This tests works by changing the embedding of the padding token (token # with id '0') and making sure all the stats stay the same. # # It also tests that the stats *do* change when the embedding of a # non-padding token changes. inputs_hyper = get_bounds.GetBounds.Hyper( initial_bound=6.0, stddev_coeff=3.0, absdev_coeff=2.0, mix_coeff=0.5, granularity=quant_config.QuantGranularity.per_channel) # Set logits_via_embedding to false so that the embedding of the padding # token doesn't affect the logits calculation at the end of the decoder. hparams = training_hparams_generator_lib.create_base_transformer_hparams( mlp_weight_prec=8, embedding_weight_prec=None, attention_weight_prec=8, mlp_pos_inputs_prec=8, mlp_pos_inputs_hyper=inputs_hyper, mlp_signed_inputs_prec=8, mlp_signed_inputs_hyper=inputs_hyper, attention_kqv_inputs_prec=8, attention_kqv_inputs_hyper=inputs_hyper, attention_out_inputs_prec=8, attention_out_inputs_hyper=inputs_hyper, logits_inputs_prec=8, logits_inputs_hyper=inputs_hyper, logits_via_embeddings=False, attention_act_q_inputs_prec=8, attention_act_q_inputs_hyper=inputs_hyper, attention_act_k_inputs_prec=8, attention_act_k_inputs_hyper=inputs_hyper, attention_act_probs_inputs_prec=8, attention_act_v_inputs_prec=8, attention_act_v_inputs_hyper=inputs_hyper, num_layers=2, emb_dim=5, num_heads=2, qkv_dim=4, mlp_dim=4, quant_type=QuantType.fake_quant) module = models.Transformer(hparams=hparams, quant_context=quant_config.QuantContext( update_bounds=True, collect_acts_stats=True), vocab_size=3, output_vocab_size=3, max_len=10, train=False, use_bfloat16=False, dropout_rate=.1, attention_dropout_rate=.1, should_decode=False) key = jax.random.PRNGKey(0) # Mark the first token of the target and last token of the inputs as padding # tokens. targets = onp.array([[0, 2]]) inputs = onp.array([[1, 0]]) initial_state = module.init(key, inputs=inputs, targets=targets) # Change the embedding of the padding token. initial_state = initial_state.unfreeze() initial_state['params']['shared_embedding'][ 'embedding'] = initial_state['params']['shared_embedding'][ 'embedding'].at[0, :].set(10.0) module.train = True _, state1 = module.apply(flax.core.freeze(initial_state), inputs=inputs, targets=targets, mutable=True, rngs={'dropout': key}) initial_state['params']['shared_embedding'][ 'embedding'] = initial_state['params']['shared_embedding'][ 'embedding'].at[0, :].set(20.0) _, state2 = module.apply(flax.core.freeze(initial_state), inputs=inputs, targets=targets, mutable=True, rngs={'dropout': key}) # This tests the statistics in both the GetBounds and StatsTag modules. test_utils.assert_stats_are_equal(state1, state2) # Now we repeat the test, but changing the embedding of a non-padding token # (token with ID 1 here). We expect to see the stats change. # print(initial_state) initial_state['params']['shared_embedding'][ 'embedding'] = initial_state['params']['shared_embedding'][ 'embedding'].at[1, :].set(10.0) _, state1 = module.apply(flax.core.freeze(initial_state), inputs=inputs, targets=targets, mutable=True, rngs={'dropout': key}) initial_state['params']['shared_embedding'][ 'embedding'] = initial_state['params']['shared_embedding'][ 'embedding'].at[1, :].set(200.0) _, state2 = module.apply(flax.core.freeze(initial_state), inputs=inputs, targets=targets, mutable=True, rngs={'dropout': key}) print(initial_state['get_bounds']['encoder']['encoderblock_0'] ['enc_self_att']['K']['bounds']) print(state1['get_bounds']['encoder']['encoderblock_0']['enc_self_att'] ['K']['bounds']) print(state2['get_bounds']['encoder']['encoderblock_0']['enc_self_att'] ['K']['bounds']) print('') test_utils.assert_stats_are_unequal(state1, state2)
def encoder_from_file(config, batch_size=8, encode_length=16, use_bfloat16=True, use_xla_optimizations=True): """Generates HLO for just the encoder of the WMT model. Args: config: A ConfigDict instance. batch_size: Batch size. encode_length: Max length of an input sentence. use_bfloat16: Use bfloat16 mixed precision training instead of float32. use_xla_optimizations: Whether to use xla optimizations. """ if FLAGS.checkpoint: raise app.UsageError('Checkpoints not yet supported for WMT encoder.') input_shape = (batch_size, encode_length) rng = jax.random.PRNGKey(0) hparams = hparams_utils.load_dataclass_from_config_dict( training_hparams.TrainingHParams, config) model_hparams = hparams.model_hparams model = models.Encoder(vocab_size=32711, hparams=model_hparams.encoder, shared_embedding=None, use_bfloat16=use_bfloat16, emb_dim=model_hparams.emb_dim, num_heads=model_hparams.num_heads, qkv_dim=model_hparams.qkv_dim, mlp_dim=model_hparams.mlp_dim, max_len=encode_length, train=False, dropout_rate=0.1, attention_dropout_rate=0.1, quant_context=quant_config.QuantContext( update_bounds=False, collect_acts_stats=False, quantize_acts=True)) init_state = model.init(rng, jnp.ones(input_shape, jnp.float32)) def _fn(state, inputs): return model.apply(state, inputs, mutable=False) if not use_xla_optimizations: computation = jax.xla_computation(_fn)(init_state, jnp.ones( input_shape, jnp.float32)) hlo_utils.output_hlo(computation, FLAGS.hlo_output) else: def _wrapped_fn(inputs): return _fn(init_state, inputs) def to_shape_str(shape_tuple): return 'f32[%s]' % ','.join(map(str, shape_tuple)) hlo_module_proto_str, hlo_txt = jax_to_ir.jax_to_hlo( _wrapped_fn, [('inputs', jax_to_ir.parse_shape_str(to_shape_str(input_shape)))]) hlo_utils.output_hlo_to_file(hlo_module_proto_str, hlo_txt, FLAGS.hlo_output)