def train_batch(self, samples, use_cuda, tf_ratio=0.5, backprop=True, coverage_lambda=-1): start = time.time() if len(samples) == 0: return 0, 0 target_length = min(self.config['target_length'], max([len(pair.masked_target_tokens) for pair in samples])) nb_unks = max([len(s.unknown_tokens) for s in samples]) input_variable, full_input_variable, target_variable, full_target_variable, decoder_input = \ utils.get_batch_variables(samples, self.config['input_length'], target_length, use_cuda, self.vocab.word2index['SOS']) encoder_hidden = self.encoder.init_hidden(len(samples), use_cuda) self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() loss = 0 encoder_outputs, encoder_hidden = self.encoder(input_variable, encoder_hidden) decoder_hidden = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1) decoder_h_states = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1).unsqueeze(1) previous_att = None for token_i in range(target_length): p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \ self.decoder(decoder_input, decoder_h_states, decoder_hidden, encoder_outputs, full_input_variable, previous_att, nb_unks, use_cuda) if coverage_lambda < 0 or token_i == 0: loss += self.criterion(torch.log(p_final.clamp(min=1e-8)), full_target_variable.narrow(1, token_i, 1) .squeeze(-1)) else: coverage = previous_att.narrow(1, 0, previous_att.size()[1]-1).sum(dim=1) coverage_min, _ = torch.cat((att_dist.unsqueeze(1), coverage.unsqueeze(1)), dim=1).min(dim=1) coverage_loss = coverage_min.sum(-1) loss += self.criterion(torch.log(p_final.clamp(min=1e-8)), full_target_variable.narrow(1, token_i, 1).squeeze(-1))\ + (coverage_lambda * coverage_loss) # this needs to be fixed if random.uniform(0, 1) < tf_ratio: decoder_input = target_variable.narrow(1, token_i, 1) else: _, max_tokens = p_final.max(1) for i in range(max_tokens.size()[0]): if max_tokens.data[i] >= self.vocab.vocab_size: max_tokens.data[i] = self.vocab.word2index['UNK'] decoder_input = max_tokens.unsqueeze(1) if backprop: loss.backward() torch.nn.utils.clip_grad_norm(self.encoder.parameters(), 2) torch.nn.utils.clip_grad_norm(self.decoder.parameters(), 2) self.encoder_optimizer.step() self.decoder_optimizer.step() ''' print(" ", [[t for t in pair.full_target_tokens if t not in pair.full_source_tokens and t >= self.vocab.vocab_size] for pair in samples]) ''' return loss.data[0] / target_length, time.time() - start
def train_batch(self, samples, use_cuda, backprop=True): start = time.time() input_variable, full_input_variable, target_variable, full_target_variable, decoder_input = \ utils.get_batch_variables(samples, 400, 50, use_cuda, self.vocab.word2index['SOS']) novelty_score = Variable(torch.FloatTensor(self.compute_novelty_n_grams(samples, 3))) if use_cuda: novelty_score = novelty_score.cuda() loss = self.criterion(self.model(full_input_variable, full_target_variable), novelty_score) if backprop: loss.backward() self.optimizer.step() return loss.data[0], time.time() - start
def predict(self, samples, target_length, beam_size, use_cuda): # this only works with one sample at a time nb_unks = max([len(s.unknown_tokens) for s in samples]) input_variable, full_input_variable, target_variable, full_target_variable, decoder_input = \ utils.get_batch_variables(samples, self.config['input_length'], target_length, use_cuda, self.vocab.word2index['SOS']) encoder_hidden = self.encoder.init_hidden(len(samples), use_cuda) encoder_outputs, encoder_hidden = self.encoder(input_variable, encoder_hidden) decoder_hidden = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1) decoder_h_states = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1).unsqueeze(1) previous_att = None if not beam_size: result = [] for token_i in range(target_length): p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \ self.decoder(decoder_input, decoder_h_states, decoder_hidden, encoder_outputs, full_input_variable, previous_att, nb_unks, use_cuda) p_vocab_word, vocab_word_idx = p_final.max(1) result.append([{ 'token_idx': vocab_word_idx.data[i], 'word': utils.translate_word(vocab_word_idx.data[i], samples[i], self.vocab), 'p_gen': round(p_gen.data[i][0], 3) } for i in range(len(samples))]) _, max_tokens = p_final.max(1) for i in range(max_tokens.size()[0]): if max_tokens.data[i] >= self.vocab.vocab_size: max_tokens.data[i] = self.vocab.word2index['UNK'] decoder_input = max_tokens.unsqueeze(1) return result else: search_complete = False top_beams = [ Beam(decoder_input, decoder_h_states, decoder_hidden, previous_att, [], []) ] while not search_complete: new_beams = [] for beam in top_beams: if beam.complete: new_beams.append(beam) else: p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \ self.decoder(beam.decoder_input, beam.decoder_h_states, beam.decoder_hidden, encoder_outputs, full_input_variable, beam.previous_att, nb_unks, use_cuda) for k in range(beam_size): p_vocab_word, vocab_word_idx = p_final.max(1) _, max_tokens = p_final.max(1) if max_tokens.data[0] >= self.vocab.vocab_size: max_tokens.data[0] = self.vocab.word2index[ 'UNK'] new_beams.append( Beam(max_tokens.unsqueeze(1), decoder_h_states, decoder_hidden, previous_att, beam.log_probs + [p_vocab_word.data[0]], beam.sequence + [vocab_word_idx.data[0]])) p_final[0, vocab_word_idx.data[0]] = 0 if len(new_beams[-1].sequence ) == target_length or vocab_word_idx.data[ 0] == self.vocab.word2index['EOS']: new_beams[-1].complete = True all_beams = sorted([(b, b.compute_score()) for b in new_beams], key=lambda tup: tup[1]) if len(all_beams) > beam_size: all_beams = all_beams[:beam_size] top_beams = [beam[0] for beam in all_beams] if len([True for b in top_beams if b.complete]) == beam_size: search_complete = True return [[ " ".join([ utils.translate_word(t, samples[0], self.vocab) for t in b.sequence ]), b.compute_score() ] for b in top_beams]
def predict(self, samples, target_length, beam_size, use_cuda): # this only works with one sample at a time nb_unks = max([len(s.unknown_tokens) for s in samples]) input_variable, full_input_variable, target_variable, full_target_variable, decoder_input = \ utils.get_batch_variables(samples, self.config['input_length'], target_length, use_cuda, self.vocab.word2index['SOS']) encoder_hidden = self.encoder.init_hidden(len(samples), use_cuda) encoder_outputs, encoder_hidden = self.encoder(input_variable, encoder_hidden) decoder_hidden = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1) decoder_h_states = torch.cat((encoder_hidden[0], encoder_hidden[1]), -1).unsqueeze(1) previous_att = None if not beam_size: result = [] for token_i in range(target_length): p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \ self.decoder(decoder_input, decoder_h_states, decoder_hidden, encoder_outputs, full_input_variable, previous_att, nb_unks, use_cuda) p_vocab_word, vocab_word_idx = p_final.max(1) result.append([{ 'token_idx': vocab_word_idx.data[i], 'word': utils.translate_word(vocab_word_idx.data[i], samples[i], self.vocab), 'p_gen': round(p_gen.data[i][0], 3) } for i in range(len(samples))]) _, max_tokens = p_final.max(1) for i in range(max_tokens.size()[0]): if max_tokens.data[i] >= self.vocab.vocab_size: max_tokens.data[i] = self.vocab.word2index['UNK'] decoder_input = max_tokens.unsqueeze(1) return result else: search_complete = False top_beams = [ Beam(decoder_input, decoder_h_states, decoder_hidden, previous_att, [], []) ] def predict_for_beams(beams, encoder_outputs, full_input_variable): results = [] encoder_outputs = torch.stack([ encoder_outputs[i] for beam in beams for i in range(len(samples)) ], 0) full_input_variable = torch.stack([ full_input_variable[i] for beam in beams for i in range(len(samples)) ], 0) decoder_input = torch.stack([ beam.decoder_input[i] for beam in beams for i in range(len(samples)) ], 0) decoder_h_states = torch.stack([ beam.decoder_h_states[i] for beam in beams for i in range(len(samples)) ], 0) decoder_hidden = torch.stack([ beam.decoder_hidden[i] for beam in beams for i in range(len(samples)) ], 0) if beams[0].previous_att is not None: previous_att = torch.stack([ beam.previous_att[i] for beam in beams for i in range(len(samples)) ], 0) else: previous_att = None p_final, p_gen, p_vocab, att_dist, decoder_h_states, decoder_hidden, previous_att = \ self.decoder(decoder_input, decoder_h_states, decoder_hidden, encoder_outputs, full_input_variable, previous_att, nb_unks, use_cuda) for b in range(len(beams)): results.append([beams[b]] + [ tensor.narrow(0, b * len(samples), len(samples)) for tensor in [ p_final, decoder_h_states, decoder_hidden, previous_att ] ]) return results while not search_complete: new_beams = [] beams_to_predict = [] for beam in top_beams: if beam.complete: new_beams.append(beam) else: beams_to_predict.append(beam) predictions = predict_for_beams(beams_to_predict, encoder_outputs, full_input_variable) for b in predictions: beam = b[0] p_final, decoder_h_states, decoder_hidden, previous_att = b[ 1], b[2], b[3], b[4] p_top_words, top_indexes = p_final.topk(beam_size) for k in range(beam_size): non_masked_word = top_indexes.data[0][k] if top_indexes.data[0][k] >= self.vocab.vocab_size: top_indexes.data[0][k] = self.vocab.word2index[ 'UNK'] new_beams.append( Beam(top_indexes.narrow(1, k, 1), decoder_h_states, decoder_hidden, previous_att, beam.log_probs + [p_top_words.data[0][k]], beam.sequence + [non_masked_word])) if len(new_beams[-1].sequence) == target_length or top_indexes.data[0][k] == \ self.vocab.word2index['EOS']: new_beams[-1].complete = True all_beams = sorted([(b, b.compute_score()) for b in new_beams], key=lambda tup: tup[1]) if len(all_beams) > beam_size: all_beams = all_beams[:beam_size] top_beams = [beam[0] for beam in all_beams] if len([True for b in top_beams if b.complete]) == beam_size: search_complete = True return [[ " ".join([ utils.translate_word(t, samples[0], self.vocab) for t in b.sequence ]), b.compute_score() ] for b in top_beams]