示例#1
0
 def forward(
         self,
         token: th.Tensor,
         hidden: Optional[th.Tensor] = None,
         token_len: Optional[th.Tensor] = None
 ) -> Tuple[th.Tensor, th.Tensor]:
     """
     args:
         token: input token sequence, N x T
         hidden: previous sequence embeddings, T x N x E
         token_len: length of x, N or None
     return:
         output: N x T x V
         hidden: current sequence embeddings, T x N x E
     """
     # N x T => T x N x V
     t = 0 if hidden is None else hidden.shape[0]
     token_embed = self.abs_pos_enc(self.vocab_embed(token), t=t)
     # h == None: training or eval in time = 0
     hidden = token_embed if hidden is None else th.cat(
         [hidden, token_embed], dim=0)
     # tgt_mask: T x T
     tgt_mask = prep_sub_mask(hidden.shape[0], device=hidden.device)
     # src_pad_mask: N x T
     src_pad_mask = None if token_len is None else (padding_mask(token_len)
                                                    == 1)
     # Ti x N x D
     enc_out = self.encoder(hidden,
                            inj_pose=None,
                            src_mask=tgt_mask,
                            src_key_padding_mask=src_pad_mask)
     # Ti x N x V
     output = self.dist(enc_out)
     # N x Ti x V
     return output.transpose(0, 1), hidden
示例#2
0
def test_aps_selfattn(index):
    S, L, N, E = 100, 100, 8, 256
    self_attn = ApsMultiheadAttention(E, 4, dropout=0)
    self_attn.train()
    query = th.rand(L, N, E)
    if index == 0:
        key, value = query, query
    elif index == 1:
        key = th.rand(S, N, E)
        value = key
    else:
        key = th.rand(S, N, E)
        value = th.rand(S, N, E)

    key_len = th.randint(S // 2, S, (N, ))
    key_len[0] = S
    key_padding_mask = padding_mask(key_len)
    attn_mask = prep_sub_mask(S)

    my1, my2 = self_attn(query,
                         key,
                         value,
                         None,
                         key_padding_mask=key_padding_mask,
                         attn_mask=attn_mask)
    th1, th2 = self_attn.torch_forward(query,
                                       key,
                                       value,
                                       key_padding_mask=key_padding_mask,
                                       attn_mask=attn_mask)
    assert my1.shape == th1.shape
    assert my2.shape == th2.shape
    th.testing.assert_allclose(my2, th2)
    th.testing.assert_allclose(my1, th1)
示例#3
0
 def forward(
         self,
         token: th.Tensor,
         h: Optional[th.Tensor] = None,
         token_len: Optional[th.Tensor] = None
 ) -> Tuple[th.Tensor, th.Tensor]:
     """
     args:
         token: input token sequence, N x T
         h: previous sequence embeddings, T x N x E
         token_len: length of x, N or None
     return:
         output: N x T x V
         h: current sequence embeddings, T x N x E
     """
     # N x T => T x N x V
     t = 0 if h is None else h.shape[0]
     x = self.abs_pos_enc(self.vocab_embed(token), t=t)
     # h == None: training or eval in time = 0
     h = x if h is None else th.cat([h, x], dim=0)
     # src_pad_mask: N x T
     src_pad_mask = None if token_len is None else (padding_mask(token_len)
                                                    == 1)
     tgt_mask = prep_sub_mask(t + 1, device=x.device)
     # N x Ti x D
     enc_out = self.encoder(h,
                            mask=tgt_mask,
                            src_key_padding_mask=src_pad_mask)
     # N x Ti x V
     output = self.dist(enc_out)
     return output, h
示例#4
0
 def step(
         self,
         pred_prev: th.Tensor,
         hidden: Optional[th.Tensor] = None) -> Tuple[th.Tensor, th.Tensor]:
     """
     Make one step for decoder
     Args:
         pred_prev: 1 x 1
         hidden: None or T x 1 x E
     Return:
         dec_out: 1 x D
     """
     t = 0 if hidden is None else hidden.shape[0]
     # 1 x 1 x E
     pred_prev_emb = self.abs_pos_enc(self.vocab_embed(pred_prev), t=t)
     hidden = pred_prev_emb if hidden is None else th.cat(
         [hidden, pred_prev_emb], dim=0)
     tgt_mask = prep_sub_mask(t + 1, device=pred_prev.device)
     dec_out = self.decoder(hidden, mask=tgt_mask)
     return dec_out[-1], hidden
示例#5
0
 def forward(self, enc_out: th.Tensor, tgt_pad: th.Tensor,
             tgt_len: Optional[th.Tensor]) -> th.Tensor:
     """
     Args:
         enc_out (Tensor): N x Ti x D
         tgt_pad (Tensor): N x To+1 (padding blank at time = 1)
         tgt_len (Tensor): N or None
     Return:
         output: N x Ti x To+1 x V
     """
     # N x Ti
     pad_mask = None if tgt_len is None else (padding_mask(tgt_len) == 1)
     # genrarte target masks (-inf/0)
     tgt_mask = prep_sub_mask(tgt_pad.shape[-1], device=tgt_pad.device)
     # To+1 x N x E
     tgt_pad = self.abs_pos_enc(self.vocab_embed(tgt_pad))
     # To+1 x N x D
     dec_out = self.decoder(tgt_pad,
                            src_mask=tgt_mask,
                            src_key_padding_mask=pad_mask)
     return self.pred(enc_out, dec_out.transpose(0, 1))