Ejemplo n.º 1
0
  def init_model_with_1_layer(self,
                              inputs,
                              num_features,
                              kernel_init=flax_layers.default_kernel_init,
                              weight_prec=None,
                              quant_act=None,
                              weight_half_shift=False):
    """Create and initialize a flax model with a single DenseAqt layer."""
    quant_context = quant_config.QuantContext(
        update_bounds=False, collect_acts_stats=False)
    layer_kwargs = {
        'kernel_init': kernel_init,
        'features': num_features,
        'use_bias': False,
        'quant_context': quant_context,
        'paxis_name': 'batch',
        'train': False,
        'dtype': jnp.float32
    }
    layer_kwargs['hparams'] = flax_layers.DenseAqt.HParams(
        weight_prec=weight_prec,
        quant_act=quant_act,
        quant_type=QuantType.fake_quant,
        weight_quant_granularity=quant_config.QuantGranularity.per_channel,
        weight_half_shift=weight_half_shift)

    dense_module = flax_layers.DenseAqt(**layer_kwargs)
    initial_state = dense_module.init(
        self.rng_key, jnp.zeros(inputs.shape), padding_mask=None)
    return dense_module, initial_state
Ejemplo n.º 2
0
    def __call__(
        self,
        inputs,
    ):
        """Applies ResNet model. Number of residual blocks inferred from hparams."""
        num_classes = self.num_classes
        hparams = self.hparams
        num_filters = self.num_filters
        dtype = self.dtype

        x = aqt_flax_layers.ConvAqt(
            features=num_filters,
            kernel_size=(7, 7),
            strides=(2, 2),
            padding=[(3, 3), (3, 3)],
            use_bias=False,
            dtype=dtype,
            name='init_conv',
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.conv_init,
        )(inputs)
        x = nn.BatchNorm(use_running_average=not self.train,
                         momentum=0.9,
                         epsilon=1e-5,
                         dtype=dtype,
                         name='init_bn')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        filter_multiplier = hparams.filter_multiplier
        for i, block_hparams in enumerate(hparams.residual_blocks):
            proj = block_hparams.conv_proj
            # For projection layers (unless it is the first layer), strides = (2, 2)
            if i > 0 and proj is not None:
                filter_multiplier *= 2
                strides = (2, 2)
            else:
                strides = (1, 1)
            x = ResidualBlock(filters=int(num_filters * filter_multiplier),
                              hparams=block_hparams,
                              quant_context=self.quant_context,
                              strides=strides,
                              train=self.train,
                              dtype=dtype)(x)
        x = jnp.mean(x, axis=(1, 2))

        x = aqt_flax_layers.DenseAqt(
            features=num_classes,
            dtype=dtype,
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.dense_layer,
        )(x, padding_mask=None)

        x = jnp.asarray(x, dtype)
        output = nn.log_softmax(x)
        return output
Ejemplo n.º 3
0
  def test_padding(self):
    """Test that padding results in the right statistics being collected."""
    # Exact values don't matter here, we just need code to think it's using
    # dynamic bounds so it gathers activation statistics
    bounds = get_bounds.GetBounds.Hyper(
        initial_bound=0.0,
        stddev_coeff=1.0,
        absdev_coeff=0.0,
        mix_coeff=1.0,
        reset_stats=False,
        granularity=quant_config.QuantGranularity.per_channel)
    quant_act = flax_layers.QuantOps.ActHParams(
        input_distribution=flax_layers.QuantOps.ActHParams.InputDistribution
        .symmetric,
        prec=8,
        bounds=bounds)
    hparams = flax_layers.DenseAqt.HParams(
        quant_type=flax_layers.QuantType.fake_quant,
        weight_prec=8,
        quant_act=quant_act,
        weight_quant_granularity=quant_config.QuantGranularity.per_channel)
    module = flax_layers.DenseAqt(
        hparams=hparams,
        features=1,
        paxis_name=None,
        quant_context=quant_config.QuantContext(
            update_bounds=True, collect_acts_stats=False),
        train=True,
        dtype=jnp.float32)

    # Simulate an input with a batch size of 2, three tokens per example, two
    # channels per token
    x = jnp.arange(12).astype(jnp.float32).reshape((2, 3, 2))
    # Reshape it to have dimensions [batch, feature]
    x = x.reshape(6, 2)

    initial_state = module.init(self.rng_key, x, padding_mask=None)

    # Check that the per-channel activation statistics are as expected with no
    # padding
    _, state_nopadding = module.apply(
        initial_state, x, padding_mask=None, mutable='get_bounds')
    expected_means = onp.array([[(0 + 2 + 4 + 6 + 8 + 10) / 6,
                                 (1 + 3 + 5 + 7 + 9 + 11) / 6]])
    actual_means = state_nopadding['get_bounds']['GetBounds_0']['stats'].mean
    onp.testing.assert_allclose(actual_means, expected_means)

    # Now we pad out some of the tokens (chosen arbitrarily) and check that the
    # computed per-channel stats are the means of the non-padding tokens only
    # Exclude the second and third tokens from the first batch and the first
    # token from the second batch.
    padding_mask = jnp.array([[True, False, False], [False, True, True]])
    # Reshape it to have dimensions [batch, feature]
    padding_mask = padding_mask.reshape(6, 1)
    _, state_padding = module.apply(
        initial_state, x, padding_mask=padding_mask, mutable='get_bounds')
    expected_means = onp.array([[(0 + 8 + 10) / 3, (1 + 9 + 11) / 3]])
    actual_means = state_padding['get_bounds']['GetBounds_0']['stats'].mean
    onp.testing.assert_allclose(actual_means, expected_means)
Ejemplo n.º 4
0
    def __call__(
        self,
        inputs,
        *,
        padding_mask,
    ):
        """Applies Transformer MlpBlock module."""

        batch_size, sequence_length, channel_size = inputs.shape
        inputs = inputs.reshape((batch_size * sequence_length, channel_size))
        shape_utils.assert_shapes_equal(padding_mask.shape,
                                        (batch_size, sequence_length, 1))
        padding_mask = padding_mask.reshape((batch_size * sequence_length, 1))
        x = aqt_flax_layers.DenseAqt(features=self.mlp_dim,
                                     dtype=self.dtype,
                                     paxis_name='batch',
                                     train=self.train,
                                     quant_context=self.quant_context,
                                     hparams=self.hparams.dense_1,
                                     kernel_init=self.kernel_init,
                                     bias_init=self.bias_init,
                                     name='dense_1')(inputs,
                                                     padding_mask=padding_mask)

        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate)(
            x, deterministic=self.deterministic)

        output = aqt_flax_layers.DenseAqt(
            # We have relu before this layer, x would only contain positive values.
            features=channel_size,
            dtype=self.dtype,
            paxis_name='batch',
            train=self.train,
            quant_context=self.quant_context,
            hparams=self.hparams.dense_2,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            name='dense_2')(x, padding_mask=padding_mask)

        output = nn.Dropout(rate=self.dropout_rate)(
            output, deterministic=self.deterministic)
        output = output.reshape((batch_size, sequence_length, channel_size))
        return output
Ejemplo n.º 5
0
 def __call__(self, inputs, hparams, num_classes, dtype=jnp.float32):
     output = aqt_flax_layers.DenseAqt(
         features=num_classes,
         dtype=dtype,
         train=False,
         quant_context=quant_config.QuantContext(
             update_bounds=False, collect_acts_stats=False),
         paxis_name='batch',
         hparams=hparams,
     )(inputs, padding_mask=None)
     return output
Ejemplo n.º 6
0
 def multi_batch_dense_aqt(inputs, *, name, padding_mask):
   batch_size, sequence_length, channel_size = inputs.shape
   inputs = inputs.reshape(batch_size * sequence_length, channel_size)
   if padding_mask is not None:
     padding_mask = padding_mask.reshape(batch_size * sequence_length, 1)
   out = flax_layers.DenseAqt(
       name=name,
       features=num_heads * head_dim,
       paxis_name=paxis_name,
       train=train,
       quant_context=self.quant_context,
       hparams=hparams.dense_kqv,
       kernel_init=kernel_init,
       bias_init=bias_init,
       use_bias=use_bias,
       dtype=dtype)(
           inputs, padding_mask=padding_mask)
   return out.reshape(batch_size, sequence_length, num_heads, head_dim)
Ejemplo n.º 7
0
 def test_quant_granularity(self, _, mock_quantized_dot, granularity, axis):
   hparams = flax_layers.DenseAqt.HParams(
       weight_prec=8,
       quant_act=None,
       quant_type=quantization.QuantType.fake_quant,
       weight_quant_granularity=granularity)
   layer = flax_layers.DenseAqt(
       features=2,
       hparams=hparams,
       quant_context=quant_config.QuantContext(
           update_bounds=False, collect_acts_stats=False),
       paxis_name=None,
       train=False,
       dtype=jnp.float32)
   x = jnp.ones((2, 2))
   state = layer.init(self.rng_key, x, padding_mask=None)
   layer.apply(state, x, padding_mask=None)
   weight_params = mock_quantized_dot.call_args[1]['weight_params']
   self.assertEqual(weight_params.axis, axis)
Ejemplo n.º 8
0
    def test_annotation_only_changes_hlo_metadata_dense(
            self, weight_prec, acts_prec):
        FLAGS.metadata_enabled = False
        quant_act = quantization.QuantOps.ActHParams(
            input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
            prec=acts_prec,
            bounds=1.0,
            half_shift=False)
        input_shape = (1, 16)
        module_no_annotation = aqt_flax_layers.DenseAqt(
            features=4,
            use_bias=False,
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            paxis_name='batch',
            train=False,
            hparams=aqt_flax_layers.DenseAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_quant_granularity=quant_config.QuantGranularity.
                per_channel,
                weight_half_shift=False),
            dtype=jnp.float32)

        init_state = module_no_annotation.init(self.rng_key,
                                               jnp.ones(
                                                   input_shape, jnp.float32),
                                               padding_mask=None)
        output_no_annotation = module_no_annotation.apply(
            init_state, jnp.ones(input_shape), padding_mask=None)

        hlo_no_annotation = hlo_utils.load_hlo_proto_from_model(
            module_no_annotation, init_state, [input_shape], padding_mask=None)
        del init_state

        FLAGS.metadata_enabled = True
        module_w_annotation = aqt_flax_layers.DenseAqt(
            features=4,
            use_bias=False,
            paxis_name='batch',
            train=False,
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            dtype=jnp.float32,
            hparams=aqt_flax_layers.DenseAqt.HParams(
                weight_prec=weight_prec,
                quant_act=quant_act,
                quant_type=QuantType.fake_quant,
                weight_quant_granularity=quant_config.QuantGranularity.
                per_channel,
                weight_half_shift=False),
        )

        init_state = module_w_annotation.init(self.rng_key,
                                              jnp.ones(input_shape,
                                                       jnp.float32),
                                              padding_mask=None)
        output_w_annotation = module_w_annotation.apply(init_state,
                                                        jnp.ones(input_shape),
                                                        padding_mask=None)

        hlo_w_annotation = hlo_utils.load_hlo_proto_from_model(
            module_w_annotation, init_state, [input_shape], padding_mask=None)
        del init_state

        onp.testing.assert_array_equal(output_no_annotation,
                                       output_w_annotation)
        self.compare_hlo_instructions(hlo_no_annotation, hlo_w_annotation)
Ejemplo n.º 9
0
    def __call__(self,
                 inputs_q,
                 inputs_kv,
                 *,
                 padding_mask,
                 key_padding_mask,
                 segmentation=None,
                 key_segmentation=None):
        """Applies multi-head dot product attention on the input data.

    If weight_prec is not None, scales and quantizes weights to signed int with
    weight_prec bits.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` or for self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or
        None for self-attention, inn which case key/values will be derived from
        inputs_q.
      padding_mask: boolean tensor specifying query tokens that are pad token.
      key_padding_mask: boolean tensor specifying key-value tokens that are pad
        token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """
        batch_size, query_sequence_length, channel_size = inputs_q.shape
        hparams = self.hparams
        if inputs_kv is None:
            inputs_kv = inputs_q
            key_sequence_length = inputs_q.shape[1]
        else:
            key_sequence_length = inputs_kv.shape[1]
            shape_utils.assert_shapes_equal(
                inputs_kv.shape,
                (batch_size, key_sequence_length, channel_size))

        jax_precision = jax.lax.Precision.DEFAULT

        if padding_mask is not None:
            shape_utils.assert_shapes_equal(
                padding_mask.shape, (batch_size, query_sequence_length, 1))
        if key_padding_mask is None:
            key_padding_mask = padding_mask
        else:
            shape_utils.assert_shapes_equal(
                key_padding_mask.shape, (batch_size, key_sequence_length, 1))
        attention_axis = self.attention_axis
        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        qkv_features = self.qkv_features
        qkv_features = qkv_features or inputs_q.shape[-1]

        num_heads = self.num_heads
        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        paxis_name = self.paxis_name
        train = self.train
        kernel_init = self.kernel_init
        bias_init = self.bias_init
        use_bias = self.use_bias
        dtype = self.dtype

        def multi_batch_dense_aqt(inputs, *, name, padding_mask):
            batch_size, sequence_length, channel_size = inputs.shape
            inputs = inputs.reshape(batch_size * sequence_length, channel_size)
            if padding_mask is not None:
                padding_mask = padding_mask.reshape(
                    batch_size * sequence_length, 1)
            out = flax_layers.DenseAqt(name=name,
                                       features=num_heads * head_dim,
                                       paxis_name=paxis_name,
                                       train=train,
                                       quant_context=self.quant_context,
                                       hparams=hparams.dense_kqv,
                                       kernel_init=kernel_init,
                                       bias_init=bias_init,
                                       use_bias=use_bias,
                                       dtype=dtype)(inputs,
                                                    padding_mask=padding_mask)
            return out.reshape(batch_size, sequence_length, num_heads,
                               head_dim)

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, sequence_length, n_heads, n_features_per_head]
        query = multi_batch_dense_aqt(inputs_q,
                                      name='query',
                                      padding_mask=padding_mask)
        key = multi_batch_dense_aqt(inputs_kv,
                                    name='key',
                                    padding_mask=key_padding_mask)
        value = multi_batch_dense_aqt(inputs_kv,
                                      name='value',
                                      padding_mask=key_padding_mask)
        is_cache_initialized = False
        if self.decode:
            is_cache_initialized = self.has_variable('cache', 'cached_key')
            cached_key = self.variable('cache', 'cached_key', jnp.zeros,
                                       key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros,
                                         value.shape, value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_cache_initialized:
                expected_shape = list(cached_key.value.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                cshape = cached_key.value.shape
                indices = [0] * len(cshape)
                i = cache_index.value
                attn_size = onp.prod(onp.take(cshape, attention_axis))

                *batch_dims, max_length, num_heads, depth_per_head = (  # pylint: disable=unused-variable
                    cached_key.value.shape)
                indices = (0, ) * len(batch_dims) + (i, 0, 0)

                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value,
                                                 indices)
                one = jnp.array(1, jnp.int32)
                cache_index.value = cache_index.value + one
                cached_key.value = key
                cached_value.value = value

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(max_length) < cache_index.value), cshape[:2])
                key_padding_mask = key_padding_mask.astype(
                    jnp.float32)[Ellipsis, None]

        # create attention masks
        mask_components = []
        if self.causal_mask:
            if self.decode and is_cache_initialized:
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(onp.take(key.shape, attention_axis))
                attn_size = onp.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.int32)
                mask = ii < cache_index.value
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))
        if padding_mask is not None:
            if key_padding_mask is None:
                key_padding_mask = padding_mask
            attn_padding_mask = make_padding_mask(
                padding_mask_query=padding_mask,
                padding_mask_key=key_padding_mask,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis)
            mask_components.append(attn_padding_mask)
        if segmentation is not None:
            if key_segmentation is None:
                key_segmentation = segmentation
            segmentation_mask = make_padding_mask(
                padding_mask_query=segmentation,
                padding_mask_key=key_segmentation,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis,
                segmentation_mask=True)
            mask_components.append(segmentation_mask)
        attention_mask = None
        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)
            attention_mask = attention_mask.astype(jnp.bool_)

            # attention mask in the form of attention bias
            attention_bias = jnp.where(
                attention_mask,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # Add an extra dimension to the mask corresponding to the head
        # dimension. eg, if inputs_q has shape [batch_size, sequence_length,
        # n_features], then padding_mask will have a shape
        # [batch_size, sequence_length, 1] and query will have shape
        # [batch_size, sequence_length, n_heads, n_features_per_head].
        # We create query_padding_mask with shape [batch_size, sequence_length,
        # 1, 1] to be broadcast-compatible with 'query'.
        if padding_mask is not None:
            padding_mask = padding_mask[Ellipsis, None]
            shape_utils.assert_shapes_equal(
                padding_mask.shape, (batch_size, query_sequence_length, 1, 1))
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask[Ellipsis, None]
            # During prediction, the key padding mask is only going to be
            # broadcast-compatible with the key.
            shape_utils.assert_shapes_compatible(
                key_padding_mask.shape,
                (batch_size, key_sequence_length, 1, 1))

        # apply attention
        attention_fn = self.attention_fn
        dropout_rate = self.dropout_rate
        broadcast_dropout = self.broadcast_dropout
        deterministic = self.deterministic
        if not deterministic and self.dropout_rate > 0.0:
            dropout_rng = self.make_rng('dropout')
        else:
            dropout_rng = None
        x = attention_fn(  # pylint: disable=redundant-keyword-arg
            query=query,
            key=key,
            value=value,
            hparams=hparams.attn_acts,
            paxis_name=paxis_name,
            train=train,
            quant_context=self.quant_context,
            dtype=dtype,
            axis=attention_axis,
            bias=attention_bias,
            precision=jax_precision,
            dropout_rng=dropout_rng,
            dropout_rate=dropout_rate,
            broadcast_dropout=broadcast_dropout,
            deterministic=deterministic,
            query_padding_mask=padding_mask,
            key_padding_mask=key_padding_mask,
            attn_mask=attention_mask)
        shape_utils.assert_shapes_equal(
            x.shape, (batch_size, query_sequence_length, num_heads, head_dim))
        x = x.reshape(batch_size * query_sequence_length, num_heads * head_dim)
        if padding_mask is not None:
            padding_mask = padding_mask.reshape(
                batch_size * query_sequence_length, 1)
        # back to the original inputs dimensions
        out = flax_layers.DenseAqt(features=channel_size,
                                   hparams=hparams.dense_out,
                                   quant_context=self.quant_context,
                                   paxis_name=paxis_name,
                                   train=train,
                                   kernel_init=kernel_init,
                                   bias_init=bias_init,
                                   use_bias=use_bias,
                                   dtype=dtype,
                                   name='dense_out')(x,
                                                     padding_mask=padding_mask)
        shape_utils.assert_shapes_equal(
            out.shape, (batch_size * query_sequence_length, channel_size))
        out = out.reshape(batch_size, query_sequence_length, channel_size)
        return out
Ejemplo n.º 10
0
    def __call__(
        self,
        encoded,
        src_padding_mask,
        targets,
        targets_positions=None,
        inputs_segmentation=None,
        targets_segmentation=None,
        tgt_padding_mask=None,
    ):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      src_padding_mask: padding mask for inputs.
      targets: target inputs.
      targets_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      tgt_padding_mask: target tokens padding mask.

    Returns:
      output of a transformer decoder.

    """
        batch_size, sequence_length, channel_size = encoded.shape  # pylint: disable=unused-variable
        target_batch_size, target_sequence_length = targets.shape  # pylint: disable=unused-variable
        shape_utils.assert_shapes_equal(targets.shape,
                                        (batch_size, target_sequence_length))

        # Padding Masks
        if tgt_padding_mask is None:
            tgt_padding_mask = (targets > 0)[Ellipsis, None]
        shape_utils.assert_shapes_equal(
            tgt_padding_mask.shape, (batch_size, target_sequence_length, 1))

        if self.use_bfloat16:
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Target Embedding
        if self.shared_embedding is None:
            output_embed = aqt_flax_layers.EmbedAqt(
                num_embeddings=self.output_vocab_size,
                features=self.emb_dim,
                hparams=self.hparams.embedding,
                embedding_init=nn.initializers.normal(
                    stddev=self.emb_dim**-0.5),
                dtype=dtype,
                name='target_embed',
                train=self.train,
                quant_context=self.quant_context,
                paxis_name='batch')
        else:
            output_embed = self.shared_embedding

        y = targets.astype('int32')
        if not self.decode:
            y = shift_right(y)
        y = output_embed(y) * jnp.sqrt(self.emb_dim)
        y = AddPositionEmbs(name='posembed_targets',
                            max_len=self.max_len,
                            decode=self.decode,
                            min_timescale=1.0,
                            max_timescale=10000.0)(
                                y, inputs_positions=targets_positions)
        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not self.train)

        if self.use_bfloat16:
            y = y.astype(jnp.bfloat16)

        # Target-Input Decoder
        num_layers = len(self.hparams.encoder_decoder_1d_blocks)
        for lyr in range(num_layers):
            y = EncoderDecoder1DBlock(
                train=self.train,
                quant_context=self.quant_context,
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                hparams=self.hparams.encoder_decoder_1d_blocks[lyr],
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                deterministic=not self.train,
                name=f'encoderdecoderblock_{lyr}',
                decode=self.decode)(y,
                                    encoded,
                                    padding_mask=tgt_padding_mask,
                                    key_padding_mask=src_padding_mask,
                                    inputs_segmentation=inputs_segmentation,
                                    targets_segmentation=targets_segmentation)
        y = aqt_flax_layers.LayerNormAqt(dtype=dtype,
                                         name='encoderdecoder_norm',
                                         hparams=self.hparams.layer_norm,
                                         quant_context=self.quant_context)(y)
        y = y.reshape((batch_size * target_sequence_length, channel_size))
        tgt_padding_mask = tgt_padding_mask.reshape(
            (batch_size * target_sequence_length, 1))
        # Decoded Logits
        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(query=y,
                                         padding_mask=tgt_padding_mask,
                                         paxis_name=self.paxis_name,
                                         train=self.train)
        else:
            if self.hparams.logits is None:
                raise ValueError(
                    'If logits_via_embedding is False, then the hparams '
                    'for the logits layer have to be provided.')
            logits = aqt_flax_layers.DenseAqt(
                features=self.output_vocab_size,
                dtype=dtype,
                paxis_name='batch',
                train=self.train,
                quant_context=self.quant_context,
                hparams=self.hparams.logits,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6),
                name='logits_dense')(y, padding_mask=tgt_padding_mask)
        return logits
Ejemplo n.º 11
0
  def __call__(
      self,
      inputs,
  ):
    """Applies ResNet model. Number of residual blocks inferred from hparams."""
    num_classes = self.num_classes
    hparams = self.hparams
    num_filters = self.num_filters
    dtype = self.dtype
    assert hparams.act_function in act_function_zoo.keys(
    ), 'Activation function type is not supported.'

    x = aqt_flax_layers.ConvAqt(
        features=num_filters,
        kernel_size=(7, 7),
        strides=(2, 2),
        padding=[(3, 3), (3, 3)],
        use_bias=False,
        dtype=dtype,
        name='init_conv',
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.conv_init,
    )(
        inputs)
    x = nn.BatchNorm(
        use_running_average=not self.train,
        momentum=0.9,
        epsilon=1e-5,
        dtype=dtype,
        name='init_bn')(
            x)
    if hparams.act_function == 'relu':
      x = nn.relu(x)
      x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    else:
      # TODO(yichi): try adding other activation functions here
      # Use avg pool so that for binary nets, the distribution is symmetric.
      x = nn.avg_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    filter_multiplier = hparams.filter_multiplier
    for i, block_hparams in enumerate(hparams.residual_blocks):
      proj = block_hparams.conv_proj
      # For projection layers (unless it is the first layer), strides = (2, 2)
      if i > 0 and proj is not None:
        filter_multiplier *= 2
        strides = (2, 2)
      else:
        strides = (1, 1)
      x = ResidualBlock(
          filters=int(num_filters * filter_multiplier),
          hparams=block_hparams,
          quant_context=self.quant_context,
          strides=strides,
          train=self.train,
          dtype=dtype)(
              x)
    if hparams.act_function == 'none':
      # The DenseAQT below is not binarized.
      # If removing the activation functions, there will be no act function
      # between the last residual block and the dense layer.
      # So add a ReLU in that case.
      # TODO(yichi): try BPReLU
      x = nn.relu(x)
    else:
      pass
    x = jnp.mean(x, axis=(1, 2))

    x = aqt_flax_layers.DenseAqt(
        features=num_classes,
        dtype=dtype,
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.dense_layer,
    )(x, padding_mask=None)

    x = jnp.asarray(x, dtype)
    output = nn.log_softmax(x)
    return output