Ejemplo n.º 1
0
    def get_transformer_params(self, embedding_size):
        """Returns Transformer parameters."""
        num_heads = self.model_config.heads

        def get_size(frac):
            if frac <= 3.0:
                dim = int(embedding_size * frac)
                if dim % num_heads > 0:
                    # Making sure that the Transformer input dimension is divisible by
                    # the number of heads.
                    dim = math.ceil(float(dim) / num_heads) * num_heads
                return dim
            else:
                return int(frac)

        attn_act_fn = common_ht.get_transformer_activation(self.model_config)

        return transformer.TransformerParams(
            query_key_dim=get_size(self.model_config.query_key_dim_frac),
            internal_dim=get_size(self.model_config.internal_dim_frac),
            value_dim=get_size(self.model_config.value_dim_frac),
            num_layers=self.model_config.num_layers,
            mha_output_dim=embedding_size,
            heads=num_heads,
            dropout_rate=self.model_config.dropout_rate,
            attention_activation_fn=attn_act_fn,
            activation_fn=util.nonlinearity(
                self.model_config.transformer_nonlinearity),
        )
Ejemplo n.º 2
0
 def _make_encoder(self):
     """Creates an encoder layer."""
     params = transformer.TransformerParams(mha_output_dim=512,
                                            heads=8,
                                            internal_dim=1024,
                                            num_layers=1,
                                            query_key_dim=128)
     sample_encoder_layer = transformer.EncoderLayer(params)
     return (sample_encoder_layer(tf.random.uniform((64, 43, 512)), False,
                                  None), params)
Ejemplo n.º 3
0
 def test_mha(self):
     """Verifies multi-headed attention layer output shapes."""
     params = transformer.TransformerParams(mha_output_dim=4,
                                            heads=2,
                                            num_layers=1,
                                            query_key_dim=2,
                                            internal_dim=2)
     mha = transformer.MultiHeadAttention(params)
     q, k, v, _, _ = _make_qkv_2()
     output, attn_weights = mha(q, k, v)
     self.assertEqual(output.shape, (1, 2, 4))
     self.assertEqual(attn_weights.shape, (1, 2, 2, 2))
Ejemplo n.º 4
0
    def get_decoder_params(self, embedding_size):
        """Returns Transformer parameters."""
        def get_size(frac):
            if frac <= 3.0:
                return int(embedding_size * frac)
            else:
                return int(frac)

        attn_act_fn = common_ht.get_transformer_activation(self.model_config)

        return transformer.TransformerParams(
            query_key_dim=get_size(self.model_config.query_key_dim_frac),
            internal_dim=get_size(self.model_config.internal_dim_frac),
            value_dim=get_size(self.model_config.value_dim_frac),
            num_layers=self.model_config.num_layers,
            mha_output_dim=embedding_size,
            heads=self.model_config.heads,
            dropout_rate=self.model_config.dropout_rate,
            attention_activation_fn=attn_act_fn,
            activation_fn=util.nonlinearity(
                self.model_config.transformer_nonlinearity),
        )