def __call__(self, outputs, output_symbols, targets): ''' Inputs: outputs: (seq_len, batch_size, label_size) output_symbols : (seq_len, batch_size) index of output symbol (sampling from policy) targets: (batch_size, label_size) ''' ''' outputs = torch.stack(outputs) output_symbols = torch.stack(output_symbols).squeeze(2) seq_len, batch_size, label_size = outputs.shape outputs = outputs.transpose(0,1) # batch_size * seq_len * label_size outputs = outputs.transpose(1,2) # batch_size * label_size * seq_len mask = torch.ones((seq_len, batch_size), dtype = torch.float32, device = outputs.device) mask[1:,:] = 1 - output_symbols[:-1,:].data.eq(self.eos_id).float() losses = self.criterion(outputs, output_symbols.transpose(0,1)) * mask loss = torch.sum(losses) / torch.sum(mask) ''' outputs = torch.stack(outputs) targets = targets.to(outputs.device) output_symbols = torch.stack(output_symbols).squeeze(2) seq_len, batch_size, label_size = outputs.shape target_time = torch.zeros((seq_len, batch_size), dtype=torch.long, device=outputs.device) mask = torch.ones((seq_len, batch_size), dtype=torch.float32, device=outputs.device) for i in range(0, seq_len): target_time[i] = (torch.exp(outputs[i]) * targets).topk( 1, dim=1)[1].squeeze(1) targets = targets - utils.to_one_hot(target_time[i], label_size) # check if all targets are sucessfully predicted is_end_batch = torch.sum(targets, dim=1).eq(0) targets[:, self.eos_id] += is_end_batch.float() # check eos in output token if i > 0: eos_batches = target_time[i - 1, :].data.eq(self.eos_id) eos_batches = eos_batches.float() mask[i, :] = (1 - eos_batches) * mask[i - 1, :] losses = self.criterion(outputs.permute( 1, 2, 0), target_time.transpose( 0, 1)) * mask # (batch_size, label_size, seq_len) loss = torch.sum(losses) / torch.sum(mask) return loss
def __call__(self, outputs, output_symbols, targets): ''' Inputs: outputs: (seq_len, batch_size, label_size) output_symbols : (seq_len, batch_size) index of output symbol (sampling from policy) targets: (batch_size, label_size) ''' # some details: # directly minize specific score # give sos low score outputs = torch.stack(outputs) targets = targets.to(outputs.device) output_symbols = torch.stack(output_symbols).squeeze(2) seq_len, batch_size, label_size = outputs.shape outputs_one_hot = utils.to_one_hot(output_symbols, label_size).to(outputs.device) q_values = torch.zeros(outputs.shape, dtype=torch.float32, device=outputs.device) mask = torch.ones((seq_len, batch_size), dtype=torch.float32, device=outputs.device) q_values[0, :, :] = -1 + targets for i in range(1, seq_len): is_correct = targets * outputs_one_hot[ i - 1, :, :] # batch_size * label_size targets = targets - is_correct q_values[i, :, :] = q_values[i - 1, :, :] - is_correct + torch.sum( is_correct, dim=1).unsqueeze(1) - 1 # check if all targets are sucessfully predicted is_end_batch = torch.sum(targets, dim=1).eq(0) q_values[i, :, self.eos_id] += is_end_batch.float() # check eos in output token eos_batches = output_symbols[i - 1, :].data.eq(self.eos_id) eos_batches = eos_batches.float() mask[i, :] = (1 - eos_batches) * mask[i - 1, :] optimal_policy = torch.softmax(q_values / self.temperature, dim=2) #print(F.kl_div(optimal_policy, outputs)) # KL divergence #softmax #losses = torch.mean(optimal_policy * torch.log(optimal_policy / (outputs + 1e-8) + 1e-8), dim = 2) * mask #log_softmax: losses = torch.mean(optimal_policy * (torch.log(optimal_policy + 1e-8) - outputs), dim=2) * mask loss = torch.sum(losses) / torch.sum(mask) return loss
def __call__(self, outputs, output_symbols, targets): outputs = torch.stack(outputs) targets = targets.to(outputs.device) output_symbols = torch.stack(output_symbols).squeeze(2) seq_len, batch_size, label_size = outputs.shape outputs_one_hot = utils.to_one_hot(output_symbols, label_size).to(outputs.device) target_each_time = torch.zeros(outputs.shape, dtype=torch.float32, device=outputs.device) mask = torch.ones((seq_len, batch_size), dtype=torch.float32, device=outputs.device) target_each_time[0, :, :] = targets for i in range(1, seq_len): is_correct = targets * outputs_one_hot[ i - 1, :, :] # batch_size * label_size targets = targets - is_correct target_each_time[i, :, :] = targets # check if all targets are sucessfully predicted is_end_batch = torch.sum(targets, dim=1).eq(0) target_each_time[i, :, self.eos_id] += is_end_batch.float() # check eos in output token eos_batches = output_symbols[i - 1, :].data.eq(self.eos_id) eos_batches = eos_batches.float() mask[i, :] = (1 - eos_batches) * mask[i - 1, :] prob_outputs = torch.exp(outputs) new_probs = prob_outputs * (1 - target_each_time) new_probs = new_probs / torch.sum(new_probs, dim=-1).unsqueeze(-1) Entropy = torch.sum(new_probs * torch.log(new_probs + 1e-8), dim=-1) * mask loss = torch.sum(Entropy) / torch.sum(mask) return loss
def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, dataset=None, teacher_forcing_ratio=0, candidates=None, logit_output=None): ret_dict = dict() if self.use_attention: ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() inputs, batch_size, max_length = self._validate_args( inputs, encoder_hidden, encoder_outputs, teacher_forcing_ratio, candidates) decoder_hidden = self._init_state(encoder_hidden) use_teacher_forcing = True if random.random( ) < teacher_forcing_ratio else False decoder_outputs = [] sequence_symbols = [] lengths = np.array([max_length] * batch_size) def post_decode(step_output, step_symbols, step_attn): decoder_outputs.append(step_output) if self.use_attention: ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) sequence_symbols.append(step_symbols) eos_batches = step_symbols.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > di) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) # Manual unrolling is used to support random teacher forcing. # If teacher_forcing_ratio is True or False instead of a probability, # the unrolling can be done in graph if use_teacher_forcing and self.loss_type == 'vanilla': ## Only vanilla RNN decoder_input = inputs[:, :-1] context, decoder_hidden, attn = self.forward_step( decoder_input, decoder_hidden, encoder_outputs) decoder_output, symbols = self.decoder(context) decoder_output = decoder_output.log() for di in range(decoder_output.size(1)): step_output = decoder_output[:, di, :] step_symbols = symbols[:, di] if attn is not None: step_attn = attn[:, di, :] else: step_attn = None post_decode(step_output, step_symbols, step_attn) else: decoder_input = inputs[:, 0].unsqueeze(1) mask = torch.zeros((batch_size, self.output_size), dtype=torch.float32).to(inputs.device) for di in range(max_length): context, decoder_hidden, attn = self.forward_step( decoder_input, decoder_hidden, encoder_outputs) if 'candidates' not in self.decoder.sampling_type: decoder_output, symbols = self.decoder( context, mask, logit_output=logit_output) else: if use_teacher_forcing and self.loss_type == 'order_free': # Order Free RNN ori = self.decoder.sampling_type self.set_sampling_type('max_from_candidates') decoder_output, symbols = self.decoder( context, mask, candidates) self.set_sampling_type(ori) else: # Order Free + SS / vanilla + SS / OCD decoder_output, symbols = self.decoder( context, mask, candidates) candidates -= utils.to_one_hot( symbols.squeeze(2).squeeze(1), self.output_size).float() is_eos_batch = torch.sum(candidates, dim=1).eq(0) candidates[:, self.eos_id] = is_eos_batch.float() decoder_output = decoder_output.log() step_output = decoder_output.squeeze(1) step_symbols = symbols.squeeze(1) post_decode(step_output, step_symbols, attn) decoder_input = step_symbols if self.add_mask: # mask is one if a symbol has been predicted # There will be error if loss is nan try: mask += utils.to_one_hot(step_symbols.squeeze(1), self.output_size).float() except: print(torch.max(mask)) ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist() return decoder_outputs, decoder_hidden, ret_dict
def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, dataset=None, teacher_forcing_ratio=0, retain_output_probs=True, candidates=None, logit_output=None): """ Forward rnn for MAX_LENGTH steps. Look at :func:`seq2seq.models.DecoderRNN.DecoderRNN.forward_rnn` for details. """ inputs, batch_size, max_length = self.rnn._validate_args( inputs, encoder_hidden, encoder_outputs, teacher_forcing_ratio) self.pos_index = Variable( torch.LongTensor(range(batch_size)) * self.k).view(-1, 1) # Inflate the initial hidden states to be of size: b*k x h encoder_hidden = self.rnn._init_state(encoder_hidden) if encoder_hidden is None: hidden = None else: if isinstance(encoder_hidden, tuple): n_layer_bidiretion = encoder_hidden[0].size( 0) # n_layer * direction hidden = tuple([ _inflate(h, self.k, 2).view(n_layer_bidiretion, batch_size * self.k, -1) for h in encoder_hidden ]) else: # TODO :Should check _inflat dimension hidden = _inflate(encoder_hidden, self.k, 2).view(1, batch_size * self.k) # ... same idea for encoder_outputs and decoder_outputs if self.rnn.use_attention: _, encoder_length, encoder_output_size = encoder_outputs.shape inflated_encoder_outputs = _inflate(encoder_outputs, self.k, 1).view( batch_size * self.k, encoder_length, encoder_output_size) else: inflated_encoder_outputs = None # logit output if logit_output is not None: label_size = logit_output.shape[-1] logit_output = _inflate(logit_output, self.k, 1).view(batch_size * self.k, label_size) # Initialize the scores; for the first step, # ignore the inflated copies to avoid duplicate entries in the top k sequence_scores = torch.zeros((batch_size * self.k, 1), dtype=torch.float32) sequence_scores.fill_(-1000) sequence_scores.index_fill_( 0, torch.LongTensor([i * self.k for i in range(0, batch_size)]), 0.0) sequence_scores = Variable(sequence_scores) # Initialize the input vector input_var = Variable( torch.transpose( torch.LongTensor([[self.SOS] * batch_size * self.k]), 0, 1)) # Initialize mask mask = torch.zeros((batch_size * self.k, self.V), dtype=torch.float32) # Initialize lengths lengths = torch.ones((batch_size * self.k, 1), dtype=torch.float32) # Initialize eos eos_indices = input_var.data.eq(self.EOS) eos_score = sequence_scores * eos_indices.float() #bk*1 # Assign all vars to CUDA if available if CUDA: self.pos_index = self.pos_index.cuda() input_var = input_var.cuda() sequence_scores = sequence_scores.cuda() mask = mask.cuda() lengths = lengths.cuda() eos_indices = eos_indices.cuda() eos_score = eos_score.cuda() # Store decisions for backtracking stored_outputs = list() stored_scores = list() stored_predecessors = list() stored_emitted_symbols = list() stored_hidden = list() for t in range(max_length): # Run the RNN one step forward context, hidden, attn = self.rnn.forward_step( input_var, hidden, inflated_encoder_outputs) softmax_output, _ = self.rnn.decoder(context, logit_output=logit_output) log_softmax_output = softmax_output.log().squeeze(1) #bk * v # If doing local backprop (e.g. supervised training), retain the output layer if retain_output_probs: stored_outputs.append(log_softmax_output) #bk * v # To get the full sequence scores for the new candidates, add the local scores for t_i to the predecessor scores for t_(i-1) sequence_scores = _inflate(sequence_scores, self.V, 1) #bk*V sequence_scores += log_softmax_output + mask # Terminated sentence can only produce eos token eos_mask = eos_indices.squeeze().float() sequence_scores[:,self.EOS] = \ sequence_scores[:,self.EOS] * (1 - eos_mask) + eos_score.squeeze() * eos_mask #[bk] # Calculate new score if self.beam_score_type == 'sum': scores, candidates = sequence_scores.view(batch_size, -1).topk( self.k) # b* kV input_var = (candidates % self.V).view(batch_size * self.k, 1) # Reshape input = (bk, 1) and sequence_scores = (bk, 1) sequence_scores = scores.view(batch_size * self.k, 1) elif self.beam_score_type == 'mean': # Mean of scores in each time step scores, candidates = (sequence_scores / lengths).view( batch_size, -1).topk(self.k, dim=1) # b* kV input_var = (candidates % self.V).view(batch_size * self.k, 1) # Reshape input = (bk, 1) and sequence_scores = (bk, 1) sequence_scores = scores.view(batch_size * self.k, 1) * lengths # Update fields for next timestep predecessors = (candidates / self.V + self.pos_index.expand_as(candidates)).view( batch_size * self.k, 1) # b*k # Update mask mask = mask[predecessors.squeeze(), :] - utils.to_one_hot( input_var.squeeze(), self.V).float() * INF mask[:, self.EOS] = 0 if isinstance(hidden, tuple): hidden = tuple([ h.index_select(1, predecessors.squeeze()) for h in hidden ]) else: hidden = hidden.index_select(1, predecessors.squeeze()) # Update sequence scores and erase scores for end-of-sentence symbol so that they aren't expanded stored_scores.append(sequence_scores.clone()) eos_indices = input_var.data.eq(self.EOS) # bk* 1 eos_score = sequence_scores * eos_indices.float() ''' print(sequence_scores.view(batch_size,-1)[0]) print(input_var.view(batch_size, -1)[0]) print(predecessors.view(batch_size, -1)[0]) print(log_softmax_output.view(batch_size, self.k, self.V)[0,predecessors.view(batch_size, -1)[0],input_var.view(batch_size, -1)[0]]) print('-'*100) ''' # Update lengths if t < max_length - 1: sequence_scores.data.masked_fill_(eos_indices, -1000) lengths = lengths[predecessors.squeeze(), 0].view( batch_size * self.k, 1) + (1 - eos_indices.float()) # Cache results for backtracking stored_predecessors.append(predecessors) stored_emitted_symbols.append(input_var) stored_hidden.append(hidden) #print(sequence_scores[:20]) # Do backtracking to return the optimal values t = max_length - 1 outputs = [] output_symbols = [] step_scores = [] now_indexes = torch.arange(batch_size * self.k) ''' now_idx = 0 print("start") sco = 0 ''' while t >= 0: t_predecessors = stored_predecessors[t].squeeze() prev_indexes = now_indexes now_indexes = stored_predecessors[t].squeeze()[now_indexes] ''' prev_idx = now_idx now_idx = t_predecessors[now_idx].item() ''' current_symbol = stored_emitted_symbols[t][prev_indexes, 0].view( batch_size, self.k) current_output = stored_outputs[t][now_indexes].view( batch_size, self.k, self.V) #score[i][j][0] = output[i][j][symbol[i][j][0]] current_score = current_output.gather( 2, current_symbol.unsqueeze(2)).view(batch_size, self.k) # record the back tracked results step_scores.append(current_score) outputs.append(current_output) output_symbols.append(current_symbol.unsqueeze(2)) #x = current_symbol[0][0].item() #print(x) #print(current_output[0][0][x]) #print(x) ''' out_token = stored_emitted_symbols[t][prev_idx][0].item() print(prev_idx,now_idx, out_token, stored_outputs[t][now_idx][out_token]) #print(x) #print(stored_outputs[t][now_idx][x]) sco += stored_outputs[t][now_idx][out_token] ''' t -= 1 outputs.reverse() #[ b,k,V] output_symbols.reverse() #[b,k] step_scores.reverse() # Build return objects decoder_outputs = [step[:, 0, :] for step in outputs] decoder_hidden = None metadata = {} metadata['output'] = outputs # seq_len [batch_size * k * V] if self.beam_score_type == 'sum': metadata['topk_score'] = (sequence_scores).view( batch_size, self.k) # [batch_size * k] elif self.beam_score_type == 'mean': metadata['topk_score'] = (sequence_scores / lengths).view( batch_size, self.k) # [batch_size * k] metadata[ 'topk_sequence'] = output_symbols # seq_len [batch_size * k,1] metadata['topk_length'] = lengths.view( batch_size, self.k) # seq_len [batch_size * k] metadata['step_score'] = step_scores # seq_len [batch_size * k] metadata['sequence'] = [seq[:, 0] for seq in output_symbols ] # seq_len [batch_size] ''' idx = 0 sco = 0 for t in range(max_length): x = output_symbols[t][idx][0].item() x_score = decoder_outputs[t][idx][x].item() sco += x_score print(x,x_score, sco) print(sequence_scores[batch_size * idx][0]) exit() S = [x[0][0].item() for x in step_scores] print([x[0][0].item() for x in step_scores]) print(torch.sum(torch.tensor(S[:5]))) print([x[0] for x in metadata['sequence']]) print(lengths[0][0]) print(metadata['topk_score'][0][0]*lengths[0][0]) exit() ''' return decoder_outputs, decoder_hidden, metadata
def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, dataset=None, teacher_forcing_ratio=0, candidates=None, logit_output=None): ret_dict = dict() if self.use_attention: ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() ori_inputs = inputs inputs, batch_size, max_length = self._validate_args( inputs, encoder_hidden, encoder_outputs, teacher_forcing_ratio, candidates) decoder_hidden = self._init_state(encoder_hidden) decoder_outputs = [] sequence_symbols = [] lengths = np.array([max_length] * batch_size) def post_decode(step_output, step_symbols, step_attn): decoder_outputs.append(step_output) if self.use_attention: ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) sequence_symbols.append(step_symbols) eos_batches = step_symbols.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > di) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) decoder_input = inputs[:, 0].unsqueeze(1) mask = torch.zeros((batch_size, self.output_size), dtype=torch.float32).to(inputs.device) for di in range(max_length - 1): context, decoder_hidden, attn = self.forward_step( decoder_input, decoder_hidden, encoder_outputs) decoder_output, symbols = self.decoder(context, mask, candidates, logit_output=logit_output) decoder_output = decoder_output.log() if teacher_forcing_ratio < 1.0: ran = torch.rand(symbols.shape).to(symbols.device) is_ss = ran.gt(teacher_forcing_ratio).float() if ori_inputs is not None: # vanilla + SS corrects = inputs[:, di + 1].unsqueeze(1).unsqueeze(2) else: # order free + SS corrects = symbols ##sample ori = self.decoder.sampling_type self.set_sampling_type('sample') _, sample_symbols = self.decoder(context, mask, candidates) self.set_sampling_type(ori) step_symbols = ( is_ss * sample_symbols.float() + (1 - is_ss) * corrects.float()).squeeze(1).long() else: if ori_inputs is not None: step_symbols = inputs[:, di + 1].unsqueeze(1) else: step_symbols = symbols.squeeze(1) if 'candidates' in self.decoder.sampling_type: candidates -= utils.to_one_hot( symbols.squeeze(2).squeeze(1), self.output_size).float() is_eos_batch = torch.sum(candidates, dim=1).eq(0) candidates[:, self.eos_id] = is_eos_batch.float() step_output = decoder_output.squeeze(1) post_decode(step_output, step_symbols, attn) decoder_input = step_symbols if self.add_mask: # mask is one if a symbol has been predicted # There will be error if loss is nan mask[range(batch_size), step_symbols.squeeze(1)] = 1 mask[:, self.eos_id] = 0 #mask += utils.to_one_hot(step_symbols.squeeze(1), self.output_size).float() ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist() return decoder_outputs, decoder_hidden, ret_dict