def init_model_with_1_layer(self, inputs, num_features, kernel_init=flax_layers.default_kernel_init, weight_prec=None, quant_act=None, weight_half_shift=False): """Create and initialize a flax model with a single DenseAqt layer.""" quant_context = quant_config.QuantContext( update_bounds=False, collect_acts_stats=False) layer_kwargs = { 'kernel_init': kernel_init, 'features': num_features, 'use_bias': False, 'quant_context': quant_context, 'paxis_name': 'batch', 'train': False, 'dtype': jnp.float32 } layer_kwargs['hparams'] = flax_layers.DenseAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity.per_channel, weight_half_shift=weight_half_shift) dense_module = flax_layers.DenseAqt(**layer_kwargs) initial_state = dense_module.init( self.rng_key, jnp.zeros(inputs.shape), padding_mask=None) return dense_module, initial_state
def __call__( self, inputs, ): """Applies ResNet model. Number of residual blocks inferred from hparams.""" num_classes = self.num_classes hparams = self.hparams num_filters = self.num_filters dtype = self.dtype x = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=dtype, name='init_conv', train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.conv_init, )(inputs) x = nn.BatchNorm(use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') filter_multiplier = hparams.filter_multiplier for i, block_hparams in enumerate(hparams.residual_blocks): proj = block_hparams.conv_proj # For projection layers (unless it is the first layer), strides = (2, 2) if i > 0 and proj is not None: filter_multiplier *= 2 strides = (2, 2) else: strides = (1, 1) x = ResidualBlock(filters=int(num_filters * filter_multiplier), hparams=block_hparams, quant_context=self.quant_context, strides=strides, train=self.train, dtype=dtype)(x) x = jnp.mean(x, axis=(1, 2)) x = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.dense_layer, )(x, padding_mask=None) x = jnp.asarray(x, dtype) output = nn.log_softmax(x) return output
def test_padding(self): """Test that padding results in the right statistics being collected.""" # Exact values don't matter here, we just need code to think it's using # dynamic bounds so it gathers activation statistics bounds = get_bounds.GetBounds.Hyper( initial_bound=0.0, stddev_coeff=1.0, absdev_coeff=0.0, mix_coeff=1.0, reset_stats=False, granularity=quant_config.QuantGranularity.per_channel) quant_act = flax_layers.QuantOps.ActHParams( input_distribution=flax_layers.QuantOps.ActHParams.InputDistribution .symmetric, prec=8, bounds=bounds) hparams = flax_layers.DenseAqt.HParams( quant_type=flax_layers.QuantType.fake_quant, weight_prec=8, quant_act=quant_act, weight_quant_granularity=quant_config.QuantGranularity.per_channel) module = flax_layers.DenseAqt( hparams=hparams, features=1, paxis_name=None, quant_context=quant_config.QuantContext( update_bounds=True, collect_acts_stats=False), train=True, dtype=jnp.float32) # Simulate an input with a batch size of 2, three tokens per example, two # channels per token x = jnp.arange(12).astype(jnp.float32).reshape((2, 3, 2)) # Reshape it to have dimensions [batch, feature] x = x.reshape(6, 2) initial_state = module.init(self.rng_key, x, padding_mask=None) # Check that the per-channel activation statistics are as expected with no # padding _, state_nopadding = module.apply( initial_state, x, padding_mask=None, mutable='get_bounds') expected_means = onp.array([[(0 + 2 + 4 + 6 + 8 + 10) / 6, (1 + 3 + 5 + 7 + 9 + 11) / 6]]) actual_means = state_nopadding['get_bounds']['GetBounds_0']['stats'].mean onp.testing.assert_allclose(actual_means, expected_means) # Now we pad out some of the tokens (chosen arbitrarily) and check that the # computed per-channel stats are the means of the non-padding tokens only # Exclude the second and third tokens from the first batch and the first # token from the second batch. padding_mask = jnp.array([[True, False, False], [False, True, True]]) # Reshape it to have dimensions [batch, feature] padding_mask = padding_mask.reshape(6, 1) _, state_padding = module.apply( initial_state, x, padding_mask=padding_mask, mutable='get_bounds') expected_means = onp.array([[(0 + 8 + 10) / 3, (1 + 9 + 11) / 3]]) actual_means = state_padding['get_bounds']['GetBounds_0']['stats'].mean onp.testing.assert_allclose(actual_means, expected_means)
def __call__( self, inputs, *, padding_mask, ): """Applies Transformer MlpBlock module.""" batch_size, sequence_length, channel_size = inputs.shape inputs = inputs.reshape((batch_size * sequence_length, channel_size)) shape_utils.assert_shapes_equal(padding_mask.shape, (batch_size, sequence_length, 1)) padding_mask = padding_mask.reshape((batch_size * sequence_length, 1)) x = aqt_flax_layers.DenseAqt(features=self.mlp_dim, dtype=self.dtype, paxis_name='batch', train=self.train, quant_context=self.quant_context, hparams=self.hparams.dense_1, kernel_init=self.kernel_init, bias_init=self.bias_init, name='dense_1')(inputs, padding_mask=padding_mask) x = nn.relu(x) x = nn.Dropout(rate=self.dropout_rate)( x, deterministic=self.deterministic) output = aqt_flax_layers.DenseAqt( # We have relu before this layer, x would only contain positive values. features=channel_size, dtype=self.dtype, paxis_name='batch', train=self.train, quant_context=self.quant_context, hparams=self.hparams.dense_2, kernel_init=self.kernel_init, bias_init=self.bias_init, name='dense_2')(x, padding_mask=padding_mask) output = nn.Dropout(rate=self.dropout_rate)( output, deterministic=self.deterministic) output = output.reshape((batch_size, sequence_length, channel_size)) return output
def __call__(self, inputs, hparams, num_classes, dtype=jnp.float32): output = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=False, quant_context=quant_config.QuantContext( update_bounds=False, collect_acts_stats=False), paxis_name='batch', hparams=hparams, )(inputs, padding_mask=None) return output
def multi_batch_dense_aqt(inputs, *, name, padding_mask): batch_size, sequence_length, channel_size = inputs.shape inputs = inputs.reshape(batch_size * sequence_length, channel_size) if padding_mask is not None: padding_mask = padding_mask.reshape(batch_size * sequence_length, 1) out = flax_layers.DenseAqt( name=name, features=num_heads * head_dim, paxis_name=paxis_name, train=train, quant_context=self.quant_context, hparams=hparams.dense_kqv, kernel_init=kernel_init, bias_init=bias_init, use_bias=use_bias, dtype=dtype)( inputs, padding_mask=padding_mask) return out.reshape(batch_size, sequence_length, num_heads, head_dim)
def test_quant_granularity(self, _, mock_quantized_dot, granularity, axis): hparams = flax_layers.DenseAqt.HParams( weight_prec=8, quant_act=None, quant_type=quantization.QuantType.fake_quant, weight_quant_granularity=granularity) layer = flax_layers.DenseAqt( features=2, hparams=hparams, quant_context=quant_config.QuantContext( update_bounds=False, collect_acts_stats=False), paxis_name=None, train=False, dtype=jnp.float32) x = jnp.ones((2, 2)) state = layer.init(self.rng_key, x, padding_mask=None) layer.apply(state, x, padding_mask=None) weight_params = mock_quantized_dot.call_args[1]['weight_params'] self.assertEqual(weight_params.axis, axis)
def test_annotation_only_changes_hlo_metadata_dense( self, weight_prec, acts_prec): FLAGS.metadata_enabled = False quant_act = quantization.QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, prec=acts_prec, bounds=1.0, half_shift=False) input_shape = (1, 16) module_no_annotation = aqt_flax_layers.DenseAqt( features=4, use_bias=False, quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), paxis_name='batch', train=False, hparams=aqt_flax_layers.DenseAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity. per_channel, weight_half_shift=False), dtype=jnp.float32) init_state = module_no_annotation.init(self.rng_key, jnp.ones( input_shape, jnp.float32), padding_mask=None) output_no_annotation = module_no_annotation.apply( init_state, jnp.ones(input_shape), padding_mask=None) hlo_no_annotation = hlo_utils.load_hlo_proto_from_model( module_no_annotation, init_state, [input_shape], padding_mask=None) del init_state FLAGS.metadata_enabled = True module_w_annotation = aqt_flax_layers.DenseAqt( features=4, use_bias=False, paxis_name='batch', train=False, quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), dtype=jnp.float32, hparams=aqt_flax_layers.DenseAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_quant_granularity=quant_config.QuantGranularity. per_channel, weight_half_shift=False), ) init_state = module_w_annotation.init(self.rng_key, jnp.ones(input_shape, jnp.float32), padding_mask=None) output_w_annotation = module_w_annotation.apply(init_state, jnp.ones(input_shape), padding_mask=None) hlo_w_annotation = hlo_utils.load_hlo_proto_from_model( module_w_annotation, init_state, [input_shape], padding_mask=None) del init_state onp.testing.assert_array_equal(output_no_annotation, output_w_annotation) self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
def __call__(self, inputs_q, inputs_kv, *, padding_mask, key_padding_mask, segmentation=None, key_segmentation=None): """Applies multi-head dot product attention on the input data. If weight_prec is not None, scales and quantizes weights to signed int with weight_prec bits. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` or for self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. padding_mask: boolean tensor specifying query tokens that are pad token. key_padding_mask: boolean tensor specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ batch_size, query_sequence_length, channel_size = inputs_q.shape hparams = self.hparams if inputs_kv is None: inputs_kv = inputs_q key_sequence_length = inputs_q.shape[1] else: key_sequence_length = inputs_kv.shape[1] shape_utils.assert_shapes_equal( inputs_kv.shape, (batch_size, key_sequence_length, channel_size)) jax_precision = jax.lax.Precision.DEFAULT if padding_mask is not None: shape_utils.assert_shapes_equal( padding_mask.shape, (batch_size, query_sequence_length, 1)) if key_padding_mask is None: key_padding_mask = padding_mask else: shape_utils.assert_shapes_equal( key_padding_mask.shape, (batch_size, key_sequence_length, 1)) attention_axis = self.attention_axis if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) qkv_features = self.qkv_features qkv_features = qkv_features or inputs_q.shape[-1] num_heads = self.num_heads assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads paxis_name = self.paxis_name train = self.train kernel_init = self.kernel_init bias_init = self.bias_init use_bias = self.use_bias dtype = self.dtype def multi_batch_dense_aqt(inputs, *, name, padding_mask): batch_size, sequence_length, channel_size = inputs.shape inputs = inputs.reshape(batch_size * sequence_length, channel_size) if padding_mask is not None: padding_mask = padding_mask.reshape( batch_size * sequence_length, 1) out = flax_layers.DenseAqt(name=name, features=num_heads * head_dim, paxis_name=paxis_name, train=train, quant_context=self.quant_context, hparams=hparams.dense_kqv, kernel_init=kernel_init, bias_init=bias_init, use_bias=use_bias, dtype=dtype)(inputs, padding_mask=padding_mask) return out.reshape(batch_size, sequence_length, num_heads, head_dim) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, sequence_length, n_heads, n_features_per_head] query = multi_batch_dense_aqt(inputs_q, name='query', padding_mask=padding_mask) key = multi_batch_dense_aqt(inputs_kv, name='key', padding_mask=key_padding_mask) value = multi_batch_dense_aqt(inputs_kv, name='value', padding_mask=key_padding_mask) is_cache_initialized = False if self.decode: is_cache_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_cache_initialized: expected_shape = list(cached_key.value.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) cshape = cached_key.value.shape indices = [0] * len(cshape) i = cache_index.value attn_size = onp.prod(onp.take(cshape, attention_axis)) *batch_dims, max_length, num_heads, depth_per_head = ( # pylint: disable=unused-variable cached_key.value.shape) indices = (0, ) * len(batch_dims) + (i, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) one = jnp.array(1, jnp.int32) cache_index.value = cache_index.value + one cached_key.value = key cached_value.value = value # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(max_length) < cache_index.value), cshape[:2]) key_padding_mask = key_padding_mask.astype( jnp.float32)[Ellipsis, None] # create attention masks mask_components = [] if self.causal_mask: if self.decode and is_cache_initialized: bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.int32) mask = ii < cache_index.value mask_components.append( mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask attn_padding_mask = make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(attn_padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) attention_mask = None if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) attention_mask = attention_mask.astype(jnp.bool_) # attention mask in the form of attention bias attention_bias = jnp.where( attention_mask, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # Add an extra dimension to the mask corresponding to the head # dimension. eg, if inputs_q has shape [batch_size, sequence_length, # n_features], then padding_mask will have a shape # [batch_size, sequence_length, 1] and query will have shape # [batch_size, sequence_length, n_heads, n_features_per_head]. # We create query_padding_mask with shape [batch_size, sequence_length, # 1, 1] to be broadcast-compatible with 'query'. if padding_mask is not None: padding_mask = padding_mask[Ellipsis, None] shape_utils.assert_shapes_equal( padding_mask.shape, (batch_size, query_sequence_length, 1, 1)) if key_padding_mask is not None: key_padding_mask = key_padding_mask[Ellipsis, None] # During prediction, the key padding mask is only going to be # broadcast-compatible with the key. shape_utils.assert_shapes_compatible( key_padding_mask.shape, (batch_size, key_sequence_length, 1, 1)) # apply attention attention_fn = self.attention_fn dropout_rate = self.dropout_rate broadcast_dropout = self.broadcast_dropout deterministic = self.deterministic if not deterministic and self.dropout_rate > 0.0: dropout_rng = self.make_rng('dropout') else: dropout_rng = None x = attention_fn( # pylint: disable=redundant-keyword-arg query=query, key=key, value=value, hparams=hparams.attn_acts, paxis_name=paxis_name, train=train, quant_context=self.quant_context, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=jax_precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic, query_padding_mask=padding_mask, key_padding_mask=key_padding_mask, attn_mask=attention_mask) shape_utils.assert_shapes_equal( x.shape, (batch_size, query_sequence_length, num_heads, head_dim)) x = x.reshape(batch_size * query_sequence_length, num_heads * head_dim) if padding_mask is not None: padding_mask = padding_mask.reshape( batch_size * query_sequence_length, 1) # back to the original inputs dimensions out = flax_layers.DenseAqt(features=channel_size, hparams=hparams.dense_out, quant_context=self.quant_context, paxis_name=paxis_name, train=train, kernel_init=kernel_init, bias_init=bias_init, use_bias=use_bias, dtype=dtype, name='dense_out')(x, padding_mask=padding_mask) shape_utils.assert_shapes_equal( out.shape, (batch_size * query_sequence_length, channel_size)) out = out.reshape(batch_size, query_sequence_length, channel_size) return out
def __call__( self, encoded, src_padding_mask, targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, tgt_padding_mask=None, ): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. src_padding_mask: padding mask for inputs. targets: target inputs. targets_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. tgt_padding_mask: target tokens padding mask. Returns: output of a transformer decoder. """ batch_size, sequence_length, channel_size = encoded.shape # pylint: disable=unused-variable target_batch_size, target_sequence_length = targets.shape # pylint: disable=unused-variable shape_utils.assert_shapes_equal(targets.shape, (batch_size, target_sequence_length)) # Padding Masks if tgt_padding_mask is None: tgt_padding_mask = (targets > 0)[Ellipsis, None] shape_utils.assert_shapes_equal( tgt_padding_mask.shape, (batch_size, target_sequence_length, 1)) if self.use_bfloat16: dtype = jnp.bfloat16 else: dtype = jnp.float32 # Target Embedding if self.shared_embedding is None: output_embed = aqt_flax_layers.EmbedAqt( num_embeddings=self.output_vocab_size, features=self.emb_dim, hparams=self.hparams.embedding, embedding_init=nn.initializers.normal( stddev=self.emb_dim**-0.5), dtype=dtype, name='target_embed', train=self.train, quant_context=self.quant_context, paxis_name='batch') else: output_embed = self.shared_embedding y = targets.astype('int32') if not self.decode: y = shift_right(y) y = output_embed(y) * jnp.sqrt(self.emb_dim) y = AddPositionEmbs(name='posembed_targets', max_len=self.max_len, decode=self.decode, min_timescale=1.0, max_timescale=10000.0)( y, inputs_positions=targets_positions) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not self.train) if self.use_bfloat16: y = y.astype(jnp.bfloat16) # Target-Input Decoder num_layers = len(self.hparams.encoder_decoder_1d_blocks) for lyr in range(num_layers): y = EncoderDecoder1DBlock( train=self.train, quant_context=self.quant_context, qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, hparams=self.hparams.encoder_decoder_1d_blocks[lyr], dtype=dtype, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, deterministic=not self.train, name=f'encoderdecoderblock_{lyr}', decode=self.decode)(y, encoded, padding_mask=tgt_padding_mask, key_padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation) y = aqt_flax_layers.LayerNormAqt(dtype=dtype, name='encoderdecoder_norm', hparams=self.hparams.layer_norm, quant_context=self.quant_context)(y) y = y.reshape((batch_size * target_sequence_length, channel_size)) tgt_padding_mask = tgt_padding_mask.reshape( (batch_size * target_sequence_length, 1)) # Decoded Logits if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(query=y, padding_mask=tgt_padding_mask, paxis_name=self.paxis_name, train=self.train) else: if self.hparams.logits is None: raise ValueError( 'If logits_via_embedding is False, then the hparams ' 'for the logits layer have to be provided.') logits = aqt_flax_layers.DenseAqt( features=self.output_vocab_size, dtype=dtype, paxis_name='batch', train=self.train, quant_context=self.quant_context, hparams=self.hparams.logits, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), name='logits_dense')(y, padding_mask=tgt_padding_mask) return logits
def __call__( self, inputs, ): """Applies ResNet model. Number of residual blocks inferred from hparams.""" num_classes = self.num_classes hparams = self.hparams num_filters = self.num_filters dtype = self.dtype assert hparams.act_function in act_function_zoo.keys( ), 'Activation function type is not supported.' x = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=dtype, name='init_conv', train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.conv_init, )( inputs) x = nn.BatchNorm( use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn')( x) if hparams.act_function == 'relu': x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') else: # TODO(yichi): try adding other activation functions here # Use avg pool so that for binary nets, the distribution is symmetric. x = nn.avg_pool(x, (3, 3), strides=(2, 2), padding='SAME') filter_multiplier = hparams.filter_multiplier for i, block_hparams in enumerate(hparams.residual_blocks): proj = block_hparams.conv_proj # For projection layers (unless it is the first layer), strides = (2, 2) if i > 0 and proj is not None: filter_multiplier *= 2 strides = (2, 2) else: strides = (1, 1) x = ResidualBlock( filters=int(num_filters * filter_multiplier), hparams=block_hparams, quant_context=self.quant_context, strides=strides, train=self.train, dtype=dtype)( x) if hparams.act_function == 'none': # The DenseAQT below is not binarized. # If removing the activation functions, there will be no act function # between the last residual block and the dense layer. # So add a ReLU in that case. # TODO(yichi): try BPReLU x = nn.relu(x) else: pass x = jnp.mean(x, axis=(1, 2)) x = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.dense_layer, )(x, padding_mask=None) x = jnp.asarray(x, dtype) output = nn.log_softmax(x) return output