def get_transformer_params(self, embedding_size): """Returns Transformer parameters.""" num_heads = self.model_config.heads def get_size(frac): if frac <= 3.0: dim = int(embedding_size * frac) if dim % num_heads > 0: # Making sure that the Transformer input dimension is divisible by # the number of heads. dim = math.ceil(float(dim) / num_heads) * num_heads return dim else: return int(frac) attn_act_fn = common_ht.get_transformer_activation(self.model_config) return transformer.TransformerParams( query_key_dim=get_size(self.model_config.query_key_dim_frac), internal_dim=get_size(self.model_config.internal_dim_frac), value_dim=get_size(self.model_config.value_dim_frac), num_layers=self.model_config.num_layers, mha_output_dim=embedding_size, heads=num_heads, dropout_rate=self.model_config.dropout_rate, attention_activation_fn=attn_act_fn, activation_fn=util.nonlinearity( self.model_config.transformer_nonlinearity), )
def _make_encoder(self): """Creates an encoder layer.""" params = transformer.TransformerParams(mha_output_dim=512, heads=8, internal_dim=1024, num_layers=1, query_key_dim=128) sample_encoder_layer = transformer.EncoderLayer(params) return (sample_encoder_layer(tf.random.uniform((64, 43, 512)), False, None), params)
def test_mha(self): """Verifies multi-headed attention layer output shapes.""" params = transformer.TransformerParams(mha_output_dim=4, heads=2, num_layers=1, query_key_dim=2, internal_dim=2) mha = transformer.MultiHeadAttention(params) q, k, v, _, _ = _make_qkv_2() output, attn_weights = mha(q, k, v) self.assertEqual(output.shape, (1, 2, 4)) self.assertEqual(attn_weights.shape, (1, 2, 2, 2))
def get_decoder_params(self, embedding_size): """Returns Transformer parameters.""" def get_size(frac): if frac <= 3.0: return int(embedding_size * frac) else: return int(frac) attn_act_fn = common_ht.get_transformer_activation(self.model_config) return transformer.TransformerParams( query_key_dim=get_size(self.model_config.query_key_dim_frac), internal_dim=get_size(self.model_config.internal_dim_frac), value_dim=get_size(self.model_config.value_dim_frac), num_layers=self.model_config.num_layers, mha_output_dim=embedding_size, heads=self.model_config.heads, dropout_rate=self.model_config.dropout_rate, attention_activation_fn=attn_act_fn, activation_fn=util.nonlinearity( self.model_config.transformer_nonlinearity), )