def _get_tgt_emb(self, tgt, device): # De mask tgt_mask = ((_get_pad_mask(tgt).to(device=device).type(torch.uint8) & _get_subsequent_mask(tgt.size(-1)).type(torch.uint8).to(device=device))) if self.dec_emb_proj_flag: emb_tgt = self.dec_emb_proj(self.embedding_dropout(self.dec_embedder(tgt))) else: emb_tgt = self.embedding_dropout(self.dec_embedder(tgt)) return tgt_mask, emb_tgt
def _get_src_emb(self, src, device): # En mask src_mask_input = _get_pad_mask(src).to(device=device).type(torch.uint8) src_mask = ((_get_pad_mask(src).to(device=device).type(torch.uint8) & _get_subsequent_mask(src.size(-1)).type(torch.uint8).to(device=device))) if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj(self.embedding_dropout(self.enc_embedder(src))) else: emb_src = self.embedding_dropout(self.enc_embedder(src)) return src_mask, emb_src, src_mask_input
def forward_train(self, src, tgt, debug_flag=False, use_gpu=True): """ train enc + dec note: all output useful up to the second last element i.e. b x (len-1) e.g. [b,:-1] for preds - src: w1 w2 w3 <EOS> <PAD> <PAD> <PAD> ref: BOS w1 w2 w3 <EOS> <PAD> <PAD> tgt: w1 w2 w3 <EOS> <PAD> <PAD> dummy ref start with BOS, the last elem does not have ref! """ # import pdb; pdb.set_trace() # note: adding .type(torch.uint8) to be compatible with pytorch 1.1! # check gpu global device device = check_device(use_gpu) # run transformer src_mask = _get_pad_mask(src).to(device=device).type( torch.uint8) # b x len tgt_mask = ((_get_pad_mask(tgt).to(device=device).type(torch.uint8) & _get_subsequent_mask(self.max_seq_len).type( torch.uint8).to(device=device))) # b x len x dim_model if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj( self.embedding_dropout(self.enc_embedder(src))) else: emb_src = self.embedding_dropout(self.enc_embedder(src)) if self.dec_emb_proj_flag: emb_tgt = self.dec_emb_proj( self.embedding_dropout(self.dec_embedder(tgt))) else: emb_tgt = self.embedding_dropout(self.dec_embedder(tgt)) enc_outputs, *_ = self.enc(emb_src, src_mask=src_mask) # b x len x dim_model dec_outputs, *_ = self.dec(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask) logits = self.out(dec_outputs) # b x len x vocab_size logps = torch.log_softmax(logits, dim=2) preds = logps.data.topk(1)[1] return preds, logps, dec_outputs
def _get_src_emb(self, src, emb_src_dyn, device): # En mask src_mask_input = _get_pad_mask(src).to(device=device).type(torch.uint8) src_mask = ((_get_pad_mask(src).to(device=device).type(torch.uint8) & _get_subsequent_mask(src.size(-1)).type( torch.uint8).to(device=device))) emb_src_static = self.enc_embedder(src) # cat dynamic + static emb_src_comb = torch.cat((emb_src_static, emb_src_dyn), dim=2) # map if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj(self.embedding_dropout(emb_src_comb)) else: emb_src = self.embedding_dropout(emb_src_comb) return src_mask, emb_src, src_mask_input
def forward_translate_fast(self, src, beam_width=1, penalty_factor=1, use_gpu=True): """ require large memory - run on cpu """ # import pdb; pdb.set_trace() # check gpu global device device = check_device(use_gpu) # run dd src_mask = _get_pad_mask(src).type(torch.uint8).to( device=device) # b x 1 x len if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj(self.enc_embedder(src)) else: emb_src = self.enc_embedder(src) enc_outputs, *_ = self.enc(emb_src, src_mask=src_mask) # b x len x dim_model batch = src.size(0) length_in = src.size(1) length_out = self.max_seq_len eos_mask = torch.BoolTensor([False]).repeat( batch * beam_width).to(device=device) len_map = torch.Tensor([1 ]).repeat(batch * beam_width).to(device=device) preds = torch.Tensor([BOS]).repeat(batch, 1).type( torch.LongTensor).to(device=device) # repeat for beam_width times # a b c d -> aaa bbb ccc ddd # b x 1 x len -> (b x beam_width) x 1 x len src_mask_expand = src_mask.repeat(1, beam_width, 1).view(-1, 1, length_in) # b x len x dim_model -> (b x beam_width) x len x dim_model enc_outputs_expand = enc_outputs.repeat(1, beam_width, 1).view( -1, length_in, self.dim_model) # (b x beam_width) x len preds_expand = preds.repeat(1, beam_width).view(-1, preds.size(-1)) # (b x beam_width) scores_expand = torch.Tensor([0]).repeat(batch * beam_width).type( torch.FloatTensor).to(device=device) # loop over sequence length for i in range(1, self.max_seq_len): # gen: 0-30; ref: 1-31 # import pdb; pdb.set_trace() # Get k candidates for each beam, k^2 candidates in total (k=beam_width) tgt_mask_expand = (( _get_pad_mask(preds_expand).type(torch.uint8).to(device=device) & _get_subsequent_mask(preds_expand.size(-1)).type( torch.uint8).to(device=device))) if self.dec_emb_proj_flag: emb_tgt_expand = self.dec_emb_proj( self.dec_embedder(preds_expand)) else: emb_tgt_expand = self.dec_embedder(preds_expand) if i == 1: cache_decslf = None cache_encdec = None dec_output_expand, *_, cache_decslf, cache_encdec = self.dec( emb_tgt_expand, enc_outputs_expand, tgt_mask=tgt_mask_expand, src_mask=src_mask_expand, decode_speedup=True, cache_decslf=cache_decslf, cache_encdec=cache_encdec) logit_expand = self.out(dec_output_expand) # (b x beam_width) x len x vocab_size logp_expand = torch.log_softmax(logit_expand, dim=2) # (b x beam_width) x len x beam_width score_expand, pred_expand = logp_expand.data.topk(beam_width) # select current slice dec_output = dec_output_expand[:, i - 1] # (b x beam_width) x dim_model - nouse logp = logp_expand[:, i - 1, :] # (b x beam_width) x vocab_size - nouse pred = pred_expand[:, i - 1] # (b x beam_width) x beam_width score = score_expand[:, i - 1] # (b x beam_width) x beam_width # select k candidates from k^2 candidates if i == 1: # inital state, keep first k candidates # b x (beam_width x beam_width) -> b x (beam_width) -> (b x beam_width) x 1 score_select = scores_expand + score.reshape(batch, -1)[:,:beam_width]\ .contiguous().view(-1) scores_expand = score_select pred_select = pred.reshape( batch, -1)[:, :beam_width].contiguous().view(-1) preds_expand = torch.cat( (preds_expand, pred_select.unsqueeze(-1)), dim=1) else: # keep only 1 candidate when hitting eos # (b x beam_width) x beam_width eos_mask_expand = eos_mask.reshape(-1, 1).repeat(1, beam_width) eos_mask_expand[:, 0] = False # (b x beam_width) x beam_width score_temp = scores_expand.reshape(-1, 1) + score.masked_fill( eos_mask.reshape(-1, 1), 0).masked_fill( eos_mask_expand, -1e9) # length penalty score_temp = score_temp / (len_map.reshape(-1, 1)** penalty_factor) # select top k from k^2 # (b x beam_width^2 -> b x beam_width) score_select, pos = score_temp.reshape(batch, -1).topk(beam_width) scores_expand = score_select.view(-1) * (len_map.reshape( -1, 1)**penalty_factor).view(-1) # select correct elements according to pos pos = (pos + torch.range(0, (batch - 1) * (beam_width**2), (beam_width**2)).to(device=device).reshape( batch, 1)).long() r_idxs, c_idxs = pos // beam_width, pos % beam_width # b x beam_width pred_select = pred[r_idxs, c_idxs].view( -1) # b x beam_width -> (b x beam_width) # Copy the corresponding previous tokens. preds_expand[:, :i] = preds_expand[r_idxs.view(-1), : i] # (b x beam_width) x i # Set the best tokens in this beam search step preds_expand = torch.cat( (preds_expand, pred_select.unsqueeze(-1)), dim=1) # locate the eos in the generated sequences # eos_mask = (pred_select == EOS) + eos_mask # >=pt1.3 eos_mask = ((pred_select == EOS).type(torch.uint8) + eos_mask.type(torch.uint8)).type(torch.bool).type( torch.uint8) # >=pt1.1 len_map = len_map + torch.Tensor([1]).repeat( batch * beam_width).to(device=device).masked_fill(eos_mask, 0) # early stop if sum(eos_mask.int()) == eos_mask.size(0): break # select the best candidate preds = preds_expand.reshape( batch, -1)[:, :self.max_seq_len].contiguous() # b x len scores = scores_expand.reshape(batch, -1)[:, 0].contiguous() # b # select the worst candidate # preds = preds_expand.reshape(batch, -1) # [:, (beam_width - 1)*length : (beam_width)*length].contiguous() # b x len # scores = scores_expand.reshape(batch, -1)[:, -1].contiguous() # b return preds
def forward_eval_fast(self, src, debug_flag=False, use_gpu=True): """ require large memory - run on cpu """ # import pdb; pdb.set_trace() # check gpu global device device = check_device(use_gpu) batch = src.size(0) length_out = self.max_seq_len # run enc dec src_mask = _get_pad_mask(src).type(torch.uint8).to(device=device) if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj( self.embedding_dropout(self.enc_embedder(src))) else: emb_src = self.embedding_dropout(self.enc_embedder(src)) enc_outputs, enc_var = self.enc(emb_src, src_mask=src_mask) # record logps = torch.Tensor([-1e-4]).repeat( batch, length_out, self.dec_vocab_size).type(torch.FloatTensor).to(device=device) dec_outputs = torch.Tensor([0]).repeat( batch, length_out, self.dim_model).type(torch.FloatTensor).to(device=device) preds_save = torch.Tensor([PAD]).repeat(batch, length_out).type( torch.LongTensor).to(device=device) # used to update pred history # start from length = 1 preds = torch.Tensor([BOS]).repeat(batch, 1).type( torch.LongTensor).to(device=device) preds_save[:, 0] = preds[:, 0] for i in range(1, self.max_seq_len): # gen: 0-30; ref: 1-31 # import pdb; pdb.set_trace() tgt_mask = (( _get_pad_mask(preds).type(torch.uint8).to(device=device) & _get_subsequent_mask(preds.size(-1)).type( torch.uint8).to(device=device))) if self.dec_emb_proj_flag: emb_tgt = self.dec_emb_proj(self.dec_embedder(preds)) else: emb_tgt = self.dec_embedder(preds) if i == 1: cache_decslf = None cache_encdec = None dec_output, dec_var, *_, cache_decslf, cache_encdec = self.dec( emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask, decode_speedup=True, cache_decslf=cache_decslf, cache_encdec=cache_encdec) logit = self.out(dec_output) logp = torch.log_softmax(logit, dim=2) pred = logp.data.topk(1)[1] # b x :i # b x len x dim_model - [:,0,:] is dummy 0's dec_outputs[:, i, :] = dec_output[:, i - 1] # b x len x vocab_size - [:,0,:] is dummy -1e-4's # individual logps logps[:, i, :] = logp[:, i - 1, :] # b x len - [:,0] is BOS preds_save[:, i] = pred[:, i - 1].view(-1) # append current pred, length+1 preds = torch.cat((preds, pred[:, i - 1]), dim=1) if not debug_flag: return preds, logps, dec_outputs else: return preds, logps, dec_outputs, enc_var, dec_var
def forward_eval(self, src, debug_flag=False, use_gpu=True): """ eval enc + dec (beam_width = 1) all outputs following: tgt: <BOS> w1 w2 w3 <EOS> <PAD> gen: w1 w2 w3 <EOS> <PAD> <PAD> shift by 1, i.e. used input = <BOS> w1 <PAD> <PAD> gen output = dummy w2 dummy update prediction: assign w2(output[1]) to be input[2] """ # import pdb; pdb.set_trace() # check gpu global device device = check_device(use_gpu) batch = src.size(0) length_out = self.max_seq_len # run enc dec eos_mask = torch.BoolTensor([False]).repeat(batch).to(device=device) src_mask = _get_pad_mask(src).type(torch.uint8).to(device=device) if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj( self.embedding_dropout(self.enc_embedder(src))) else: emb_src = self.embedding_dropout(self.enc_embedder(src)) enc_outputs, enc_var = self.enc(emb_src, src_mask=src_mask) # record logps = torch.Tensor([-1e-4]).repeat( batch, length_out, self.dec_vocab_size).type(torch.FloatTensor).to(device=device) dec_outputs = torch.Tensor([0]).repeat( batch, length_out, self.dim_model).type(torch.FloatTensor).to(device=device) preds_save = torch.Tensor([PAD]).repeat(batch, length_out).type( torch.LongTensor).to(device=device) # used to update pred history # start from length = 1 preds = torch.Tensor([BOS]).repeat(batch, 1).type( torch.LongTensor).to(device=device) preds_save[:, 0] = preds[:, 0] for i in range(1, self.max_seq_len): # gen: 0-30; ref: 1-31 # import pdb; pdb.set_trace() tgt_mask = (( _get_pad_mask(preds).type(torch.uint8).to(device=device) & _get_subsequent_mask(preds.size(-1)).type( torch.uint8).to(device=device))) if self.dec_emb_proj_flag: emb_tgt = self.dec_emb_proj(self.dec_embedder(preds)) else: emb_tgt = self.dec_embedder(preds) dec_output, dec_var, *_ = self.dec(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask) logit = self.out(dec_output) logp = torch.log_softmax(logit, dim=2) pred = logp.data.topk(1)[1] # b x :i # eos_mask = (pred[:, i-1].squeeze(1) == EOS) + eos_mask # >=pt1.3 eos_mask = ((pred[:, i - 1].squeeze(1) == EOS).type(torch.uint8) + eos_mask.type(torch.uint8)).type(torch.bool).type( torch.uint8) # >=pt1.1 # b x len x dim_model - [:,0,:] is dummy 0's dec_outputs[:, i, :] = dec_output[:, i - 1] # b x len x vocab_size - [:,0,:] is dummy -1e-4's # individual logps logps[:, i, :] = logp[:, i - 1, :] # b x len - [:,0] is BOS preds_save[:, i] = pred[:, i - 1].view(-1) # append current pred, length+1 preds = torch.cat((preds, pred[:, i - 1]), dim=1) if sum(eos_mask.int()) == eos_mask.size(0): # import pdb; pdb.set_trace() if length_out != preds.size(1): dummy = torch.Tensor([PAD]).repeat( batch, length_out - preds.size(1)).type( torch.LongTensor).to(device=device) preds = torch.cat((preds, dummy), dim=1) # pad to max length break if not debug_flag: return preds, logps, dec_outputs else: return preds, logps, dec_outputs, enc_var, dec_var