def forward(self, padded_input, input_lengths, return_attns=False): """ Args: padded_input: N x T x D input_lengths: N Returns: enc_output: N x T x H """ enc_slf_attn_list = [] # Prepare masks non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) length = padded_input.size(1) slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length) # Forward enc_output = self.dropout( self.layer_norm_in(self.linear_in(padded_input)) + self.positional_encoding(padded_input)) for enc_layer in self.layer_stack: enc_output, enc_slf_attn = enc_layer(enc_output, non_pad_mask, slf_attn_mask) if return_attns: enc_slf_attn_list += [enc_slf_attn] if return_attns: return enc_output, enc_slf_attn_list return (enc_output, )
def forward(self, src_seq, return_attns=False): ''' src_seq? src_pos? -> for positional embedding ''' enc_slf_attn_list = [] src_pos = self.get_positional_input(src_seq).to(src_seq.device) # -- Prepare masks slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) # seq_k = seq_q ? non_pad_mask = get_non_pad_mask(src_seq) # -- Forward # use two embedding layer for transform the input? # print('Source:' ,src_seq.shape) # print('Word:', self.src_word_emb(src_seq).shape) # print('Position: ',self.position_enc(src_pos).shape) enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) for enc_layer in self.layer_stack: enc_output, enc_slf_attn = enc_layer(enc_output, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask) if return_attns: enc_slf_attn_list += [enc_slf_attn] if return_attns: return enc_output, enc_slf_attn_list # what's the usage of the attention map? return enc_output,
def forward(self, inputs, pos, src_inputs, enc_outputs, return_att=False): non_pad_mask = utils.get_non_pad_mask(inputs) self_attn_mask_subseq = utils.get_subsequent_mask(inputs) self_attn_mask_keypad = utils.get_att_key_pad_mask(seq_k=inputs, seq_q=inputs) self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0) dec_enc_attn_mask = utils.get_att_key_pad_mask(seq_k=src_inputs, seq_q=inputs) self_att_list = [] enc_att_list = [] output = self.embed(inputs) + self.position_enc(pos) for layer in self.layers: output, self_att, enc_att = layer(output, enc_outputs, non_pad_mask=non_pad_mask, input_att_mask=self_attn_mask, enc_att_mask=dec_enc_attn_mask, ) if return_att: self_att_list.append(self_att) enc_att_list.append(enc_att) output = self.proj(output) return output
def forward_greedy(self,z,num_steps,temperature,x=None): predictions = [] batch_size = z.size(0) next_input = z.new_zeros(size=(batch_size,num_steps),dtype=torch.long,requires_grad=False) next_input[:,:] = self.PAD_TOKEN next_input[:,0] = self.SOS_TOKEN # <sos> token z = self.activation(self.z2h(z)).view(batch_size,1,-1).repeat(1,num_steps,1) for i in range(num_steps): input = next_input step_input = self.embedding(input) step_input = self.pos_embedding(step_input) step_input = torch.cat([step_input,z],dim=2) # step_input is of size [batch,seq_len,step_input_size] step_input = self.activation(self.s2h(step_input)) non_pad_mask = get_non_pad_mask(input,self.PAD_TOKEN) slf_attn_mask_subseq = get_subsequent_mask(input) slf_attn_mask_keypad = get_attn_key_pad_mask(input,self.PAD_TOKEN) attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) out = self.transformer(step_input,non_pad_mask=non_pad_mask,attn_mask=attn_mask) out = out[:,i,:] out = self.activation(out) out = self.h2o(out) out = self.last_activation(out,temperature) if x is not None: # teacher forcing previous_output = x[:,i] else: # use prediction as input previous_output = torch.argmax(out,dim=-1) previous_output = previous_output.detach() next_input = torch.cat([input[:,:i+1],previous_output.view(-1,1),input[:,i+2:]],dim=1).detach() predictions.append(out) output = torch.stack(predictions).transpose(1,0) return output
def forward(self, padded_input, encoder_padded_outputs, encoder_input_lengths, return_attns=False): """ Args: padded_input: N x To encoder_padded_outputs: N x Ti x H Returns: """ dec_slf_attn_list, dec_enc_attn_list = [], [] # Get Deocder Input and Output ys_in_pad, ys_out_pad = self.preprocess(padded_input) # Prepare masks non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad, seq_q=ys_in_pad, pad_idx=self.eos_id) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) output_length = ys_in_pad.size(1) dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, encoder_input_lengths, output_length) # Forward dec_output = self.dropout( self.tgt_word_emb(ys_in_pad) * self.x_logit_scale + self.positional_encoding(ys_in_pad)) for dec_layer in self.layer_stack: dec_output, dec_slf_attn, dec_enc_attn = dec_layer( dec_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) if return_attns: dec_slf_attn_list += [dec_slf_attn] dec_enc_attn_list += [dec_enc_attn] # before softmax seq_logit = self.tgt_word_prj(dec_output) # Return pred, gold = seq_logit, ys_out_pad if return_attns: return pred, gold, dec_slf_attn_list, dec_enc_attn_list return pred, gold
def forward(self, inputs, pos, return_att=False): self_att_mask = utils.get_att_key_pad_mask(seq_k=inputs, seq_q=inputs) non_pad_mask = utils.get_non_pad_mask(inputs) att_weights_list = [] output = self.embed(inputs) output = output + self.position_enc(pos) for enc_layer in self.layers: output, att_weights = enc_layer(output, non_pad_mask=non_pad_mask, mh_att_mask=self_att_mask) if return_att: att_weights_list.append(att_weights) return output
def forward_beam(self,z,num_steps,temperature): predictions = [] batch_size = z.size(0) next_input = z.new_zeros(size=(batch_size*self.beam_width,num_steps),dtype=torch.long,requires_grad=False) next_input[:,:] = self.PAD_TOKEN next_input[:,0] = self.SOS_TOKEN # <sos> token z = self.activation(self.z2h(z)).view(batch_size,1,-1).repeat(self.beam_width,num_steps,1) previous_output = z.new_zeros(size=(batch_size*self.beam_width,),dtype=torch.long) previous_output[:] = self.SOS_TOKEN # <sos> token # a table for storing the scores scores = z.new_zeros(size=(batch_size*self.beam_width,self.output_size)) # an array of numbers for displacement ie. if batch_size is 2 and beam_width is 3 then this is [0,0,0,3,3,3]. This is used later for indexing beam_displacement = torch.arange(start=0,end=batch_size*self.beam_width,step=self.beam_width,dtype=torch.long,device=z.device).view(-1,1).repeat(1,self.beam_width).view(-1) for i in range(num_steps): input = next_input.detach() step_input = self.embedding(input) step_input = self.pos_embedding(step_input) step_input = torch.cat([step_input,z],dim=2) # step_input is of size [batch,seq_len,step_input_size] step_input = self.activation(self.s2h(step_input)) non_pad_mask = get_non_pad_mask(input,self.PAD_TOKEN) slf_attn_mask_subseq = get_subsequent_mask(input) slf_attn_mask_keypad = get_attn_key_pad_mask(input,self.PAD_TOKEN) attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) out = self.transformer(step_input,non_pad_mask=non_pad_mask,attn_mask=attn_mask) out = out[:,i,:] out = self.activation(out) out = self.h2o(out) out = self.last_activation(out,temperature) # compute new scores next_scores = scores + torch.log(out+1e-8) # select top-k scores where k is the beam width score,outputs = next_scores.view(batch_size,-1).topk(self.beam_width,dim=-1) # flatten the output outputs = outputs.view(-1) # get the indices in the original onehot output by finding the module of the vocab size indices = torch.fmod(outputs,self.output_size) # find the index in the beam/batch for the onehot output. Add beam displacement to get correct index beam_indices = torch.div(outputs,self.output_size) + beam_displacement # check if some elements/words are repeated res = torch.eq(previous_output,indices).nonzero() # some elements/words is repeated retries = 0 while res.shape[0] > 0: mask = torch.ones(size=(batch_size*self.beam_width,self.output_size),requires_grad=False,device=z.device) # set the mask to be zero when an option is non selectable mask[beam_indices[res],indices[res]] = 0 # apply the mask out = out * mask # set the score for the repeated elements to be low next_scores = scores + torch.log(out+1e-8) # select top-k scores where k is the beam width score,outputs = next_scores.view(batch_size,-1).topk(self.beam_width,dim=-1) # flatten the output outputs = outputs.view(-1) # get the indices in the original onehot output by finding the module of the vocab size indices = torch.fmod(outputs,self.output_size) # find the index in the beam/batch for the onehot output. Add beam displacement to get correct index beam_indices = torch.div(outputs,self.output_size) + beam_displacement # check if some elements/words are repeated res = torch.eq(previous_output,indices).nonzero() if retries > 10: break retries += 1 # copy the score for each selected candidate scores = score.view(-1,1).repeat(1,self.output_size) # renormalize the output out = out/out.sum(-1).view(-1,1).repeat(1,self.output_size) # append the prediction to output predictions.append(out[beam_indices,:]) # detach the output such that we don't backpropagate through timesteps previous_output = indices.detach() next_input = torch.cat([input[:,:i+1],previous_output.view(-1,1),input[:,i+2:]],dim=1) output = torch.stack(predictions).transpose(1,0) # initialize an output_mask such that we can filter out sentences output_mask = torch.zeros_like(output) # set the selected sentences output_mask to 1 output_mask[scores[:,0].view(batch_size,-1).argmax(dim=-1) + beam_displacement.view(batch_size,-1)[:,0]] = 1 # collect the best prediction for each sample in batch output = (output*output_mask).view(batch_size,self.beam_width,num_steps,self.output_size) # sum the beam sentences. Since the sentences that is not selected is zero this doesn't change the actual sentences output = output.sum(1) return output