Exemple #1
0
    def __init__(self,
                 dec_attn_config,
                 enc_dec_attn_config,
                 num_heads,
                 dim,
                 hidden_dim,
                 layer_i,
                 causal=True,
                 span=1,
                 dropout_p=0.1):
        ''' Initialize the transformer layer '''
        super(TransformerDecoderLayer, self).__init__()

        self.span = span
        self.causal = causal
        self.uuid = uuid.uuid4()

        self.enc_dec_attn_config = enc_dec_attn_config
        self.no_attn = dec_attn_config['no_attn']

        if dec_attn_config['ffn_layer'][layer_i]:
            self.ffn = TransformerSublayer(TransformerFFN(dim, hidden_dim),
                                           dim, dropout_p)
            print('dec layer %i has ffn' % layer_i)

        if not self.no_attn:
            self.self_attention = TransformerSublayer(
                NewAttention(dec_attn_config, dim, num_heads), dim, dropout_p)

        if self.enc_dec_attn_config['enc_dec_attn_layer'] == 1 or \
                (type(self.enc_dec_attn_config['enc_dec_attn_layer'] is list) and
                 self.enc_dec_attn_config['enc_dec_attn_layer'][layer_i] == 1):
            if self.enc_dec_attn_config['enc_dec_attn_num_heads'] == -1:
                src_num_heads = num_heads
            elif type(self.enc_dec_attn_config['enc_dec_attn_num_heads']
                      ) is not list:
                src_num_heads = self.enc_dec_attn_config[
                    'enc_dec_attn_num_heads']
            else:
                src_num_heads = self.enc_dec_attn_config[
                    'enc_dec_attn_num_heads'][layer_i]
            assert src_num_heads != 0

            self.source_attention = TransformerSublayer(
                NewAttention(enc_dec_attn_config, dim, src_num_heads), dim,
                dropout_p)

            print('layer %i num of src heads %i' % (layer_i, src_num_heads))
    def __init__(self, attn_config, num_heads, dim, hidden_dim, layer_i, dropout_p=0.1):
        ''' Initialize the transformer layer '''
        super(TransformerEncoderLayer, self).__init__()

        if attn_config['ffn_layer'][layer_i]:
            self.ffn = TransformerSublayer(
                TransformerFFN(dim, hidden_dim),
                dim, dropout_p
            )
            print('enc layer %i has ffn' % layer_i)

        self.self_attention = TransformerSublayer(
            NewAttention(attn_config, dim, num_heads),
            dim, dropout_p
        )