def forward(self, tgt_seq, src_seq, da, enc_out): tgt_seq, src_seq = tgt_seq.to(self.device), src_seq.to(self.device) da, enc_out = da.to(self.device), enc_out.to(self.device) non_pad_mask = get_non_pad_mask(tgt_seq) slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) dec_inp = self.word_emb(tgt_seq) da_domain = da[:, :len(domains)] da_func = da[:, len(domains):len(domains) + len(functions)] da_argument = da[:, len(domains) + len(functions):] dec_inp = self.prior_layer(dec_inp, enc_out, da_domain, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) dec_inp = self.middle_layer(dec_inp, enc_out, da_func, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) dec_inp = self.post_layer(dec_inp, enc_out, da_argument, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) dec_inp = self.final_layer(dec_inp, enc_out, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) return dec_inp
def forward(self, tgt_seq, src_seq, da, enc_out): tgt_seq, src_seq = tgt_seq.to(self.device), src_seq.to(self.device) enc_out = enc_out.to(self.device) non_pad_mask = get_non_pad_mask(tgt_seq) slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) pos_seq = torch.arange(0.0, tgt_seq.size(1), 1.0).to(self.device) dec_inp = self.word_emb(tgt_seq) + self.pos_emb(pos_seq, tgt_seq.size(0)) for layer in self.layers: dec_inp = layer(dec_inp, enc_out, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) return dec_inp
def forward(self, src_seq): src_seq = src_seq.to(self.device) non_pad_mask = get_non_pad_mask(src_seq) slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) pos_seq = torch.arange(0.0, src_seq.size(1), 1.0).to(self.device) enc_inp = self.word_emb(src_seq) + self.position_emb(pos_seq, src_seq.size(0)) attn = None for layer in self.layers: enc_inp, attn = layer(enc_inp, enc_inp, enc_inp, slf_attn_mask=slf_attn_mask, non_pad_mask=non_pad_mask) enc_output = enc_inp attn = attn.detach() return enc_output, attn
def forward(self, src_seq): """ :param src_seq: B * len :return: """ src_seq = src_seq.to(self.device) non_pad_mask = get_non_pad_mask(src_seq) slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) enc_inp = self.word_emb(src_seq) attn = None for i, layer in enumerate(self.layers): enc_inp, attn = layer(enc_inp, enc_inp, enc_inp, slf_attn_mask=slf_attn_mask, non_pad_mask=non_pad_mask) enc_out = enc_inp attn = attn.detach() return enc_out, attn