def __init__(self, dec_attn_config, enc_dec_attn_config, num_heads, dim, hidden_dim, layer_i, causal=True, span=1, dropout_p=0.1): ''' Initialize the transformer layer ''' super(TransformerDecoderLayer, self).__init__() self.span = span self.causal = causal self.uuid = uuid.uuid4() self.enc_dec_attn_config = enc_dec_attn_config self.no_attn = dec_attn_config['no_attn'] if dec_attn_config['ffn_layer'][layer_i]: self.ffn = TransformerSublayer(TransformerFFN(dim, hidden_dim), dim, dropout_p) print('dec layer %i has ffn' % layer_i) if not self.no_attn: self.self_attention = TransformerSublayer( NewAttention(dec_attn_config, dim, num_heads), dim, dropout_p) if self.enc_dec_attn_config['enc_dec_attn_layer'] == 1 or \ (type(self.enc_dec_attn_config['enc_dec_attn_layer'] is list) and self.enc_dec_attn_config['enc_dec_attn_layer'][layer_i] == 1): if self.enc_dec_attn_config['enc_dec_attn_num_heads'] == -1: src_num_heads = num_heads elif type(self.enc_dec_attn_config['enc_dec_attn_num_heads'] ) is not list: src_num_heads = self.enc_dec_attn_config[ 'enc_dec_attn_num_heads'] else: src_num_heads = self.enc_dec_attn_config[ 'enc_dec_attn_num_heads'][layer_i] assert src_num_heads != 0 self.source_attention = TransformerSublayer( NewAttention(enc_dec_attn_config, dim, src_num_heads), dim, dropout_p) print('layer %i num of src heads %i' % (layer_i, src_num_heads))
def __init__(self, attn_config, num_heads, dim, hidden_dim, layer_i, dropout_p=0.1): ''' Initialize the transformer layer ''' super(TransformerEncoderLayer, self).__init__() if attn_config['ffn_layer'][layer_i]: self.ffn = TransformerSublayer( TransformerFFN(dim, hidden_dim), dim, dropout_p ) print('enc layer %i has ffn' % layer_i) self.self_attention = TransformerSublayer( NewAttention(attn_config, dim, num_heads), dim, dropout_p )