def forward_greedy(self,z,num_steps,temperature,x=None): predictions = [] batch_size = z.size(0) next_input = z.new_zeros(size=(batch_size,num_steps),dtype=torch.long,requires_grad=False) next_input[:,:] = self.PAD_TOKEN next_input[:,0] = self.SOS_TOKEN # <sos> token z = self.activation(self.z2h(z)).view(batch_size,1,-1).repeat(1,num_steps,1) for i in range(num_steps): input = next_input step_input = self.embedding(input) step_input = self.pos_embedding(step_input) step_input = torch.cat([step_input,z],dim=2) # step_input is of size [batch,seq_len,step_input_size] step_input = self.activation(self.s2h(step_input)) non_pad_mask = get_non_pad_mask(input,self.PAD_TOKEN) slf_attn_mask_subseq = get_subsequent_mask(input) slf_attn_mask_keypad = get_attn_key_pad_mask(input,self.PAD_TOKEN) attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) out = self.transformer(step_input,non_pad_mask=non_pad_mask,attn_mask=attn_mask) out = out[:,i,:] out = self.activation(out) out = self.h2o(out) out = self.last_activation(out,temperature) if x is not None: # teacher forcing previous_output = x[:,i] else: # use prediction as input previous_output = torch.argmax(out,dim=-1) previous_output = previous_output.detach() next_input = torch.cat([input[:,:i+1],previous_output.view(-1,1),input[:,i+2:]],dim=1).detach() predictions.append(out) output = torch.stack(predictions).transpose(1,0) return output
def forward(self, inputs, pos, src_inputs, enc_outputs, return_att=False): non_pad_mask = utils.get_non_pad_mask(inputs) self_attn_mask_subseq = utils.get_subsequent_mask(inputs) self_attn_mask_keypad = utils.get_att_key_pad_mask(seq_k=inputs, seq_q=inputs) self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0) dec_enc_attn_mask = utils.get_att_key_pad_mask(seq_k=src_inputs, seq_q=inputs) self_att_list = [] enc_att_list = [] output = self.embed(inputs) + self.position_enc(pos) for layer in self.layers: output, self_att, enc_att = layer(output, enc_outputs, non_pad_mask=non_pad_mask, input_att_mask=self_attn_mask, enc_att_mask=dec_enc_attn_mask, ) if return_att: self_att_list.append(self_att) enc_att_list.append(enc_att) output = self.proj(output) return output
def forward(self, padded_input, encoder_padded_outputs, encoder_input_lengths, return_attns=False): """ Args: padded_input: N x To encoder_padded_outputs: N x Ti x H Returns: """ dec_slf_attn_list, dec_enc_attn_list = [], [] # Get Deocder Input and Output ys_in_pad, ys_out_pad = self.preprocess(padded_input) # Prepare masks non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad, seq_q=ys_in_pad, pad_idx=self.eos_id) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) output_length = ys_in_pad.size(1) dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, encoder_input_lengths, output_length) # Forward dec_output = self.dropout( self.tgt_word_emb(ys_in_pad) * self.x_logit_scale + self.positional_encoding(ys_in_pad)) for dec_layer in self.layer_stack: dec_output, dec_slf_attn, dec_enc_attn = dec_layer( dec_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) if return_attns: dec_slf_attn_list += [dec_slf_attn] dec_enc_attn_list += [dec_enc_attn] # before softmax seq_logit = self.tgt_word_prj(dec_output) # Return pred, gold = seq_logit, ys_out_pad if return_attns: return pred, gold, dec_slf_attn_list, dec_enc_attn_list return pred, gold
def predict_word(dec_seq, src_seq, enc_output, n_active_inst, n_bm): src_mask = get_pad_mask(src_seq, PAD) dec_mask = get_pad_mask(dec_seq, PAD) & get_subsequent_mask(dec_seq) dec_output, *_ = self.model.decoder(dec_seq, dec_mask, enc_output, src_mask) dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h word_prob = F.log_softmax(self.model.trg_word_prj(dec_output), dim=1) word_prob = word_prob.view(n_active_inst, n_bm, -1) return word_prob
def forward(self, src_seq, trg_seq): src_mask = get_pad_mask(src_seq, self.src_pad_idx) trg_mask = get_pad_mask( trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq) enc_output, *_ = self.encoder(src_seq, src_mask) dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask) seq_logit = self.trg_word_prj(dec_output) * self.x_logit_scale # return seq_logit.view(-1, seq_logit.size(2)) return seq_logit
def decoder(decoder_model, model_cls, decode_input_ids, encode_outputs, encode_attention_mask): decode_attention_mask = get_subsequent_mask(decode_input_ids) decode_outputs, *_ = decoder_model( decode_input_ids, decode_attention_mask, encoder_output=encode_outputs, encoder_attention_mask=encode_attention_mask) logits = model_cls(decode_outputs) return logits
def recognize_beam(self, encoder_outputs, char_list, args): """Beam search, decode one utterence now. Args: encoder_outputs: T x H char_list: list of character args: args.beam Returns: nbest_hyps: """ # search params beam = args.beam_size nbest = args.nbest if args.decode_max_len == 0: maxlen = encoder_outputs.size(0) else: maxlen = args.decode_max_len encoder_outputs = encoder_outputs.unsqueeze(0) # prepare sos ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long() # yseq: 1xT hyp = {'score': 0.0, 'yseq': ys} hyps = [hyp] ended_hyps = [] for i in range(maxlen): hyps_best_kept = [] for hyp in hyps: ys = hyp['yseq'] # 1 x i # -- Prepare masks non_pad_mask = torch.ones_like(ys).float().unsqueeze( -1) # 1xix1 slf_attn_mask = get_subsequent_mask(ys) # -- Forward dec_output = self.dropout( self.tgt_word_emb(ys) * self.x_logit_scale + self.positional_encoding(ys)) for dec_layer in self.layer_stack: dec_output, _, _ = dec_layer(dec_output, encoder_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=None) seq_logit = self.tgt_word_prj(dec_output[:, -1]) local_scores = F.log_softmax(seq_logit, dim=1) # topk scores local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + local_best_scores[0, j] new_hyp['yseq'] = torch.ones( 1, (1 + ys.size(1))).type_as(encoder_outputs).long() new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq'] new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j]) # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # end for hyp in hyps hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'] = torch.cat([ hyp['yseq'], torch.ones(1, 1).fill_( self.eos_id).type_as(encoder_outputs).long() ], dim=1) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][0, -1] == self.eos_id: ended_hyps.append(hyp) else: remained_hyps.append(hyp) hyps = remained_hyps if len(hyps) > 0: print('remeined hypothes: ' + str(len(hyps))) else: print('no hypothesis. Finish decoding.') break for hyp in hyps: # print('fch: ') # print(str(x.encode('utf-8')) for x in char_list) # print(([(char_list[int(x)]).encode('utf-8') for x in hyp['yseq'][0, 1:]])) # print('hypo: ' + ''.join([str(char_list[int(x)].encode('utf-8')) # for x in hyp['yseq'][0, 1:]])) print('hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][0, 1:]])) # end for i in range(maxlen) nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # compitable with LAS implementation for hyp in nbest_hyps: hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist() return nbest_hyps
def forward_beam(self,z,num_steps,temperature): predictions = [] batch_size = z.size(0) next_input = z.new_zeros(size=(batch_size*self.beam_width,num_steps),dtype=torch.long,requires_grad=False) next_input[:,:] = self.PAD_TOKEN next_input[:,0] = self.SOS_TOKEN # <sos> token z = self.activation(self.z2h(z)).view(batch_size,1,-1).repeat(self.beam_width,num_steps,1) previous_output = z.new_zeros(size=(batch_size*self.beam_width,),dtype=torch.long) previous_output[:] = self.SOS_TOKEN # <sos> token # a table for storing the scores scores = z.new_zeros(size=(batch_size*self.beam_width,self.output_size)) # an array of numbers for displacement ie. if batch_size is 2 and beam_width is 3 then this is [0,0,0,3,3,3]. This is used later for indexing beam_displacement = torch.arange(start=0,end=batch_size*self.beam_width,step=self.beam_width,dtype=torch.long,device=z.device).view(-1,1).repeat(1,self.beam_width).view(-1) for i in range(num_steps): input = next_input.detach() step_input = self.embedding(input) step_input = self.pos_embedding(step_input) step_input = torch.cat([step_input,z],dim=2) # step_input is of size [batch,seq_len,step_input_size] step_input = self.activation(self.s2h(step_input)) non_pad_mask = get_non_pad_mask(input,self.PAD_TOKEN) slf_attn_mask_subseq = get_subsequent_mask(input) slf_attn_mask_keypad = get_attn_key_pad_mask(input,self.PAD_TOKEN) attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) out = self.transformer(step_input,non_pad_mask=non_pad_mask,attn_mask=attn_mask) out = out[:,i,:] out = self.activation(out) out = self.h2o(out) out = self.last_activation(out,temperature) # compute new scores next_scores = scores + torch.log(out+1e-8) # select top-k scores where k is the beam width score,outputs = next_scores.view(batch_size,-1).topk(self.beam_width,dim=-1) # flatten the output outputs = outputs.view(-1) # get the indices in the original onehot output by finding the module of the vocab size indices = torch.fmod(outputs,self.output_size) # find the index in the beam/batch for the onehot output. Add beam displacement to get correct index beam_indices = torch.div(outputs,self.output_size) + beam_displacement # check if some elements/words are repeated res = torch.eq(previous_output,indices).nonzero() # some elements/words is repeated retries = 0 while res.shape[0] > 0: mask = torch.ones(size=(batch_size*self.beam_width,self.output_size),requires_grad=False,device=z.device) # set the mask to be zero when an option is non selectable mask[beam_indices[res],indices[res]] = 0 # apply the mask out = out * mask # set the score for the repeated elements to be low next_scores = scores + torch.log(out+1e-8) # select top-k scores where k is the beam width score,outputs = next_scores.view(batch_size,-1).topk(self.beam_width,dim=-1) # flatten the output outputs = outputs.view(-1) # get the indices in the original onehot output by finding the module of the vocab size indices = torch.fmod(outputs,self.output_size) # find the index in the beam/batch for the onehot output. Add beam displacement to get correct index beam_indices = torch.div(outputs,self.output_size) + beam_displacement # check if some elements/words are repeated res = torch.eq(previous_output,indices).nonzero() if retries > 10: break retries += 1 # copy the score for each selected candidate scores = score.view(-1,1).repeat(1,self.output_size) # renormalize the output out = out/out.sum(-1).view(-1,1).repeat(1,self.output_size) # append the prediction to output predictions.append(out[beam_indices,:]) # detach the output such that we don't backpropagate through timesteps previous_output = indices.detach() next_input = torch.cat([input[:,:i+1],previous_output.view(-1,1),input[:,i+2:]],dim=1) output = torch.stack(predictions).transpose(1,0) # initialize an output_mask such that we can filter out sentences output_mask = torch.zeros_like(output) # set the selected sentences output_mask to 1 output_mask[scores[:,0].view(batch_size,-1).argmax(dim=-1) + beam_displacement.view(batch_size,-1)[:,0]] = 1 # collect the best prediction for each sample in batch output = (output*output_mask).view(batch_size,self.beam_width,num_steps,self.output_size) # sum the beam sentences. Since the sentences that is not selected is zero this doesn't change the actual sentences output = output.sum(1) return output
def recognize_beam(self, encoder_outputs, char_list, args): """ Beam search, decode one utterence now. Args: encoder_outputs: T x H #418 x 512 char_list: list of character #4233 args: args.beam #5 Returns: nbest_hyps: """ # search params beam = args.beam_size nbest = args.nbest if args.decode_max_len == 0: maxlen = encoder_outputs.size(0) else: maxlen = args.decode_max_len encoder_outputs = encoder_outputs.unsqueeze(0) # prepare sos ys = flow.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long() hyp = {"score": 0.0, "yseq": ys} hyps = [hyp] ended_hyps = [] for i in range(maxlen): hyps_best_kept = [] for hyp in hyps: ys = hyp["yseq"] ys = ys.to(device=encoder_outputs.device) # -- Prepare masks non_pad_mask = flow.ones_like(ys).to( dtype=flow.float32).unsqueeze(-1) slf_attn_mask = get_subsequent_mask(ys) # -- Forward dec_output = self.dropout( self.tgt_word_emb(ys) * self.x_logit_scale + self.positional_encoding(ys)) for dec_layer in self.layer_stack: dec_output, _, _ = dec_layer( dec_output, encoder_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=None, ) seq_logit = self.tgt_word_prj(dec_output[:, -1]) local_logit = F.softmax(seq_logit) local_scores = flow.log(local_logit) # topk scores local_best_scores, local_best_ids = flow.topk(local_scores, beam, dim=1) for j in range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + local_best_scores[0, j] new_hyp["yseq"] = (flow.ones( 1, (1 + ys.size(1))).type_as(encoder_outputs).long()) new_hyp["yseq"][:, :ys.size(1)] = hyp["yseq"] new_hyp["yseq"][:, ys.size(1)] = int( float(local_best_ids[0, j].numpy())) hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # end for hyp in hyps hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp["yseq"] = flow.cat( [ hyp["yseq"], flow.ones(1, 1).fill_( self.eos_id).type_as(encoder_outputs).long(), ], dim=1, ) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][0, -1] == self.eos_id: ended_hyps.append(hyp) else: remained_hyps.append(hyp) hyps = remained_hyps if len(hyps) > 0: print("remeined hypothes: " + str(len(hyps))) else: print("no hypothesis. Finish decoding.") break for hyp in hyps: print("hypo: " + "".join( [char_list[int(x.numpy())] for x in hyp["yseq"][0, 1:]])) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), nbest)] for hyp in nbest_hyps: hyp["yseq"] = hyp["yseq"][0].cpu().numpy().tolist() return nbest_hyps