def forward( self, token: th.Tensor, hidden: Optional[th.Tensor] = None, token_len: Optional[th.Tensor] = None ) -> Tuple[th.Tensor, th.Tensor]: """ args: token: input token sequence, N x T hidden: previous sequence embeddings, T x N x E token_len: length of x, N or None return: output: N x T x V hidden: current sequence embeddings, T x N x E """ # N x T => T x N x V t = 0 if hidden is None else hidden.shape[0] token_embed = self.abs_pos_enc(self.vocab_embed(token), t=t) # h == None: training or eval in time = 0 hidden = token_embed if hidden is None else th.cat( [hidden, token_embed], dim=0) # tgt_mask: T x T tgt_mask = prep_sub_mask(hidden.shape[0], device=hidden.device) # src_pad_mask: N x T src_pad_mask = None if token_len is None else (padding_mask(token_len) == 1) # Ti x N x D enc_out = self.encoder(hidden, inj_pose=None, src_mask=tgt_mask, src_key_padding_mask=src_pad_mask) # Ti x N x V output = self.dist(enc_out) # N x Ti x V return output.transpose(0, 1), hidden
def test_aps_selfattn(index): S, L, N, E = 100, 100, 8, 256 self_attn = ApsMultiheadAttention(E, 4, dropout=0) self_attn.train() query = th.rand(L, N, E) if index == 0: key, value = query, query elif index == 1: key = th.rand(S, N, E) value = key else: key = th.rand(S, N, E) value = th.rand(S, N, E) key_len = th.randint(S // 2, S, (N, )) key_len[0] = S key_padding_mask = padding_mask(key_len) attn_mask = prep_sub_mask(S) my1, my2 = self_attn(query, key, value, None, key_padding_mask=key_padding_mask, attn_mask=attn_mask) th1, th2 = self_attn.torch_forward(query, key, value, key_padding_mask=key_padding_mask, attn_mask=attn_mask) assert my1.shape == th1.shape assert my2.shape == th2.shape th.testing.assert_allclose(my2, th2) th.testing.assert_allclose(my1, th1)
def forward( self, token: th.Tensor, h: Optional[th.Tensor] = None, token_len: Optional[th.Tensor] = None ) -> Tuple[th.Tensor, th.Tensor]: """ args: token: input token sequence, N x T h: previous sequence embeddings, T x N x E token_len: length of x, N or None return: output: N x T x V h: current sequence embeddings, T x N x E """ # N x T => T x N x V t = 0 if h is None else h.shape[0] x = self.abs_pos_enc(self.vocab_embed(token), t=t) # h == None: training or eval in time = 0 h = x if h is None else th.cat([h, x], dim=0) # src_pad_mask: N x T src_pad_mask = None if token_len is None else (padding_mask(token_len) == 1) tgt_mask = prep_sub_mask(t + 1, device=x.device) # N x Ti x D enc_out = self.encoder(h, mask=tgt_mask, src_key_padding_mask=src_pad_mask) # N x Ti x V output = self.dist(enc_out) return output, h
def step(self, enc_out: th.Tensor, tgt_pad: th.Tensor, enc_len: Optional[th.Tensor] = None, tgt_len: Optional[th.Tensor] = None, pre_emb: Optional[th.Tensor] = None, out_idx: Optional[int] = None) -> Tuple[th.Tensor]: """ Args: enc_out (Tensor): T x N x D tgt_pad (Tensor): N x To enc_len (Tensor): N or None pre_emb (Tensor): T' x N x D Return: dec_out (Tensor): T+T' x N x D or N x D tgt_emb (Tensor): T+T' x N x E """ # N x Ti offset = 0 if pre_emb is None else pre_emb.shape[0] mem_pad_mask = None if enc_len is None else (padding_mask(enc_len) == 1) tgt_pad_mask = None if tgt_len is None else (padding_mask(tgt_len) == 1) # N x T x E tgt_emb = self.vocab_embed(tgt_pad) # T x N x E tgt_emb = self.abs_pos_enc(tgt_emb, t=offset) # T+T' x N x E if pre_emb is not None: tgt_emb = th.cat([pre_emb, tgt_emb], dim=0) # T+T' x T+T' tgt_mask = prep_sub_mask(tgt_emb.shape[0], device=tgt_pad.device) # To+1 x N x D dec_out = self.decoder(tgt_emb, enc_out, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad_mask, memory_key_padding_mask=mem_pad_mask) if out_idx is not None: dec_out = dec_out[out_idx] # To+1 x N x V dec_out = self.output(dec_out) return dec_out, tgt_emb
def _process_mask(self, mask: th.Tensor, x_len: th.Tensor) -> th.Tensor: """ Process mask estimated by networks """ if mask is None: return mask if x_len is not None: zero_mask = padding_mask(x_len) # N x T mask = th.masked_fill(mask, zero_mask[..., None], 0) if self.mask_norm: max_abs = th.norm(mask, float("inf"), dim=1, keepdim=True) mask = mask / (max_abs + EPSILON) mask = th.transpose(mask, 1, 2) return mask
def forward(self, enc_out: th.Tensor, tgt_pad: th.Tensor, tgt_len: Optional[th.Tensor]) -> th.Tensor: """ Args: enc_out (Tensor): N x Ti x D tgt_pad (Tensor): N x To+1 (padding blank at time = 1) tgt_len (Tensor): N or None Return: output: N x Ti x To+1 x V """ # N x Ti pad_mask = None if tgt_len is None else (padding_mask(tgt_len) == 1) # genrarte target masks (-inf/0) tgt_mask = prep_sub_mask(tgt_pad.shape[-1], device=tgt_pad.device) # To+1 x N x E tgt_pad = self.abs_pos_enc(self.vocab_embed(tgt_pad)) # To+1 x N x D dec_out = self.decoder(tgt_pad, src_mask=tgt_mask, src_key_padding_mask=pad_mask) return self.pred(enc_out, dec_out.transpose(0, 1))
def forward(self, inp_pad: th.Tensor, inp_len: Optional[th.Tensor]) -> EncRetType: """ Go through projection layer Args: inp_pad: N x Ti x F inp_len: N or None Return: enc_inp: N x Ti x D inp_len: N or None src_pad_mask: N x Ti or None """ inp_len = self.proj.num_frames(inp_len) enc_inp = self.proj(inp_pad) src_pad_mask = None if inp_len is None else (padding_mask(inp_len) == 1) if self.type == "abs": # enc_inp: N x Ti x D => Ti x N x D enc_inp = self.pose(enc_inp) inj_pose = None else: # enc_inp: N x Ti x D => Ti x N x D enc_inp = enc_inp.transpose(0, 1) nframes = enc_inp.shape[0] # 2Ti-1 x D if self.type == "rel": inj_pose = self.pose( th.arange(-nframes + 1, nframes, device=enc_inp.device)) else: inj_pose = self.pose( th.arange(0, 2 * nframes - 1, 1.0, device=enc_inp.device)) # Ti x N x D enc_out = self.encoder(enc_inp, inj_pose=inj_pose, src_key_padding_mask=src_pad_mask) # N x Ti x D return enc_out.transpose(0, 1), inp_len