def test_embed(self, weight_prec): # Since the dummy embedding matrix has a row of all zeros, we need 'epsilon' # to be added to it before calculating scale factors. quantization.DISABLE_EPSILON_IN_SCALE_FUN_FOR_TESTING = False rng = random.PRNGKey(0) x = jnp.arange(4)[None] dummy_embedding = jnp.broadcast_to(jnp.arange(4)[Ellipsis, None], (4, 3)).astype(jnp.float32) embed_module = flax_layers.EmbedAqt( num_embeddings=4, features=3, dtype=jnp.float32, hparams=flax_layers.EmbedAqt.HParams( weight_prec=weight_prec, quant_act=None, quant_type=QuantType.fake_quant, weight_half_shift=False), embedding_init=lambda _rng, _shape: dummy_embedding, train=False, paxis_name=None, quant_context=quant_config.QuantContext(update_bounds=False), ) y, state = embed_module.init_with_output(rng, x) test_utils.assert_all_close_prec(dummy_embedding[None], y, weight_prec) z = embed_module.apply( state, jnp.ones((1, 3)), padding_mask=None, method=embed_module.attend) test_utils.assert_all_close_prec(3. * jnp.arange(4), z[0, Ellipsis], weight_prec)
def test_embed_equality(self, weight_prec): rng = random.PRNGKey(0) x = 2 * jnp.ones(4, dtype=jnp.int32)[None] dummy_embedding = 2 * jnp.ones((4, 2)).astype(jnp.float32) embed_module = flax_layers.EmbedAqt( num_embeddings=4, features=2, dtype=jnp.float32, hparams=flax_layers.EmbedAqt.HParams( weight_prec=weight_prec, quant_act=None, quant_type=QuantType.fake_quant, weight_half_shift=False), embedding_init=lambda _rng, _shape: dummy_embedding, train=False, quant_context=quant_config.QuantContext(update_bounds=False), paxis_name=None) y, init_state = embed_module.init_with_output(rng, x) onp.testing.assert_array_equal(dummy_embedding[None], y) z = embed_module.apply( init_state, jnp.ones((1, 2)), padding_mask=None, method=embed_module.attend) onp.testing.assert_array_equal(2. * (2 * jnp.ones(4)), z[0, Ellipsis])
def setup(self): if self.use_bfloat16: dtype = jnp.bfloat16 else: dtype = jnp.float32 if self.hparams.share_embeddings: if self.output_vocab_size is not None: assert self.output_vocab_size == self.vocab_size, ( "can't share embedding with different vocab sizes.") self.shared_embedding = aqt_flax_layers.EmbedAqt( # pylint: disable=missing-from-attributes num_embeddings=self.vocab_size, features=self.hparams.emb_dim, hparams=self.hparams.encoder.embedding, dtype=dtype, embedding_init=nn.initializers.normal( stddev=self.hparams.emb_dim**-0.5), train=self.train, quant_context=self.quant_context, paxis_name='batch') else: self.shared_embedding = None self.encoder = Encoder( # pylint: disable=missing-from-attributes hparams=self.hparams.encoder, vocab_size=self.vocab_size, shared_embedding=self.shared_embedding, use_bfloat16=self.use_bfloat16, emb_dim=self.hparams.emb_dim, num_heads=self.hparams.num_heads, qkv_dim=self.hparams.qkv_dim, mlp_dim=self.hparams.mlp_dim, max_len=self.max_len, train=self.train, quant_context=self.quant_context, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, ) self.decoder = Decoder( # pylint: disable=missing-from-attributes hparams=self.hparams.decoder, output_vocab_size=self.output_vocab_size, shared_embedding=self.shared_embedding, logits_via_embedding=self.hparams.logits_via_embedding, use_bfloat16=self.use_bfloat16, emb_dim=self.hparams.emb_dim, num_heads=self.hparams.num_heads, qkv_dim=self.hparams.qkv_dim, mlp_dim=self.hparams.mlp_dim, max_len=self.max_len, train=self.train, quant_context=self.quant_context, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, paxis_name='batch', decode=self.should_decode)
def test_embed_should_call_clip_and_round(self, floor_with_gradient, round_with_gradient, weight_prec, acts_prec, fixed_bounds): round_with_gradient.side_effect = lambda x: x floor_with_gradient.side_effect = lambda x: x if fixed_bounds: bounds = 6.0 else: bounds = get_bounds.GetBounds.Hyper( initial_bound=6.0, stddev_coeff=3.0, absdev_coeff=2.0, mix_coeff=0.5, granularity=quant_config.QuantGranularity.per_tensor) quant_act = quantization.QuantOps.ActHParams( input_distribution=QuantOps.ActHParams.InputDistribution.symmetric, prec=acts_prec, bounds=bounds, half_shift=False) rng = random.PRNGKey(0) x = jnp.ones((1, 3)) embed_module = flax_layers.EmbedAqt( num_embeddings=4, features=3, dtype=jnp.float32, hparams=flax_layers.EmbedAqt.HParams( weight_prec=weight_prec, quant_act=quant_act, quant_type=QuantType.fake_quant, weight_half_shift=False), quant_context=quant_config.QuantContext(update_bounds=False), paxis_name=None, train=False) init_state = embed_module.init( rng, x, method=embed_module.attend, padding_mask=None) round_with_gradient.reset_mock() floor_with_gradient.reset_mock() embed_module.apply( init_state, x, padding_mask=None, method=embed_module.attend) round_with_gradient.assert_called_with(mock.ANY) self.assertEqual(round_with_gradient.call_count, 1) floor_with_gradient.assert_not_called()
def __call__( self, encoded, src_padding_mask, targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, tgt_padding_mask=None, ): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. src_padding_mask: padding mask for inputs. targets: target inputs. targets_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. tgt_padding_mask: target tokens padding mask. Returns: output of a transformer decoder. """ batch_size, sequence_length, channel_size = encoded.shape # pylint: disable=unused-variable target_batch_size, target_sequence_length = targets.shape # pylint: disable=unused-variable shape_utils.assert_shapes_equal(targets.shape, (batch_size, target_sequence_length)) # Padding Masks if tgt_padding_mask is None: tgt_padding_mask = (targets > 0)[Ellipsis, None] shape_utils.assert_shapes_equal( tgt_padding_mask.shape, (batch_size, target_sequence_length, 1)) if self.use_bfloat16: dtype = jnp.bfloat16 else: dtype = jnp.float32 # Target Embedding if self.shared_embedding is None: output_embed = aqt_flax_layers.EmbedAqt( num_embeddings=self.output_vocab_size, features=self.emb_dim, hparams=self.hparams.embedding, embedding_init=nn.initializers.normal( stddev=self.emb_dim**-0.5), dtype=dtype, name='target_embed', train=self.train, quant_context=self.quant_context, paxis_name='batch') else: output_embed = self.shared_embedding y = targets.astype('int32') if not self.decode: y = shift_right(y) y = output_embed(y) * jnp.sqrt(self.emb_dim) y = AddPositionEmbs(name='posembed_targets', max_len=self.max_len, decode=self.decode, min_timescale=1.0, max_timescale=10000.0)( y, inputs_positions=targets_positions) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not self.train) if self.use_bfloat16: y = y.astype(jnp.bfloat16) # Target-Input Decoder num_layers = len(self.hparams.encoder_decoder_1d_blocks) for lyr in range(num_layers): y = EncoderDecoder1DBlock( train=self.train, quant_context=self.quant_context, qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, hparams=self.hparams.encoder_decoder_1d_blocks[lyr], dtype=dtype, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, deterministic=not self.train, name=f'encoderdecoderblock_{lyr}', decode=self.decode)(y, encoded, padding_mask=tgt_padding_mask, key_padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation) y = aqt_flax_layers.LayerNormAqt(dtype=dtype, name='encoderdecoder_norm', hparams=self.hparams.layer_norm, quant_context=self.quant_context)(y) y = y.reshape((batch_size * target_sequence_length, channel_size)) tgt_padding_mask = tgt_padding_mask.reshape( (batch_size * target_sequence_length, 1)) # Decoded Logits if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(query=y, padding_mask=tgt_padding_mask, paxis_name=self.paxis_name, train=self.train) else: if self.hparams.logits is None: raise ValueError( 'If logits_via_embedding is False, then the hparams ' 'for the logits layer have to be provided.') logits = aqt_flax_layers.DenseAqt( features=self.output_vocab_size, dtype=dtype, paxis_name='batch', train=self.train, quant_context=self.quant_context, hparams=self.hparams.logits, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), name='logits_dense')(y, padding_mask=tgt_padding_mask) return logits
def __call__(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies Transformer model on the inputs. Args: inputs: input data inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output of a transformer decoder. """ batch_size, sequence_length = inputs.shape # Padding Masks src_padding_mask = (inputs > 0)[Ellipsis, None] shape_utils.assert_shapes_equal(src_padding_mask.shape, (batch_size, sequence_length, 1)) if self.use_bfloat16: dtype = jnp.bfloat16 else: dtype = jnp.float32 # Input Embedding if self.shared_embedding is None: input_embed = aqt_flax_layers.EmbedAqt( num_embeddings=self.vocab_size, features=self.emb_dim, hparams=self.hparams.embedding, embedding_init=nn.initializers.normal( stddev=self.emb_dim**-0.5), dtype=dtype, name='input_embed', paxis_name='batch', train=self.train, quant_context=self.quant_context) else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) * jnp.sqrt(self.emb_dim) x = AddPositionEmbs(name='posembed_input', max_len=self.max_len, min_timescale=1.0, max_timescale=10000.0, decode=False)(x, inputs_positions=inputs_positions) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not self.train) if self.use_bfloat16: x = x.astype(jnp.bfloat16) # Input Encoder num_layers = len(self.hparams.encoder_1d_blocks) for lyr in range(num_layers): x = Encoder1DBlock( train=self.train, quant_context=self.quant_context, qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, hparams=self.hparams.encoder_1d_blocks[lyr], dtype=dtype, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, deterministic=not self.train, name=f'encoderblock_{lyr}')( x, padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation) encoded = aqt_flax_layers.LayerNormAqt( dtype=dtype, name='encoder_norm', hparams=self.hparams.layer_norm, quant_context=self.quant_context)(x) shape_utils.assert_shapes_equal( encoded.shape, (batch_size, sequence_length, self.emb_dim)) return encoded