示例#1
0
 def forward(
         self,
         token: th.Tensor,
         hidden: Optional[th.Tensor] = None,
         token_len: Optional[th.Tensor] = None
 ) -> Tuple[th.Tensor, th.Tensor]:
     """
     args:
         token: input token sequence, N x T
         hidden: previous sequence embeddings, T x N x E
         token_len: length of x, N or None
     return:
         output: N x T x V
         hidden: current sequence embeddings, T x N x E
     """
     # N x T => T x N x V
     t = 0 if hidden is None else hidden.shape[0]
     token_embed = self.abs_pos_enc(self.vocab_embed(token), t=t)
     # h == None: training or eval in time = 0
     hidden = token_embed if hidden is None else th.cat(
         [hidden, token_embed], dim=0)
     # tgt_mask: T x T
     tgt_mask = prep_sub_mask(hidden.shape[0], device=hidden.device)
     # src_pad_mask: N x T
     src_pad_mask = None if token_len is None else (padding_mask(token_len)
                                                    == 1)
     # Ti x N x D
     enc_out = self.encoder(hidden,
                            inj_pose=None,
                            src_mask=tgt_mask,
                            src_key_padding_mask=src_pad_mask)
     # Ti x N x V
     output = self.dist(enc_out)
     # N x Ti x V
     return output.transpose(0, 1), hidden
示例#2
0
def test_aps_selfattn(index):
    S, L, N, E = 100, 100, 8, 256
    self_attn = ApsMultiheadAttention(E, 4, dropout=0)
    self_attn.train()
    query = th.rand(L, N, E)
    if index == 0:
        key, value = query, query
    elif index == 1:
        key = th.rand(S, N, E)
        value = key
    else:
        key = th.rand(S, N, E)
        value = th.rand(S, N, E)

    key_len = th.randint(S // 2, S, (N, ))
    key_len[0] = S
    key_padding_mask = padding_mask(key_len)
    attn_mask = prep_sub_mask(S)

    my1, my2 = self_attn(query,
                         key,
                         value,
                         None,
                         key_padding_mask=key_padding_mask,
                         attn_mask=attn_mask)
    th1, th2 = self_attn.torch_forward(query,
                                       key,
                                       value,
                                       key_padding_mask=key_padding_mask,
                                       attn_mask=attn_mask)
    assert my1.shape == th1.shape
    assert my2.shape == th2.shape
    th.testing.assert_allclose(my2, th2)
    th.testing.assert_allclose(my1, th1)
示例#3
0
 def forward(
         self,
         token: th.Tensor,
         h: Optional[th.Tensor] = None,
         token_len: Optional[th.Tensor] = None
 ) -> Tuple[th.Tensor, th.Tensor]:
     """
     args:
         token: input token sequence, N x T
         h: previous sequence embeddings, T x N x E
         token_len: length of x, N or None
     return:
         output: N x T x V
         h: current sequence embeddings, T x N x E
     """
     # N x T => T x N x V
     t = 0 if h is None else h.shape[0]
     x = self.abs_pos_enc(self.vocab_embed(token), t=t)
     # h == None: training or eval in time = 0
     h = x if h is None else th.cat([h, x], dim=0)
     # src_pad_mask: N x T
     src_pad_mask = None if token_len is None else (padding_mask(token_len)
                                                    == 1)
     tgt_mask = prep_sub_mask(t + 1, device=x.device)
     # N x Ti x D
     enc_out = self.encoder(h,
                            mask=tgt_mask,
                            src_key_padding_mask=src_pad_mask)
     # N x Ti x V
     output = self.dist(enc_out)
     return output, h
示例#4
0
 def step(self,
          enc_out: th.Tensor,
          tgt_pad: th.Tensor,
          enc_len: Optional[th.Tensor] = None,
          tgt_len: Optional[th.Tensor] = None,
          pre_emb: Optional[th.Tensor] = None,
          out_idx: Optional[int] = None) -> Tuple[th.Tensor]:
     """
     Args:
         enc_out (Tensor): T x N x D
         tgt_pad (Tensor): N x To
         enc_len (Tensor): N or None
         pre_emb (Tensor): T' x N x D
     Return:
         dec_out (Tensor): T+T' x N x D or N x D
         tgt_emb (Tensor): T+T' x N x E
     """
     # N x Ti
     offset = 0 if pre_emb is None else pre_emb.shape[0]
     mem_pad_mask = None if enc_len is None else (padding_mask(enc_len)
                                                  == 1)
     tgt_pad_mask = None if tgt_len is None else (padding_mask(tgt_len)
                                                  == 1)
     # N x T x E
     tgt_emb = self.vocab_embed(tgt_pad)
     # T x N x E
     tgt_emb = self.abs_pos_enc(tgt_emb, t=offset)
     # T+T' x N x E
     if pre_emb is not None:
         tgt_emb = th.cat([pre_emb, tgt_emb], dim=0)
     # T+T' x T+T'
     tgt_mask = prep_sub_mask(tgt_emb.shape[0], device=tgt_pad.device)
     # To+1 x N x D
     dec_out = self.decoder(tgt_emb,
                            enc_out,
                            tgt_mask=tgt_mask,
                            tgt_key_padding_mask=tgt_pad_mask,
                            memory_key_padding_mask=mem_pad_mask)
     if out_idx is not None:
         dec_out = dec_out[out_idx]
     # To+1 x N x V
     dec_out = self.output(dec_out)
     return dec_out, tgt_emb
示例#5
0
 def _process_mask(self, mask: th.Tensor, x_len: th.Tensor) -> th.Tensor:
     """
     Process mask estimated by networks
     """
     if mask is None:
         return mask
     if x_len is not None:
         zero_mask = padding_mask(x_len)  # N x T
         mask = th.masked_fill(mask, zero_mask[..., None], 0)
     if self.mask_norm:
         max_abs = th.norm(mask, float("inf"), dim=1, keepdim=True)
         mask = mask / (max_abs + EPSILON)
     mask = th.transpose(mask, 1, 2)
     return mask
示例#6
0
 def forward(self, enc_out: th.Tensor, tgt_pad: th.Tensor,
             tgt_len: Optional[th.Tensor]) -> th.Tensor:
     """
     Args:
         enc_out (Tensor): N x Ti x D
         tgt_pad (Tensor): N x To+1 (padding blank at time = 1)
         tgt_len (Tensor): N or None
     Return:
         output: N x Ti x To+1 x V
     """
     # N x Ti
     pad_mask = None if tgt_len is None else (padding_mask(tgt_len) == 1)
     # genrarte target masks (-inf/0)
     tgt_mask = prep_sub_mask(tgt_pad.shape[-1], device=tgt_pad.device)
     # To+1 x N x E
     tgt_pad = self.abs_pos_enc(self.vocab_embed(tgt_pad))
     # To+1 x N x D
     dec_out = self.decoder(tgt_pad,
                            src_mask=tgt_mask,
                            src_key_padding_mask=pad_mask)
     return self.pred(enc_out, dec_out.transpose(0, 1))
示例#7
0
文件: encoder.py 项目: yt752/aps
    def forward(self, inp_pad: th.Tensor,
                inp_len: Optional[th.Tensor]) -> EncRetType:
        """
        Go through projection layer
        Args:
            inp_pad: N x Ti x F
            inp_len: N or None
        Return:
            enc_inp: N x Ti x D
            inp_len: N or None
            src_pad_mask: N x Ti or None
        """
        inp_len = self.proj.num_frames(inp_len)
        enc_inp = self.proj(inp_pad)
        src_pad_mask = None if inp_len is None else (padding_mask(inp_len)
                                                     == 1)

        if self.type == "abs":
            # enc_inp: N x Ti x D => Ti x N x D
            enc_inp = self.pose(enc_inp)
            inj_pose = None
        else:
            # enc_inp: N x Ti x D => Ti x N x D
            enc_inp = enc_inp.transpose(0, 1)
            nframes = enc_inp.shape[0]
            # 2Ti-1 x D
            if self.type == "rel":
                inj_pose = self.pose(
                    th.arange(-nframes + 1, nframes, device=enc_inp.device))
            else:
                inj_pose = self.pose(
                    th.arange(0, 2 * nframes - 1, 1.0, device=enc_inp.device))
        # Ti x N x D
        enc_out = self.encoder(enc_inp,
                               inj_pose=inj_pose,
                               src_key_padding_mask=src_pad_mask)
        # N x Ti x D
        return enc_out.transpose(0, 1), inp_len