Example #1
0
    def forward(self, tgt_seq, src_seq, da, enc_out):
        tgt_seq, src_seq = tgt_seq.to(self.device), src_seq.to(self.device)
        da, enc_out = da.to(self.device), enc_out.to(self.device)
        non_pad_mask = get_non_pad_mask(tgt_seq)
        slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
        dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)

        dec_inp = self.word_emb(tgt_seq)
        da_domain = da[:, :len(domains)]
        da_func = da[:, len(domains):len(domains) + len(functions)]
        da_argument = da[:, len(domains) + len(functions):]
        dec_inp = self.prior_layer(dec_inp, enc_out, da_domain,
                                   non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask,
                                   dec_enc_attn_mask=dec_enc_attn_mask)
        dec_inp = self.middle_layer(dec_inp, enc_out, da_func,
                                    non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask,
                                    dec_enc_attn_mask=dec_enc_attn_mask)
        dec_inp = self.post_layer(dec_inp, enc_out, da_argument,
                                  non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask,
                                  dec_enc_attn_mask=dec_enc_attn_mask)
        dec_inp = self.final_layer(dec_inp, enc_out,
                                   non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask,
                                   dec_enc_attn_mask=dec_enc_attn_mask)
        return dec_inp
Example #2
0
    def forward(self, tgt_seq, src_seq, da, enc_out):
        tgt_seq, src_seq = tgt_seq.to(self.device), src_seq.to(self.device)
        enc_out = enc_out.to(self.device)
        non_pad_mask = get_non_pad_mask(tgt_seq)
        slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
        dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)

        pos_seq = torch.arange(0.0, tgt_seq.size(1), 1.0).to(self.device)
        dec_inp = self.word_emb(tgt_seq) + self.pos_emb(pos_seq, tgt_seq.size(0))
        for layer in self.layers:
            dec_inp = layer(dec_inp, enc_out, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask,
                            dec_enc_attn_mask=dec_enc_attn_mask)
        return dec_inp
Example #3
0
 def forward(self, src_seq):
     src_seq = src_seq.to(self.device)
     non_pad_mask = get_non_pad_mask(src_seq)
     slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
     pos_seq = torch.arange(0.0, src_seq.size(1), 1.0).to(self.device)
     enc_inp = self.word_emb(src_seq) + self.position_emb(pos_seq, src_seq.size(0))
     attn = None
     for layer in self.layers:
         enc_inp, attn = layer(enc_inp, enc_inp, enc_inp, slf_attn_mask=slf_attn_mask, non_pad_mask=non_pad_mask)
     enc_output = enc_inp
     attn = attn.detach()
     return enc_output, attn
Example #4
0
    def forward(self, src_seq):
        """

        :param src_seq: B * len
        :return:
        """
        src_seq = src_seq.to(self.device)
        non_pad_mask = get_non_pad_mask(src_seq)
        slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
        enc_inp = self.word_emb(src_seq)
        attn = None
        for i, layer in enumerate(self.layers):
            enc_inp, attn = layer(enc_inp, enc_inp, enc_inp, slf_attn_mask=slf_attn_mask, non_pad_mask=non_pad_mask)
        enc_out = enc_inp
        attn = attn.detach()
        return enc_out, attn