def __init__(self, d_k, d_v, d_model, n_heads, dropout): super(_MultiHeadAttention, self).__init__() self.d_k = d_k self.d_v = d_v self.d_model = d_model self.n_heads = n_heads self.w_q = Linear([d_model, d_k * n_heads]) self.w_k = Linear([d_model, d_k * n_heads]) self.w_v = Linear([d_model, d_v * n_heads]) self.attention = ScaledDotProductAttention(d_k, dropout)
def __init__(self, d_k, d_v, d_model, n_heads, dropout): super(MultiHeadAttention, self).__init__() self.n_heads = n_heads self.multihead_attn = _MultiHeadAttention(d_k, d_v, d_model, n_heads, dropout) self.proj = Linear(n_heads * d_v, d_model) self.dropout = nn.Dropout(dropout) self.layer_norm = LayerNormalization(d_model)
def __init__(self, d_k, d_v, d_model, d_ff, n_branches, dropout): super(MultiBranchAttention, self).__init__() self.d_k = d_k self.d_v = d_v self.d_model = d_model self.d_ff = d_ff self.n_branches = n_branches self.multihead_attn = _MultiHeadAttention(d_k, d_v, d_model, n_branches, dropout) # additional parameters for BranchedAttention self.w_o = nn.ModuleList( [Linear(d_v, d_model) for _ in range(n_branches)]) self.w_kp = torch.rand(n_branches) self.w_kp = nn.Parameter(self.w_kp / self.w_kp.sum()) self.w_a = torch.rand(n_branches) self.w_a = nn.Parameter(self.w_a / self.w_a.sum()) self.pos_ffn = nn.ModuleList([ PoswiseFeedForwardNet(d_model, d_ff // n_branches, dropout) for _ in range(n_branches) ]) self.dropout = nn.Dropout(dropout) self.layer_norm = LayerNormalization(d_model) init.xavier_normal(self.w_o)
def __init__(self, opt): super(Transformer, self).__init__() self.encoder = Encoder(opt.n_layers, opt.d_k, opt.d_v, opt.d_model, opt.d_ff, opt.n_heads, opt.max_src_seq_len, opt.src_vocab_size, opt.dropout, opt.weighted_model) self.decoder = Decoder(opt.n_layers, opt.d_k, opt.d_v, opt.d_model, opt.d_ff, opt.n_heads, opt.max_tgt_seq_len, opt.tgt_vocab_size, opt.dropout, opt.weighted_model) self.tgt_proj = Linear(opt.d_model, opt.tgt_vocab_size, bias=False) self.weighted_model = opt.weighted_model self.mu_linear = nn.Linear(opt.d_model, 1024) self.var_linear = nn.Linear(opt.d_model, 1024) self.vae_z_linear = nn.Linear(1024, opt.d_model) if opt.share_proj_weight: print('Sharing target embedding and projection..') self.tgt_proj.weight = self.decoder.tgt_emb.weight if opt.share_embs_weight: print('Sharing source and target embedding..') assert opt.src_vocab_size == opt.tgt_vocab_size, \ 'To share word embeddings, the vocabulary size of src/tgt should be the same' self.encoder.src_emb.weight = self.decoder.tgt_emb.weight