def _decode_and_generate( self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None, cls_bank=None ): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx ) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch # cls_bank = None # cls_bank = memory_bank[:, 0:1, :] # dec_out, dec_attn = self.model.decoder( # decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, adapter=True, cls_bank=cls_bank) dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, adapter=True) # Generator forward. if not self.copy_attn: attn = dec_attn["std"] log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores( scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset ) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _decode_and_generate(self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch self.model.decoder.set_copy_info(batch, self._tgt_vocab) dec_out, dec_attn = self.model.decoder(decoder_in, memory_bank, memory_lengths=memory_lengths, step=step) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] #print("DEC_OUT: ", dec_out.size()) #print("ATTN: ", attn.size()) scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(-1, batch.batch_size, scores.size(-1)) scores = scores.transpose(0, 1).contiguous() else: scores = scores.view(-1, self.beam_size, scores.size(-1)) #print("TGT_VOCAB: ", self._tgt_vocab) scores = collapse_copy_scores(scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() #print(log_probs.size()) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _decode_and_generate( self, decoder_in, memory_bank, batch, data, memory_lengths, src_map=None, step=None, batch_offset=None ): tgt_field = self.fields["tgt"][0][1].base_field unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token] if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(len(tgt_field.vocab) - 1), unk_idx ) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step ) # Generator forward. if not self.copy_attn: attn = dec_attn["std"] log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores( scores, batch, tgt_field.vocab, data.src_vocabs, batch_dim=0, batch_offset=batch_offset ) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _decode_and_generate( self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx ) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step ) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores( scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset ) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _decode_and_generate( self, decoder_in, memory_bank, batch, memory_lengths, src_map=None, step=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx ) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step ) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn_key = 'std' else: attn_key = None log_probs = self.model.generator(dec_out.squeeze(0)) scores = log_probs.exp() # unfortunate but torch.Categorical want softmax else: attn = dec_attn["copy"] attn_key= 'copy' scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) scores = scores.view(batch.batch_size, -1, scores.size(-1)) scores = collapse_copy_scores( scores, batch, self._tgt_vocab, src_vocabs=None, batch_dim=0, batch_offset=0 ) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) return scores, dec_attn, attn_key
def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = self.state["input_feed"].squeeze(0) input_feed_batch, _ = input_feed.size() _, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. dec_outs = [] attns = {} if self.attn is not None: attns["std"] = [] if self.copy_attn is not None or self._reuse_copy_attn: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim dec_state = self.state["hidden"] coverage = self.state["coverage"].squeeze(0) \ if self.state["coverage"] is not None else None # Input feed concatenates hidden state with # input at every time step. #GN: This is the loop that needs to be modified #for emb_t in emb.split(1): #print("TARGET[0]: ", tgt[0]) #print("LEN TARGET[0]: ", str(len(tgt[0]))) #temp = torch.ones(1,dtype=int) #temp[0] = 5 #temp = temp.unsqueeze(1) #temp = temp.unsqueeze(1) #temp2 = self.embeddings(temp) #print("TEMP: ", temp) #print("TEMP2: ", temp2) for t in range(0, len(tgt)): if t == 0 or self.eval_status == True: #emb_t = self.embeddings([t]) emb_t = emb.split(1)[t] #Start symbol elif self.teacher_forcing == "teacher": #emb_t = self.embeddings([t]) emb_t = emb.split(1)[t] #Use gold output #print("DEC: " + str(len(dec_outs))) #print("TGT: " + str(len(tgt))) else: #t_value = top_labels[0] #Use predicted output if self.teacher_forcing == "random": rep_t = torch.ones(len(tgt[0]), dtype=int) for batch in range(len(tgt[0])): t_value = random.randint( 0, self.vocab_size - 1) #Randomly select a member of the vocab rep_t[batch] = t_value rep_t = rep_t.unsqueeze(1) rep_t = rep_t.unsqueeze(1) elif self.teacher_forcing == "student" or self.teacher_forcing == "dist": rep_t = torch.ones(len(tgt[0]), dtype=int) '''if self.copy_attn is None: log_probs = self.generator(decoder_output.squeeze(0)) else: attn = attns["copy"] scores = self.generator(decoder_output.view(-1, decoder_output.size(2)), attn.view(-1, attn.size(2)), self.batch.src_map) #if batch_offset is None: #Not a beam search, batch_offset doesn't make sense in this case scores = scores.view(-1, self.batch.batch_size, scores.size(-1)) scores = scores.transpose(0,1).contiguous() #else: # scores = scores.view(-1, self.beam_size, scores.size(-1)) src_vocabs = None #If this happens, the collapse function backs off to the back source, which is fine scores = collapse_copy_scores(scores, self.batch, self.tgt_vocab, src_vocabs, batch_dim=0) scores = scores.view(decoder_input.size(0), -1, scores.size(-1)) #decoder input is still from last t_value, so it should be fine log_probs = scores.squeeze(0).log() ''' for batch_id in range(len(tgt[0])): #print(log_probs[batch]) top_probs, top_labels = torch.topk( log_probs[batch_id], self.vocab_size) #"COPY" is also an option #print("PROBS: ", top_probs.size()) #print("LABELS: ", top_labels.size()) top_probs = top_probs.squeeze(0).tolist() top_probs = np.exp( top_probs ) #Normalization is required due to some extra weight that is lost in the log/exp conversion top_probs /= np.sum(top_probs) top_labels = top_labels.squeeze(0).tolist() #print("PROBS: ", top_probs) #print("LABELS: ", top_labels) #print("TOP PROBS: " , top_probs) #print("TOP LABELS: " , top_labels) #top_probs = top_probs[batch].tolist() #top_labels = top_labels[batch].tolist() if (self.teacher_forcing == "student"): t_value = top_labels[0] elif (self.teacher_forcing == "dist"): rand_val = random.uniform(0, 1) rand_sum = 0.0 index = 0 while (rand_sum < rand_val): # and index < len(top_labels)): #print("INDEX: ", index) #print("VAL: ", rand_val) #print("SUM: ", rand_sum) rand_sum += top_probs[index] np.exp(top_probs[index]) index += 1 t_value = top_labels[index - 1] #rep_t[batch_id] = t_value rep_t = rep_t.unsqueeze(1) rep_t = rep_t.unsqueeze(1) #rep_t = dtorch.ones(len(tgt[0]),dtype=int) #rep_t[0] = t_value #rep_t = rep_t.unsqueeze(1) #rep_t = rep_t.unsqueeze(1) emb_t = self.embeddings(rep_t).squeeze(1) #print("DEC: " + str(dec_outs[-1])) #print("TGT: " + str(emb.split(1)[t])) # elif opt.teacher_forcing == "rand": # emb_t = self.embeddings(random) # elif opt.teacher_forcing == "dist": # rand = random_integer # #print(emb_t.squeeze(0).size()) #print(input_feed.size()) if (emb_t.dim() > 2): decoder_input = torch.cat([emb_t.squeeze(0), input_feed], 1) else: decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, dec_state = self.rnn(decoder_input, dec_state) #print("SIZE: ", dec_state.size()) if self.attentional: decoder_output, p_attn = self.attn( rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"].append(p_attn) else: decoder_output = rnn_output if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output = self.context_gate(decoder_input, rnn_output, decoder_output) decoder_output = self.dropout(decoder_output) #log_probs = self.rnn.generator(decoder_output.squeeze(0)) #print("PROBS: ", log_probs) #top_probs, top_labels = torch.topk(probs,len(probs[0])) #top_probs = top_probs[0].tolist() #top_labels = top_labels[0].tolist() input_feed = decoder_output dec_outs += [decoder_output] # Update the coverage attention. if self._coverage: coverage = p_attn if coverage is None else p_attn + coverage attns["coverage"] += [coverage] if self.copy_attn is not None: _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._reuse_copy_attn: attns["copy"] = attns["std"] decoder_output = rnn_output #print("DEC: ", decoder_output.size()) #print("ATTN: ", copy_attn.size()) if self.eval_status == False: if self.copy_attn is None: log_probs = self.generator(decoder_output.squeeze(0)) else: attn = attns["copy"] #print("SRC_MAP: ", self.batch.src_map.size()) #print("ATTN: ", copy_attn.size()) #src_map = torch.zeros(copy_attn.size(1), self.batch.src.size(0), self.batch.src.size(1), dtype=torch.float) #print(self.batch) scores = self.generator( decoder_output.view(-1, decoder_output.size(1)), copy_attn.view(-1, copy_attn.size(1)), self.batch.src_map) #if batch_offset is None: #Not a beam search, batch_offset doesn't make sense in this case scores = scores.view(-1, self.batch.batch_size, scores.size(-1)) scores = scores.transpose(0, 1).contiguous() #else: # scores = scores.view(-1, self.beam_size, scores.size(-1)) src_vocabs = None #If this happens, the collapse function backs off to the batch source, which is fine scores = collapse_copy_scores(scores, self.batch, self.tgt_vocab, src_vocabs, batch_dim=0) scores = scores.view( decoder_input.size(0), -1, scores.size(-1) ) #decoder input is still from last t_value, so it should be fine log_probs = scores.squeeze(1).log() #log_probs = scores.squeeze(0).log() return dec_state, dec_outs, attns
def _decode_and_generate( self, decoder_in, #states, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None): #masks, expand_masks = states.get_mask() if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder(decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, generator_id=self.generator_id) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator[self.generator_id]( dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence #log_probs = log_probs.exp().mul(masks).log() else: attn = dec_attn["copy"] #print(dec_out.size()) #print(attn.size()) scores = self.model.generator[self.generator_id](dec_out.view( -1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) #print(scores.size()) #print(batch_offset) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores( scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset ) # some copied words are the same, the scores of the same word should be sumed. scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) #expand_masks = expand_masks.expand(expand_masks.size(0), scores.size(2) - masks.size(1)) #masks = torch.cat([masks, expand_masks], 1) #scores = scores.mul(masks) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _decode_and_generate( self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None, verbose=False): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx ) # Decoder forward, takes [tgt_len, batch, nfeats] as input [1,20,1] # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step ) # dec_out has shape [1, 20, 512] 512 is wordvec size # dec_attn has shape [1, 20, 1311] 1311 is src length # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] # batch_offset is not none, # scores: [20, 50511] where 50511 is the size of the extended vocab if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores( scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset ) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _compute_loss(self, batch, output, target, copy_attn, align, std_attn=None, coverage_attn=None): """Compute the loss. The args must match :func:`self._make_shard_state()`. Args: batch: the current batch. output: the predict output from the model. target: the validate target to compare output with. copy_attn: the copy attention value. align: the align info. """ target = target.view(-1) align = align.view(-1) src_map_list = list() for src_type in self.src_types: src_map_list.append(getattr(batch, f"src_map.{src_type}")) # end for scores = self.generator(self._bottle(output), self._bottle(copy_attn), src_map_list) loss = self.criterion(scores, align, target) if self.lambda_coverage != 0.0: coverage_loss = self._compute_coverage_loss( std_attn, coverage_attn) loss += coverage_loss # this block does not depend on the loss value computed above # and is used only for stats scores_data = collapse_copy_scores( self._unbottle(scores.clone(), batch.batch_size), batch, self.tgt_vocab, None) scores_data = self._bottle(scores_data) # this block does not depend on the loss value computed above # and is used only for stats # Correct target copy token instead of <unk> # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 target_data = target.clone() unk = self.criterion.unk_index correct_mask = (target_data == unk) & (align != unk) offset_align = align[correct_mask] + len(self.tgt_vocab) target_data[correct_mask] += offset_align # Compute sum of perplexities for stats stats = self._stats(loss.sum().clone(), scores_data, target_data) # this part looks like it belongs in CopyGeneratorLoss if self.normalize_by_length: # Compute Loss as NLL divided by seq length tgt_lens = batch.tgt[:, :, 0].ne(self.padding_idx).sum(0).float() # Compute Total Loss per sequence in batch loss = loss.view(-1, batch.batch_size).sum(0) # Divide by length of each sequence and sum loss = torch.div(loss, tgt_lens).sum() else: loss = loss.sum() return loss, stats
def _decode_and_generate(self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None, batch_indices=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch # pdb.set_trace() if self.user_emb: if batch_indices != None: uid = batch.uid[batch_indices].unsqueeze(-1).expand_as( decoder_in.view(3, -1, 1)).reshape(-1) dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, uid=uid) else: uid = batch.uid.unsqueeze(-1).expand_as( decoder_in.view(3, -1, 1)).reshape(-1) dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, uid=uid) else: dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step) # Generator forward. if not self.copy_attn: attn = dec_attn["std"] log_probs = self.model.generator(dec_out.squeeze(0)) if self.user_bias != 'none': sfm = torch.nn.LogSoftmax(dim=-1) # print(log_probs.shape) out = log_probs.view(self.beam_size, -1, log_probs.shape[-1]) if self.user_bias == 'factor_cell': if batch_indices != None: out = out + torch.matmul( self.model.user_bias(batch.uid[batch_indices]), self.model.user_global) else: out = out + torch.matmul( self.model.user_bias(batch.uid), self.model.user_global) else: # import pdb # pdb.set_trace() if batch_indices != None: out = out + self.model.user_bias( batch.uid[batch_indices]) else: out = out + self.model.user_bias(batch.uid) log_probs = out.view(out.shape[0] * out.shape[1], -1) log_probs = sfm(log_probs).cuda() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores(scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _decode_and_generate( self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None, batch_emotion=None, batch_tgt_concept_emb=None, batch_tgt_concept_words=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx ) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch # print(step) # print(decoder_in.shape) # print(memory_bank.shape) # print(batch.emotion.shape) # print(batch_offset) # batch_emotion = None # if hasattr(batch, "emotion"): # batch_emotion = batch.emotion # src_len, batch_size, hidden = memory_bank.shape # # expand emotion beam_size times # batch_emotion = tile(batch_emotion, batch_size//len(batch_emotion)) if isinstance(self.model.decoder, KGTransformerDecoder): dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, emotion=batch_emotion, tgt_concept_emb=batch_tgt_concept_emb, tgt_concept_words=batch_tgt_concept_words ) else: dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, emotion=batch_emotion ) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None if self.model.decoder.__class__.__name__ == "EDSDecoder": output = dec_out tgt_len, batch_size, hidden_size = output.size() # if not hasattr(self.model.decoder, "eds_type"): # self.model.decoder.eds_type = 0 ########### ########### # method 0: vanilla probs if self.model.decoder.eds_type in [0]: # print(output.shape) log_probs = self.model.generator[0](output.squeeze(0)) # print(log_probs.shape) # log_probs = self.model.generator(dec_out.squeeze(0)) ########### ########### # method 1: mask on logits, other emotional words have zero logits elif self.model.decoder.eds_type in [1]: logits = F.relu(self.model.generator[0][:2](output)) # [tgt_len, batch, vocab], relu refers to the official code in training # logits = torch.sigmoid(self.model.generator[0][:2](output)) # [tgt_len, batch, vocab], sigmoid refers to the official code in generating response type_controller = torch.sigmoid(self.model.generator[1](output)) # [tgt_len, batch, 1] # emotion = batch.emotion # beam_size = batch_size//len(emotion) # expand emotion beam_size times # emotion = tile(emotion, beam_size) mask = self.model.decoder.generic_mask.unsqueeze(0).unsqueeze(0) * type_controller + \ self.get_emotion_mask(batch_emotion).unsqueeze(0) * (1-type_controller) logits = mask * logits # [tgt_len, batch, vocab] log_probs = self.model.generator[0][2](logits.squeeze(0)) ########### ########### # method 2: two separate generators for generic and emotional words # generic vocab = vocab_size - emotion_vocab_size elif self.model.decoder.eds_type in [2]: all_probs = torch.zeros((tgt_len, batch_size, self.model.decoder.vocab_size)).type_as(output) generic_probs = torch.softmax(self.model.generator[0][:2](output), dim=-1) # [tgt_len, batch, generic_vocab] emotion_probs = torch.softmax(self.model.generator[1][:2](output), dim=-1) # [tgt_len, batch, emotion_vocab] type_controller = torch.sigmoid(self.model.generator[2](output)) # [tgt_len, batch, 1] generic_probs = (1-type_controller)*generic_probs emotion_probs = type_controller*emotion_probs for i in range(batch_size): generic_indices = torch.cat([self.model.decoder.generic_vocab_indices, self.model.decoder.other_emotion_indices[batch_emotion[i]]]) all_probs[:,i,generic_indices] = generic_probs[:,i,:] # [tgt_len, generic_vocab] all_probs[:,i,self.model.decoder.emotion_vocab_indices[batch_emotion[i]]] = emotion_probs[:,i,:] # [tgt_len, emotion_vocab] log_probs = torch.log(all_probs.squeeze(0)) ########### ########### # method 3: two separate generators for generic and emotional words # emotion_vocab = vocab_size - generic_vocab elif self.model.decoder.eds_type in [3]: all_probs = torch.zeros((tgt_len, batch_size, self.model.decoder.vocab_size)).type_as(output) generic_probs = torch.softmax(self.model.generator[0][:2](output), dim=-1) # [tgt_len, batch, generic_vocab] emotion_probs = torch.softmax(self.model.generator[1][:2](output), dim=-1) # [tgt_len, batch, emotion_vocab] type_controller = torch.sigmoid(self.model.generator[2](output)) # [tgt_len, batch, 1] generic_probs = (1-type_controller)*generic_probs emotion_probs = type_controller*emotion_probs all_probs[:,:,self.model.decoder.generic_vocab_indices] = generic_probs # [tgt_len, batch, generic_vocab] all_probs[:,:,self.model.decoder.all_emotion_indices] = emotion_probs # [tgt_len, batch, emotion_vocab] log_probs = torch.log(all_probs.squeeze(0)) else: log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores( scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset ) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _decode_and_generate( self, decoder_in, memory_bank, # yida translate tag_src, tag_decoder_in, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch tag_gen = self.model.generators[ "tag"] if "tag" in self.model.generators else None dec_out, dec_attn, rnn_outs = self.model.decoder( # yida translate decoder_in, memory_bank, tag_gen, tag_src, tag_decoder_in, memory_lengths=memory_lengths, step=step) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None # yida translate tag_log_probs = None if "tag" in self.model.generators: tag_outputs = rnn_outs if not isinstance(rnn_outs, list) else dec_out tag_log_probs = self.model.generators["tag"]( tag_outputs.squeeze(0)) tag_argmax = tag_log_probs.max(1)[1] # dist = torch.distributions.Multinomial(logits=tag_log_probs, total_count=1) # tag_argmax = torch.argmax(dist.sample(), dim=1) # tag_index = torch.multinomial(tag_log_probs, num_samples=1) if "high" in self.model.generators: vocab_size = sum([ v._modules["0"].out_features for k, v in self.model.generators.items() if k != "tag" ]) log_probs = torch.full( [dec_out.squeeze(0).shape[0], vocab_size], -float('inf'), dtype=torch.float, device=self._dev) for k, gen in self.model.generators.items(): if k == "tag": continue indices = tag_argmax.eq(self.tag_vocab[k]) if indices.any(): k_output = dec_out.squeeze(0)[indices] k_logits = gen(k_output) mask = indices.float().unsqueeze(-1).mm( self.tag_mask[k]) log_probs.masked_scatter_( mask.bool(), k_logits.log_softmax(dim=-1)) else: logits = self.model.generators["generator"]( dec_out.squeeze(0)) if self.mask_decode: high_num = self.tag_mask["high"].sum().long().item() high_indices = tag_argmax.eq(self.tag_vocab["high"]) low_indices = tag_argmax.eq(self.tag_vocab["low"]) # logits[high_indices, high_num:] = -float("inf") logits[low_indices, :high_num] = -float("inf") log_probs = torch.log_softmax(logits, dim=-1) else: logits = self.model.generators["generator"](dec_out.squeeze(0)) log_probs = torch.log_softmax(logits, dim=-1) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(-1, batch.batch_size, scores.size(-1)) scores = scores.transpose(0, 1).contiguous() else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores(scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence # yida translate return log_probs, attn, tag_log_probs
def _decode_and_generate(self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) if show_debug_detail: print("debug in _decode_and_generate, memory_bank", memory_bank.size(), memory_bank.type()) print("debug in _decode_and_generate, decoder_in", decoder_in.size(), decoder_in.type()) print('debug in _decode_and_generate, memory_lengths', memory_lengths, memory_lengths.type()) print('debug max_relative_positions', self.model.decoder.transformer_layers[0].self_attn. max_relative_positions) # always 0 print('debug word_vec_size', self.model.decoder.embeddings.word_vec_size) print('debug word_padding_idx', self.model.decoder.embeddings.word_padding_idx) print('debug step', step, type(step)) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch tic = time.perf_counter() dec_out, dec_attn = self.model.decoder(decoder_in, memory_bank, memory_lengths=memory_lengths, step=step) toc = time.perf_counter() onmt_docoder_time = toc - tic tic = time.perf_counter() turbo_dec_out, turbo_dec_attn = self.turbo_decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step) toc = time.perf_counter() turbo_docoder_time = toc - tic global_timer.turbo_timer += turbo_docoder_time global_timer.torch_timer += onmt_docoder_time # check correctness assert (torch.max(torch.abs(dec_out - turbo_dec_out) < 1e-3)) assert (torch.max( torch.abs(dec_attn["std"] - turbo_dec_attn["std"]) < 1e-3)) # print("onmt time: ", onmt_docoder_time, " turbo time: ", turbo_docoder_time) # print("dec_out diff: ", torch.max( # torch.abs(dec_out - # turbo_dec_out))) # print("attn diff: ", torch.max( # torch.abs(dec_attn["std"] - # turbo_dec_attn["std"]))) tic = time.perf_counter() # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(-1, batch.batch_size, scores.size(-1)) scores = scores.transpose(0, 1).contiguous() else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores(scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence toc = time.perf_counter() others_time = toc - tic if show_profile_detail: print(f"others_time {others_time:0.4f}") return log_probs, attn
def _decode_and_generate( self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None, # we need src_origin to split the history/current content src_origin=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) sep_id = batch.dataset.fields['src'].base_field.vocab.stoi['[SEP]'] # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder(decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, sep_id=sep_id, src_origin=src_origin, batch_offset=batch_offset) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: if "copy" in dec_attn: attn = dec_attn["copy"] scores = self.model.generator( dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) elif "his_copy" in dec_attn: his_attn = dec_attn["his_copy"] cur_attn = dec_attn["cur_copy"] his_mid = dec_attn["his_mid"] cur_mid = dec_attn["cur_mid"] attn_shape = his_attn.shape scores, attn = self.model.generator( dec_out.view(-1, dec_out.size(2)), his_attn.view(-1, his_attn.size(2)), cur_attn.view(-1, cur_attn.size(2)), his_mid.view(-1, his_mid.size(2)), cur_mid.view(-1, cur_mid.size(2)), src_map) # convert attn as the same shape as normal copy attn = attn.view(attn_shape) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(-1, batch.batch_size, scores.size(-1)) scores = scores.transpose(0, 1).contiguous() else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores(scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def forward_pass(self, decoder_in, step, past=None, input_embeds=None, tags=None, use_copy=True): memory_bank = self.memory_bank src_vocabs = self.src_vocabs memory_lengths = self.memory_lengths src_map = self.src_map batch = self.batch if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) decoder = self.model.decoder dec_out, all_hidden_states, past, dec_attn = decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, past=past, input_embeds=input_embeds, pplm_return=True) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None if self.simple_fusion: lm_dec_out, _ = self.model.lm_decoder(decoder_in, memory_bank.new_zeros( 1, 1, 1), step=step) probs = self.model.generator(dec_out.squeeze(0), lm_dec_out.squeeze(0)) else: probs = self.model.generator(dec_out.squeeze(0)) # print(log_probs) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores, p_copy = self.model.generator(dec_out.view( -1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map, tags=tags) scores = scores.view(batch.batch_size, -1, scores.size(-1)) scores = collapse_copy_scores(scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=None) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) # log_probs = scores.squeeze(0).log() probs = scores.squeeze(0) if use_copy is False: probs = probs[:, :50257] return probs, attn, all_hidden_states, past return probs, attn, all_hidden_states, past, p_copy return probs, attn, all_hidden_states, past
def _decode_and_generate(self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch #print(">>>>> hidden state of the decoder: <<<<<") #print(self.model.decoder.state["hidden"][1]) #before_state1, before_state2 = self.model.decoder.state["hidden"]]0].clone(), self.model.decoder.state["hidden"][1].clone() #print(before_state == self.model.decoder.state["hidden"]) dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step, counterfactual_attention_method=self. counterfactual_attention_method) hack_dict = {} # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator(dec_out.squeeze(0)) if self.counterfactual_attention_method == 'uniform_or_zero_out_max': hack_dict['log_probs_uniform'] = self.model.generator( dec_attn['uniform'][1].squeeze(0)) hack_dict['attention_matrix_uniform'] = dec_attn['uniform'][0] hack_dict['log_probs_zero_out_max'] = self.model.generator( dec_attn['zero_out_max'][1].squeeze(0)) hack_dict['attention_matrix_zero_out_max'] = dec_attn[ 'zero_out_max'][0] hack_dict['log_probs_permute'] = self.model.generator( dec_attn['permute'][1].squeeze(0)) hack_dict['attention_matrix_permute'] = dec_attn['permute'][0] elif self.counterfactual_attention_method is not None: hack_dict['log_probs_counterfactual'] = self.model.generator( dec_attn[self.counterfactual_attention_method][1].squeeze( 0)) hack_dict['attention_matrix'] = dec_attn[ self.counterfactual_attention_method][0] #if tvd_permute is True: # dec_outs = dec_attn["std_tvd_permute"] # distances = [] # for my_dec_out in dec_outs: # my_dec_out = my_dec_out.squeeze(0) # my_log_prob = self.model.generator(my_dec_out) #distance = tvd(torch.exp(log_probs), torch.exp(my_log_prob)) # distance = high_distance(torch.exp(log_probs), torch.exp(my_log_prob))[0] # distances.append(distance) # dist_change_median = torch.median(torch.stack(distances), dim=0).values # if attn.size()[0] != 1: # print(">>>> Shit! Target length in attention is more than 1 <<<<") # assert False # max_attention = attn[0].max(dim=1).values # hack_dict['tvd_dist_change_median'] = dist_change_median # hack_dict['tvd_max_attention'] = max_attention # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores(scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence if self.counterfactual_attention_method is not None: return log_probs, attn, hack_dict else: return log_probs, attn
def _decode_and_generate(self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None, alive_seq=None, valid=False): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder(decoder_in, memory_bank, memory_lengths=memory_lengths, step=step) if not valid: dec_out = dec_out[-1, :, :].unsqueeze(0) dec_attn['std'] = dec_attn['std'][-1, :, :].unsqueeze(0) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: if valid: attn = dec_attn["copy"][:decoder_in.size( 0), :, :] #.unsqueeze(0) else: attn = dec_attn["copy"][-1, :, :].unsqueeze(0) scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores(scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset) scores = scores.view( decoder_in.size(0) if valid else 1, -1, scores.size(-1)) #Original first argument: decoder_in.size(0) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _decode_and_generate( self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx ) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step ) # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None if self.trans_to is not None: dec_out_ = dec_out.view(-1, dec_out.shape[-1])#squeeze(0) ''' lang_ids = torch.empty(dec_out_.shape[0], device=dec_out_.device).fill_(self.trans_to).view(-1) self.model.generator.set_lang(lang_ids) ''' log_probs = self.model.generator(dec_out_) else: dec_out_ = dec_out.view(-1, dec_out.shape[-1]) log_probs = self.model.generator(dec_out_)#.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(batch.batch_size, -1, scores.size(-1)) else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores( scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset ) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn
def _decode_and_generate( self, decoder_in, memory_bank, batch, src_vocabs, memory_lengths, src_map=None, step=None, batch_offset=None): if self.copy_attn: # Turn any copied words into UNKs. decoder_in = decoder_in.masked_fill( decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx ) # Decoder forward, takes [tgt_len, batch, nfeats] as input # and [src_len, batch, hidden] as memory_bank # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch dec_out, dec_attn = self.model.decoder( decoder_in, memory_bank, memory_lengths=memory_lengths, step=step ) # MMM if len(self.length_model) > 0: # print('Using target length model.') if self.length_model == 'oracle': # print('Length model: oracle') pad = self._tgt_pad_idx eos = self._tgt_eos_idx #sequence has <s> and </s> t_lens = (batch.tgt != pad).sum(dim=0).squeeze(1) # # add noise to t_lens for experiments (just for test) # noisy_t_lens = torch.tensor([l+randint(-2,2) for l in t_lens]) # if self.cuda: # noisy_t_lens = noisy_t_lens.to('cuda') # self.model.generator[-1].t_lens = noisy_t_lens self.device = 'cuda' if self._use_cuda else 'cpu' self.model.generator[-1].t_lens = \ t_lens if batch_offset is None else t_lens.index_select(0, batch_offset.to(self.device)) self.model.generator[-1].eos_ind = eos self.model.generator[-1].batch_max_len = batch.tgt.size(0) elif self.length_model == 'fixed_ratio': # print('Length model: fixed_ratio') pad = self._tgt_pad_idx eos = self._tgt_eos_idx #sequence has <s> and </s> t_lens = torch.ceil(batch.src[1].float().cuda()*1.0699071772279642) # # add noise to t_lens for experiments (just for test) # noisy_t_lens = torch.tensor([l+randint(-2,2) for l in t_lens]) # if self.cuda: # noisy_t_lens = noisy_t_lens.to('cuda') # self.model.generator[-1].t_lens = noisy_t_lens self.model.generator[-1].t_lens = \ t_lens if batch_offset is None else t_lens.index_select(0, batch_offset.to(self.device)) self.model.generator[-1].eos_ind = eos self.model.generator[-1].batch_max_len = batch.tgt.size(0) elif self.length_model == 'lstm': # print('Length model: lstm') pad = self._tgt_pad_idx eos = self._tgt_eos_idx #sequence has <s> and </s> #TODO: the code itself must handle ratio and diff lstm length models t_lens = [] src_vocab = dict(self.fields)["src"].base_field.vocab ratios = onmt.utils.length_model.predict_length_ratio(self.l_model, self.device, batch.src[0].squeeze().transpose(0, 1), src_vocab) # diffs = torch.round(onmt.utils.length_model.predict_length_ratio(self.l_model, self.device, # batch.src[0].squeeze().transpose(0, 1), src_vocab)) # target sequence has <s> and </s>, but source sequence doesn't have them t_lens = ratios * batch.src[1].type(torch.FloatTensor).to(self.device) + 2 # t_lens = torch.max((diffs + batch.src[1].type(torch.FloatTensor).to(self.device)), torch.zeros(batch.tgt.size(1)).to(self.device)) + 2 # # add noise to t_lens for experiments (just for test) # noisy_t_lens = torch.tensor([l+randint(-2,2) for l in t_lens]) # if self.cuda: # noisy_t_lens = noisy_t_lens.to('cuda') # self.model.generator[-1].t_lens = noisy_t_lens self.model.generator[-1].t_lens =\ t_lens if batch_offset is None else t_lens.index_select(0, batch_offset.to(self.device)) self.model.generator[-1].eos_ind = eos self.model.generator[-1].batch_max_len = batch.tgt.size(0) # /MMM # Generator forward. if not self.copy_attn: if "std" in dec_attn: attn = dec_attn["std"] else: attn = None log_probs = self.model.generator(dec_out.squeeze(0)) # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence else: attn = dec_attn["copy"] scores = self.model.generator(dec_out.view(-1, dec_out.size(2)), attn.view(-1, attn.size(2)), src_map) # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] if batch_offset is None: scores = scores.view(-1, batch.batch_size, scores.size(-1)) scores = scores.transpose(0, 1).contiguous() else: scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = collapse_copy_scores( scores, batch, self._tgt_vocab, src_vocabs, batch_dim=0, batch_offset=batch_offset ) scores = scores.view(decoder_in.size(0), -1, scores.size(-1)) log_probs = scores.squeeze(0).log() # returns [(batch_size x beam_size) , vocab ] when 1 step # or [ tgt_len, batch_size, vocab ] when full sentence return log_probs, attn