Esempio n. 1
0
  def test_embed(self, weight_prec):
    # Since the dummy embedding matrix has a row of all zeros, we need 'epsilon'
    # to be added to it before calculating scale factors.
    quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = False
    rng = random.PRNGKey(0)
    x = jnp.arange(4)[None]
    dummy_embedding = jnp.broadcast_to(jnp.arange(4)[Ellipsis, None],
                                       (4, 3)).astype(jnp.float32)
    embed_module = flax_layers.EmbedAqt(
        num_embeddings=4,
        features=3,
        dtype=jnp.float32,
        hparams=flax_layers.EmbedAqt.HParams(
            weight_prec=weight_prec,
            quant_act=None,
            quant_type=QuantType.fake_quant,
            weight_half_shift=False),
        embedding_init=lambda _rng, _shape: dummy_embedding,
        train=False,
        paxis_name=None,
        quant_context=quant_config.QuantContext(update_bounds=False),
    )
    y, state = embed_module.init_with_output(rng, x)
    test_utils.assert_all_close_prec(dummy_embedding[None], y, weight_prec)

    z = embed_module.apply(
        state, jnp.ones((1, 3)), padding_mask=None, method=embed_module.attend)
    test_utils.assert_all_close_prec(3. * jnp.arange(4), z[0, Ellipsis], weight_prec)
Esempio n. 2
0
  def test_embed_equality(self, weight_prec):
    rng = random.PRNGKey(0)
    x = 2 * jnp.ones(4, dtype=jnp.int32)[None]
    dummy_embedding = 2 * jnp.ones((4, 2)).astype(jnp.float32)
    embed_module = flax_layers.EmbedAqt(
        num_embeddings=4,
        features=2,
        dtype=jnp.float32,
        hparams=flax_layers.EmbedAqt.HParams(
            weight_prec=weight_prec,
            quant_act=None,
            quant_type=QuantType.fake_quant,
            weight_half_shift=False),
        embedding_init=lambda _rng, _shape: dummy_embedding,
        train=False,
        quant_context=quant_config.QuantContext(update_bounds=False),
        paxis_name=None)
    y, init_state = embed_module.init_with_output(rng, x)
    onp.testing.assert_array_equal(dummy_embedding[None], y)

    z = embed_module.apply(
        init_state,
        jnp.ones((1, 2)),
        padding_mask=None,
        method=embed_module.attend)
    onp.testing.assert_array_equal(2. * (2 * jnp.ones(4)), z[0, Ellipsis])
Esempio n. 3
0
    def setup(self):

        if self.use_bfloat16:
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        if self.hparams.share_embeddings:
            if self.output_vocab_size is not None:
                assert self.output_vocab_size == self.vocab_size, (
                    "can't share embedding with different vocab sizes.")
            self.shared_embedding = aqt_flax_layers.EmbedAqt(  # pylint: disable=missing-from-attributes
                num_embeddings=self.vocab_size,
                features=self.hparams.emb_dim,
                hparams=self.hparams.encoder.embedding,
                dtype=dtype,
                embedding_init=nn.initializers.normal(
                    stddev=self.hparams.emb_dim**-0.5),
                train=self.train,
                quant_context=self.quant_context,
                paxis_name='batch')
        else:
            self.shared_embedding = None

        self.encoder = Encoder(  # pylint: disable=missing-from-attributes
            hparams=self.hparams.encoder,
            vocab_size=self.vocab_size,
            shared_embedding=self.shared_embedding,
            use_bfloat16=self.use_bfloat16,
            emb_dim=self.hparams.emb_dim,
            num_heads=self.hparams.num_heads,
            qkv_dim=self.hparams.qkv_dim,
            mlp_dim=self.hparams.mlp_dim,
            max_len=self.max_len,
            train=self.train,
            quant_context=self.quant_context,
            dropout_rate=self.dropout_rate,
            attention_dropout_rate=self.attention_dropout_rate,
        )

        self.decoder = Decoder(  # pylint: disable=missing-from-attributes
            hparams=self.hparams.decoder,
            output_vocab_size=self.output_vocab_size,
            shared_embedding=self.shared_embedding,
            logits_via_embedding=self.hparams.logits_via_embedding,
            use_bfloat16=self.use_bfloat16,
            emb_dim=self.hparams.emb_dim,
            num_heads=self.hparams.num_heads,
            qkv_dim=self.hparams.qkv_dim,
            mlp_dim=self.hparams.mlp_dim,
            max_len=self.max_len,
            train=self.train,
            quant_context=self.quant_context,
            dropout_rate=self.dropout_rate,
            attention_dropout_rate=self.attention_dropout_rate,
            paxis_name='batch',
            decode=self.should_decode)
Esempio n. 4
0
  def test_embed_should_call_clip_and_round(self, floor_with_gradient,
                                            round_with_gradient, weight_prec,
                                            acts_prec, fixed_bounds):

    round_with_gradient.side_effect = lambda x: x
    floor_with_gradient.side_effect = lambda x: x

    if fixed_bounds:
      bounds = 6.0
    else:
      bounds = get_bounds.GetBounds.Hyper(
          initial_bound=6.0,
          stddev_coeff=3.0,
          absdev_coeff=2.0,
          mix_coeff=0.5,
          granularity=quant_config.QuantGranularity.per_tensor)
    quant_act = quantization.QuantOps.ActHParams(
        input_distribution=QuantOps.ActHParams.InputDistribution.symmetric,
        prec=acts_prec,
        bounds=bounds,
        half_shift=False)
    rng = random.PRNGKey(0)
    x = jnp.ones((1, 3))

    embed_module = flax_layers.EmbedAqt(
        num_embeddings=4,
        features=3,
        dtype=jnp.float32,
        hparams=flax_layers.EmbedAqt.HParams(
            weight_prec=weight_prec,
            quant_act=quant_act,
            quant_type=QuantType.fake_quant,
            weight_half_shift=False),
        quant_context=quant_config.QuantContext(update_bounds=False),
        paxis_name=None,
        train=False)
    init_state = embed_module.init(
        rng, x, method=embed_module.attend, padding_mask=None)
    round_with_gradient.reset_mock()
    floor_with_gradient.reset_mock()
    embed_module.apply(
        init_state, x, padding_mask=None, method=embed_module.attend)
    round_with_gradient.assert_called_with(mock.ANY)
    self.assertEqual(round_with_gradient.call_count, 1)
    floor_with_gradient.assert_not_called()
Esempio n. 5
0
    def __call__(
        self,
        encoded,
        src_padding_mask,
        targets,
        targets_positions=None,
        inputs_segmentation=None,
        targets_segmentation=None,
        tgt_padding_mask=None,
    ):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      src_padding_mask: padding mask for inputs.
      targets: target inputs.
      targets_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      tgt_padding_mask: target tokens padding mask.

    Returns:
      output of a transformer decoder.

    """
        batch_size, sequence_length, channel_size = encoded.shape  # pylint: disable=unused-variable
        target_batch_size, target_sequence_length = targets.shape  # pylint: disable=unused-variable
        shape_utils.assert_shapes_equal(targets.shape,
                                        (batch_size, target_sequence_length))

        # Padding Masks
        if tgt_padding_mask is None:
            tgt_padding_mask = (targets > 0)[Ellipsis, None]
        shape_utils.assert_shapes_equal(
            tgt_padding_mask.shape, (batch_size, target_sequence_length, 1))

        if self.use_bfloat16:
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Target Embedding
        if self.shared_embedding is None:
            output_embed = aqt_flax_layers.EmbedAqt(
                num_embeddings=self.output_vocab_size,
                features=self.emb_dim,
                hparams=self.hparams.embedding,
                embedding_init=nn.initializers.normal(
                    stddev=self.emb_dim**-0.5),
                dtype=dtype,
                name='target_embed',
                train=self.train,
                quant_context=self.quant_context,
                paxis_name='batch')
        else:
            output_embed = self.shared_embedding

        y = targets.astype('int32')
        if not self.decode:
            y = shift_right(y)
        y = output_embed(y) * jnp.sqrt(self.emb_dim)
        y = AddPositionEmbs(name='posembed_targets',
                            max_len=self.max_len,
                            decode=self.decode,
                            min_timescale=1.0,
                            max_timescale=10000.0)(
                                y, inputs_positions=targets_positions)
        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not self.train)

        if self.use_bfloat16:
            y = y.astype(jnp.bfloat16)

        # Target-Input Decoder
        num_layers = len(self.hparams.encoder_decoder_1d_blocks)
        for lyr in range(num_layers):
            y = EncoderDecoder1DBlock(
                train=self.train,
                quant_context=self.quant_context,
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                hparams=self.hparams.encoder_decoder_1d_blocks[lyr],
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                deterministic=not self.train,
                name=f'encoderdecoderblock_{lyr}',
                decode=self.decode)(y,
                                    encoded,
                                    padding_mask=tgt_padding_mask,
                                    key_padding_mask=src_padding_mask,
                                    inputs_segmentation=inputs_segmentation,
                                    targets_segmentation=targets_segmentation)
        y = aqt_flax_layers.LayerNormAqt(dtype=dtype,
                                         name='encoderdecoder_norm',
                                         hparams=self.hparams.layer_norm,
                                         quant_context=self.quant_context)(y)
        y = y.reshape((batch_size * target_sequence_length, channel_size))
        tgt_padding_mask = tgt_padding_mask.reshape(
            (batch_size * target_sequence_length, 1))
        # Decoded Logits
        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(query=y,
                                         padding_mask=tgt_padding_mask,
                                         paxis_name=self.paxis_name,
                                         train=self.train)
        else:
            if self.hparams.logits is None:
                raise ValueError(
                    'If logits_via_embedding is False, then the hparams '
                    'for the logits layer have to be provided.')
            logits = aqt_flax_layers.DenseAqt(
                features=self.output_vocab_size,
                dtype=dtype,
                paxis_name='batch',
                train=self.train,
                quant_context=self.quant_context,
                hparams=self.hparams.logits,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6),
                name='logits_dense')(y, padding_mask=tgt_padding_mask)
        return logits
Esempio n. 6
0
    def __call__(self,
                 inputs,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output of a transformer decoder.

    """
        batch_size, sequence_length = inputs.shape

        # Padding Masks
        src_padding_mask = (inputs > 0)[Ellipsis, None]
        shape_utils.assert_shapes_equal(src_padding_mask.shape,
                                        (batch_size, sequence_length, 1))

        if self.use_bfloat16:
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input Embedding
        if self.shared_embedding is None:
            input_embed = aqt_flax_layers.EmbedAqt(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                hparams=self.hparams.embedding,
                embedding_init=nn.initializers.normal(
                    stddev=self.emb_dim**-0.5),
                dtype=dtype,
                name='input_embed',
                paxis_name='batch',
                train=self.train,
                quant_context=self.quant_context)
        else:
            input_embed = self.shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x) * jnp.sqrt(self.emb_dim)
        x = AddPositionEmbs(name='posembed_input',
                            max_len=self.max_len,
                            min_timescale=1.0,
                            max_timescale=10000.0,
                            decode=False)(x, inputs_positions=inputs_positions)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not self.train)

        if self.use_bfloat16:
            x = x.astype(jnp.bfloat16)

        # Input Encoder
        num_layers = len(self.hparams.encoder_1d_blocks)
        for lyr in range(num_layers):
            x = Encoder1DBlock(
                train=self.train,
                quant_context=self.quant_context,
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                hparams=self.hparams.encoder_1d_blocks[lyr],
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                deterministic=not self.train,
                name=f'encoderblock_{lyr}')(
                    x,
                    padding_mask=src_padding_mask,
                    inputs_segmentation=inputs_segmentation)
        encoded = aqt_flax_layers.LayerNormAqt(
            dtype=dtype,
            name='encoder_norm',
            hparams=self.hparams.layer_norm,
            quant_context=self.quant_context)(x)
        shape_utils.assert_shapes_equal(
            encoded.shape, (batch_size, sequence_length, self.emb_dim))
        return encoded