def test(self): self.model.eval() batch_loss_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [ sent for conv in input_conversations for sent in conv] target_sentences = [ sent for conv in target_conversations for sent in conv] input_sentence_length = [ l for len_list in sentence_length for l in len_list[:-1]] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:]] input_conversation_length = [l - 1 for l in conversation_length] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var( torch.LongTensor(input_sentence_length)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) sentence_logits = self.model( input_sentences, input_sentence_length, input_conversation_length, target_sentences) batch_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() epoch_loss = np.sum(batch_loss_history) / n_total_words print(f'Number of words: {n_total_words}') print(f'Bits per word: {epoch_loss:.3f}') word_perplexity = np.exp(epoch_loss) print_str = f'Word perplexity : {word_perplexity:.3f}\n' print(print_str) return word_perplexity
def evaluate(self): self.model.eval() batch_loss_history = [] recon_loss_history = [] kl_div_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] utterances = [utter for conv in conversations for utter in conv] target_utterances = [ utter for conv in target_conversations for utter in conv ] utterance_length = [ l for len_list in sentence_length for l in len_list ] target_utterance_length = [ l for len_list in sentence_length for l in len_list[1:] ] input_conversation_length = [ conv_len - 1 for conv_len in conversation_length ] with torch.no_grad(): utterances = to_var(torch.LongTensor(utterances)) utterance_length = to_var(torch.LongTensor(utterance_length)) target_utterances = to_var(torch.LongTensor(target_utterances)) target_utterance_length = to_var( torch.LongTensor(target_utterance_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) sentence_logits, kl_div = self.model(utterances, utterance_length, input_conversation_length, target_utterances) recon_loss, n_words = masked_cross_entropy( sentence_logits, target_utterances, target_utterance_length) batch_loss = recon_loss + kl_div batch_loss_history.append(batch_loss.item()) recon_loss_history.append(recon_loss.item()) kl_div_history.append(kl_div.item()) n_total_words += n_words.item() epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_recon_loss = np.sum(recon_loss_history) / n_total_words epoch_kl_div = np.sum(kl_div_history) / n_total_words print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' print(print_str) print('\n') return epoch_loss
def test(self): self.model.eval() batch_loss_history = [] n_total_words = 0 for batch_i, (conversations, convs_length, utterances_length) in \ enumerate(tqdm(self.eval_data_loader, ncols=80)): input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] input_utterances = [ utter for conv in input_conversations for utter in conv ] target_utterances = [ utter for conv in target_conversations for utter in conv ] input_utterance_length = [ l for len_list in utterances_length for l in len_list[:-1] ] target_utterance_length = [ l for len_list in utterances_length for l in len_list[1:] ] input_conversation_length = [ conv_len - 1 for conv_len in convs_length ] with torch.no_grad(): input_utterances = to_var(torch.LongTensor(input_utterances)) target_utterances = to_var(torch.LongTensor(target_utterances)) input_utterance_length = to_var( torch.LongTensor(input_utterance_length)) target_utterance_length = to_var( torch.LongTensor(target_utterance_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) utterances_logits = self.model(input_utterances, input_utterance_length, input_conversation_length, target_utterances) batch_loss, n_words = masked_cross_entropy( utterances_logits, target_utterances, target_utterance_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() epoch_loss = np.sum(batch_loss_history) / n_total_words print(f'Number of words: {n_total_words}') print(f'Bits per word: {epoch_loss:.3f}') word_perplexity = np.exp(epoch_loss) print(f'Word perplexity : {word_perplexity:.3f}\n') return word_perplexity
def importance_sample(self): ''' Perform importance sampling to get tighter bound ''' self.model.eval() weight_history = [] n_total_words = 0 kl_div_history = [] for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations sentences = [sent for conv in conversations for sent in conv] input_conversation_length = [l - 1 for l in conversation_length] target_sentences = [ sent for conv in target_conversations for sent in conv ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] sentence_length = [ l for len_list in sentence_length for l in len_list ] # n_words += sum([len([word for word in sent if word != PAD_ID]) for sent in target_sentences]) with torch.no_grad(): sentences = to_var(torch.LongTensor(sentences)) sentence_length = to_var(torch.LongTensor(sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) target_sentences = to_var(torch.LongTensor(target_sentences)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) # treat whole batch as one data sample weights = [] for j in range(self.config.importance_sample): sentence_logits, kl_div, log_p_z, log_q_zx = self.model( sentences, sentence_length, input_conversation_length, target_sentences) recon_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) log_w = (-recon_loss.sum() + log_p_z - log_q_zx).data weights.append(log_w) if j == 0: n_total_words += n_words.item() kl_div_history.append(kl_div.item()) # weights: [n_samples] weights = torch.stack(weights, 0) m = np.floor(weights.max()) weights = np.log(torch.exp(weights - m).type(get_type()).sum()) weights = m + weights - np.log(self.config.importance_sample) weight_history.append(weights) print(f'Number of words: {n_total_words}') bits_per_word = -np.sum(weight_history) / n_total_words print(f'Bits per word: {bits_per_word:.3f}') word_perplexity = np.exp(bits_per_word) epoch_kl_div = np.sum(kl_div_history) / n_total_words print_str = f'Word perplexity upperbound using {self.config.importance_sample} importance samples: {word_perplexity:.3f}, kl_div: {epoch_kl_div:.3f}\n' print(print_str) return word_perplexity
def evaluate(self): self.model.eval() batch_loss_history = [] recon_loss_history = [] kl_div_history = [] bow_loss_history = [] bleu_history = [] sequences_history = [] levenshteins_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations sentences = [sent for conv in conversations for sent in conv] input_conversation_length = [l - 1 for l in conversation_length] target_sentences = [ sent for conv in target_conversations for sent in conv ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] sentence_length = [ l for len_list in sentence_length for l in len_list ] with torch.no_grad(): sentences = to_var(torch.LongTensor(sentences)) sentence_length = to_var(torch.LongTensor(sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) target_sentences = to_var(torch.LongTensor(target_sentences)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) if batch_i == 0: input_conversations = [conv[:-1] for conv in conversations] input_sentences = [ sent for conv in input_conversations for sent in conv ] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) scores = self.generate_sentence(sentences, sentence_length, input_conversation_length, input_sentences, target_sentences) bleu_history += scores["bleus"] sequences_history += scores["sequences"] levenshteins_history += scores["levenshteins"] sentence_logits, kl_div, _, _ = self.model( sentences, sentence_length, input_conversation_length, target_sentences) recon_loss, n_words = masked_cross_entropy(sentence_logits, target_sentences, target_sentence_length) batch_loss = recon_loss + kl_div if self.config.bow: bow_loss = self.model.compute_bow_loss(target_conversations) bow_loss_history.append(bow_loss.item()) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) recon_loss_history.append(recon_loss.item()) kl_div_history.append(kl_div.item()) n_total_words += n_words.item() epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_recon_loss = np.sum(recon_loss_history) / n_total_words epoch_kl_div = np.sum(kl_div_history) / n_total_words print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' if bow_loss_history: epoch_bow_loss = np.sum(bow_loss_history) / n_total_words print_str += f', bow_loss = {epoch_bow_loss:.3f}' print(print_str) print('\n') self.average_bleu = sum(bleu_history) / len(bleu_history) self.average_sequences = sum(sequences_history) / len( sequences_history) self.average_levenshteins = sum(levenshteins_history) / len( levenshteins_history) print("Average scores:") print(" -> bleu:", self.average_bleu) print(" -> sequence:", self.average_sequences) print(" -> levenshteins:", self.average_levenshteins) return epoch_loss
def train(self): epoch_loss_history = [] kl_mult = 0.0 conv_kl_mult = 0.0 for epoch_i in range(self.epoch_i, self.config.n_epoch): self.epoch_i = epoch_i batch_loss_history = [] recon_loss_history = [] kl_div_history = [] kl_div_sent_history = [] kl_div_conv_history = [] bow_loss_history = [] self.model.train() n_total_words = 0 # self.evaluate() for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.train_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations sentences = [sent for conv in conversations for sent in conv] input_conversation_length = [ l - 1 for l in conversation_length ] target_sentences = [ sent for conv in target_conversations for sent in conv ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] sentence_length = [ l for len_list in sentence_length for l in len_list ] sentences = to_var(torch.LongTensor(sentences)) sentence_length = to_var(torch.LongTensor(sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) target_sentences = to_var(torch.LongTensor(target_sentences)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) # reset gradient self.optimizer.zero_grad() sentence_logits, kl_div, _, _ = self.model( sentences, sentence_length, input_conversation_length, target_sentences) recon_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) batch_loss = recon_loss + kl_mult * kl_div batch_loss_history.append(batch_loss.item()) recon_loss_history.append(recon_loss.item()) kl_div_history.append(kl_div.item()) n_total_words += n_words.item() if self.config.bow: bow_loss = self.model.compute_bow_loss( target_conversations) batch_loss += bow_loss bow_loss_history.append(bow_loss.item()) assert not isnan(batch_loss.item()) if batch_i % self.config.print_every == 0: print_str = f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item() / n_words.item():.3f}, recon = {recon_loss.item() / n_words.item():.3f}, kl_div = {kl_div.item() / n_words.item():.3f}' if self.config.bow: print_str += f', bow_loss = {bow_loss.item() / n_words.item():.3f}' tqdm.write(print_str) # Back-propagation batch_loss.backward() # Gradient cliping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) # Run optimizer self.optimizer.step() kl_mult = min(kl_mult + 1.0 / self.config.kl_annealing_iter, 1.0) epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_loss_history.append(epoch_loss) epoch_recon_loss = np.sum(recon_loss_history) / n_total_words epoch_kl_div = np.sum(kl_div_history) / n_total_words self.kl_mult = kl_mult self.epoch_loss = epoch_loss self.epoch_recon_loss = epoch_recon_loss self.epoch_kl_div = epoch_kl_div print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' if bow_loss_history: self.epoch_bow_loss = np.sum(bow_loss_history) / n_total_words print_str += f', bow_loss = {self.epoch_bow_loss:.3f}' print(print_str) if epoch_i % self.config.save_every_epoch == 0: self.save_model(epoch_i + 1) print('\n<Validation>...') self.validation_loss = self.evaluate() if epoch_i % self.config.plot_every_epoch == 0: self.write_summary(epoch_i) return epoch_loss_history
def train(self): epoch_loss_history = [] for epoch_i in range(self.epoch_i, self.config.n_epoch): self.epoch_i = epoch_i batch_loss_history = [] self.model.train() n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) in enumerate( tqdm(self.train_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [ sent for conv in input_conversations for sent in conv ] target_sentences = [ sent for conv in target_conversations for sent in conv ] input_sentence_length = [ l for len_list in sentence_length for l in len_list[:-1] ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] input_conversation_length = [ l - 1 for l in conversation_length ] input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var( torch.LongTensor(input_sentence_length)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) # reset gradient self.optimizer.zero_grad() sentence_logits = self.model(input_sentences, input_sentence_length, input_conversation_length, target_sentences, decode=False) batch_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() if batch_i % self.config.print_every == 0: tqdm.write( f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}' ) # Back-propagation batch_loss.backward() # Gradient cliping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) # Run optimizer self.optimizer.step() epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_loss_history.append(epoch_loss) self.epoch_loss = epoch_loss print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}' print(print_str) if epoch_i % self.config.save_every_epoch == 0: self.save_model(epoch_i + 1) print('\n<Validation>...') self.validation_loss = self.evaluate() if epoch_i % self.config.plot_every_epoch == 0: self.write_summary(epoch_i) self.save_model(self.config.n_epoch) return epoch_loss_history
def test(self): self.model.eval() batch_loss_history = [] n_total_words = 0 n_sentences = 0 f1_total = [] for batch_i, (conversations, conversation_length, sentence_length) in enumerate( tqdm(self.test_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [ sent for conv in input_conversations for sent in conv ] target_sentences = [ sent for conv in target_conversations for sent in conv ] input_sentence_length = [ l for len_list in sentence_length for l in len_list[:-1] ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] input_conversation_length = [l - 1 for l in conversation_length] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var( torch.LongTensor(input_sentence_length)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) if batch_i == 0: self.generate_sentence(input_sentences, input_sentence_length, input_conversation_length, target_sentences) generated_sentences = self.generate_conversations_with_gold_responses( input_sentences, input_sentence_length, input_conversation_length, target_sentences) conv_f1 = 0 for target_sent, output_sent in zip(target_sentences, generated_sentences): target_sent = self.vocab.decode(target_sent) output_sent = self.vocab.decode(output_sent) f1 = metrics.f1_score(output_sent, target_sent) conv_f1 += f1 conv_f1 = conv_f1 / target_sentences.shape[0] sentence_logits = self.model(input_sentences, input_sentence_length, input_conversation_length, target_sentences) batch_loss, n_words = masked_cross_entropy(sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() f1_total.append(conv_f1) n_sentences += target_sentences.shape[0] epoch_loss = np.sum(batch_loss_history) / n_total_words f1_average = np.sum(f1_total) / n_sentences print(f'Number of words: {n_total_words}') print(f'Bits per word: {epoch_loss:.3f}') word_perplexity = np.exp(epoch_loss) return word_perplexity, f1_average
def evaluate(self): self.model.eval() batch_freq_loss_history = [] n_total_freq_words = 0 batch_rare_loss_history = [] n_total_rare_words = 0 for batch_i, (conversations, conversation_length, sentence_length) in enumerate( tqdm(self.eval_freq_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [ sent for conv in input_conversations for sent in conv ] target_sentences = [ sent for conv in target_conversations for sent in conv ] input_sentence_length = [ l for len_list in sentence_length for l in len_list[:-1] ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] input_conversation_length = [l - 1 for l in conversation_length] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var( torch.LongTensor(input_sentence_length)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) ''' if batch_i == 0: self.generate_sentence(input_sentences, input_sentence_length, input_conversation_length, target_sentences) ''' sentence_logits = self.model(input_sentences, input_sentence_length, input_conversation_length, target_sentences) batch_loss, n_words = masked_cross_entropy(sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_freq_loss_history.append(batch_loss.item()) n_total_freq_words += n_words.item() epoch_freq_loss = np.sum(batch_freq_loss_history) / n_total_freq_words print_str = f'Validation freq loss: {epoch_freq_loss:.3f}\n' print(print_str) for batch_i, (conversations, conversation_length, sentence_length) in enumerate( tqdm(self.eval_rare_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [ sent for conv in input_conversations for sent in conv ] target_sentences = [ sent for conv in target_conversations for sent in conv ] input_sentence_length = [ l for len_list in sentence_length for l in len_list[:-1] ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] input_conversation_length = [l - 1 for l in conversation_length] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var( torch.LongTensor(input_sentence_length)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) ''' if batch_i == 0: self.generate_sentence(input_sentences, input_sentence_length, input_conversation_length, target_sentences) ''' sentence_logits = self.model(input_sentences, input_sentence_length, input_conversation_length, target_sentences) batch_loss, n_words = masked_cross_entropy(sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_rare_loss_history.append(batch_loss.item()) n_total_rare_words += n_words.item() epoch_rare_loss = np.sum(batch_rare_loss_history) / n_total_rare_words print_str = f'Validation rare loss: {epoch_rare_loss:.3f}\n' print(print_str) return epoch_freq_loss, epoch_rare_loss
def train(self): epoch_loss_history = list() min_validation_loss = sys.float_info.max patience_cnt = self.config.patience for epoch_i in range(self.epoch_i, self.config.n_epoch): self.epoch_i = epoch_i batch_loss_history = list() self.model.train() n_total_words = 0 for batch_i, (conversations, convs_length, utterances_length) in \ enumerate(tqdm(self.train_data_loader, ncols=80)): # conversations: [batch_size, max_conv_len, max_utter_len] list of conversation # A conversation: [max_conv_len, max_utter_len] list of utterances # An utterance: [max_utter_len] list of tokens # convs_length: [batch_size] list of integer # utterances_length: [batch_size, max_conv_len] list of conversation that has a list of utterance length input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] input_utterances = [ utter for conv in input_conversations for utter in conv ] target_utterances = [ utter for conv in target_conversations for utter in conv ] input_utterance_length = [ l for len_list in utterances_length for l in len_list[:-1] ] target_utterance_length = [ l for len_list in utterances_length for l in len_list[1:] ] input_conversation_length = [ conv_len - 1 for conv_len in convs_length ] input_utterances = to_var(torch.LongTensor(input_utterances)) target_utterances = to_var(torch.LongTensor(target_utterances)) input_utterance_length = to_var( torch.LongTensor(input_utterance_length)) target_utterance_length = to_var( torch.LongTensor(target_utterance_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) self.optimizer.zero_grad() utterances_logits = self.model(input_utterances, input_utterance_length, input_conversation_length, target_utterances, decode=False) batch_loss, n_words = masked_cross_entropy( utterances_logits, target_utterances, target_utterance_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() if batch_i % self.config.print_every == 0: tqdm.write( f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}' ) batch_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) self.optimizer.step() epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_loss_history.append(epoch_loss) self.epoch_loss = epoch_loss print(f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}') if epoch_i % self.config.save_every_epoch == 0: self.save_model(epoch_i + 1) print('\n<Validation>...') self.validation_loss = self.evaluate() if epoch_i % self.config.plot_every_epoch == 0: self.write_summary(epoch_i) if min_validation_loss > self.validation_loss: min_validation_loss = self.validation_loss else: patience_cnt -= 1 self.save_model(epoch_i) if patience_cnt < 0: print(f'\nEarly stop at {epoch_i}') self.save_model(epoch_i) return epoch_loss_history self.save_model(self.config.n_epoch) return epoch_loss_history
def evaluate(self): self.model.eval() batch_loss_history = [] recon_loss_history = [] kl_div_history = [] bow_loss_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations sentences = [sent for conv in conversations for sent in conv] input_conversation_length = [l - 1 for l in conversation_length] target_sentences = [ sent for conv in target_conversations for sent in conv ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] sentence_length = [ l for len_list in sentence_length for l in len_list ] sentences = to_var(torch.LongTensor(sentences), eval=True) sentence_length = to_var(torch.LongTensor(sentence_length), eval=True) input_conversation_length = to_var( torch.LongTensor(input_conversation_length), eval=True) target_sentences = to_var(torch.LongTensor(target_sentences), eval=True) target_sentence_length = to_var( torch.LongTensor(target_sentence_length), eval=True) if batch_i == 0: self.generate_sentence(sentences, sentence_length, input_conversation_length, target_sentences) sentence_logits, kl_div, _, _ = self.model( sentences, sentence_length, input_conversation_length, target_sentences) recon_loss, n_words = masked_cross_entropy(sentence_logits, target_sentences, target_sentence_length) batch_loss = recon_loss + kl_div if self.config.bow: bow_loss = self.model.compute_bow_loss(target_conversations) bow_loss_history.append(bow_loss.data[0]) assert not isnan(batch_loss.data[0]) batch_loss_history.append(batch_loss.data[0]) recon_loss_history.append(recon_loss.data[0]) kl_div_history.append(kl_div.data[0]) n_total_words += n_words.data[0] epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_recon_loss = np.sum(recon_loss_history) / n_total_words epoch_kl_div = np.sum(kl_div_history) / n_total_words print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' if bow_loss_history: epoch_bow_loss = np.sum(bow_loss_history) / n_total_words print_str += f', bow_loss = {epoch_bow_loss:.3f}' print(print_str) print('\n') return epoch_loss
def generate_file(self): batch_loss_history = [] recon_loss_history = [] kl_div_history = [] bow_loss_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations sentences = [sent for conv in conversations for sent in conv] input_conversation_length = [l - 1 for l in conversation_length] target_sentences = [ sent for conv in target_conversations for sent in conv ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] sentence_length = [ l for len_list in sentence_length for l in len_list ] with torch.no_grad(): sentences = to_var(torch.LongTensor(sentences)) sentence_length = to_var(torch.LongTensor(sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) target_sentences = to_var(torch.LongTensor(target_sentences)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) input_conversations = [conv[:-1] for conv in conversations] input_sentences = [ sent for conv in input_conversations for sent in conv ] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) self.generate_sentence(sentences, sentence_length, input_conversation_length, input_sentences, target_sentences) sentence_logits, kl_div, _, _ = self.model( sentences, sentence_length, input_conversation_length, target_sentences) recon_loss, n_words = masked_cross_entropy(sentence_logits, target_sentences, target_sentence_length) batch_loss = recon_loss + kl_div if self.config.bow: bow_loss = self.model_eval.compute_bow_loss( target_conversations) bow_loss_history.append(bow_loss.item()) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) recon_loss_history.append(recon_loss.item()) kl_div_history.append(kl_div.item()) n_total_words += n_words.item() # delete the cache data torch.cuda.empty_cache() epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_recon_loss = np.sum(recon_loss_history) / n_total_words epoch_kl_div = np.sum(kl_div_history) / n_total_words print_str = f'test loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' with open(self.config.kl_log_dir, 'a+') as fout: fout.write(print_str) fout.write('\n') if bow_loss_history: epoch_bow_loss = np.sum(bow_loss_history) / n_total_words print_str += f', bow_loss = {epoch_bow_loss:.3f}' print(print_str) print('\n') # self.model_eval = None torch.cuda.empty_cache() return epoch_loss
def evaluate(self): # try to copy a model to self.model.eval() """ self.model_eval = getattr(models, self.config.model)(self.config) if torch.cuda.is_available(): self.model_eval = self.model_eval.cuda(3) ckpt_path = os.path.join(self.config.save_path, f'{self.epoch_i+1}.pkl') self.model_eval.load_state_dict(torch.load(ckpt_path)) self.model_eval.eval() """ batch_loss_history = [] recon_loss_history = [] kl_div_history = [] bow_loss_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations sentences = [sent for conv in conversations for sent in conv] input_conversation_length = [l - 1 for l in conversation_length] target_sentences = [ sent for conv in target_conversations for sent in conv ] target_sentence_length = [ l for len_list in sentence_length for l in len_list[1:] ] sentence_length = [ l for len_list in sentence_length for l in len_list ] with torch.no_grad(): sentences = to_var(torch.LongTensor(sentences)) sentence_length = to_var(torch.LongTensor(sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) target_sentences = to_var(torch.LongTensor(target_sentences)) target_sentence_length = to_var( torch.LongTensor(target_sentence_length)) if batch_i == -1: input_conversations = [conv[:-1] for conv in conversations] input_sentences = [ sent for conv in input_conversations for sent in conv ] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) self.generate_sentence(sentences, sentence_length, input_conversation_length, input_sentences, target_sentences) sentence_logits, kl_div, _, _ = self.model( sentences, sentence_length, input_conversation_length, target_sentences) recon_loss, n_words = masked_cross_entropy(sentence_logits, target_sentences, target_sentence_length) batch_loss = recon_loss + kl_div if self.config.bow: bow_loss = self.model.compute_bow_loss(target_conversations) bow_loss_history.append(bow_loss.item()) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) recon_loss_history.append(recon_loss.item()) kl_div_history.append(kl_div.item()) n_total_words += n_words.item() # delete the cache data torch.cuda.empty_cache() epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_recon_loss = np.sum(recon_loss_history) / n_total_words epoch_kl_div = np.sum(kl_div_history) / n_total_words print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' with open(self.config.kl_log_dir, 'a+') as fout: fout.write(print_str) fout.write('\n') if bow_loss_history: epoch_bow_loss = np.sum(bow_loss_history) / n_total_words print_str += f', bow_loss = {epoch_bow_loss:.3f}' print(print_str) print('\n') # self.model_eval = None torch.cuda.empty_cache() return epoch_loss
def train(self): epoch_loss_history = list() kl_mult = 0.0 min_validation_loss = sys.float_info.max patience_cnt = self.config.patience for epoch_i in range(self.epoch_i, self.config.n_epoch): self.epoch_i = epoch_i batch_loss_history = list() recon_loss_history = [] kl_div_history = [] bow_loss_history = [] self.model.train() n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.train_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] utterances = [ utter for conv in conversations for utter in conv ] target_utterances = [ utter for conv in target_conversations for utter in conv ] utterance_length = [ l for len_list in sentence_length for l in len_list ] target_utterance_length = [ l for len_list in sentence_length for l in len_list[1:] ] input_conversation_length = [ conv_len - 1 for conv_len in conversation_length ] utterances = to_var(torch.LongTensor(utterances)) utterance_length = to_var(torch.LongTensor(utterance_length)) target_utterances = to_var(torch.LongTensor(target_utterances)) target_utterance_length = to_var( torch.LongTensor(target_utterance_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) self.optimizer.zero_grad() utterances_logits, kl_div = self.model( utterances, utterance_length, input_conversation_length, target_utterances, decode=False) recon_loss, n_words = masked_cross_entropy( utterances_logits, target_utterances, target_utterance_length) batch_loss = recon_loss + kl_mult * kl_div batch_loss_history.append(batch_loss.item()) recon_loss_history.append(recon_loss.item()) kl_div_history.append(kl_div.item()) n_total_words += n_words.item() if batch_i % self.config.print_every == 0: print_str = f'Epoch: {epoch_i + 1}, iter {batch_i}: ' \ f'loss = {batch_loss.item() / n_words.item():.3f}, ' \ f'recon = {recon_loss.item() / n_words.item():.3f}, ' \ f'kl_div = {kl_div.item() / n_words.item():.3f}' tqdm.write(print_str) batch_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) self.optimizer.step() kl_mult = min(kl_mult + 1.0 / self.config.kl_annealing_iter, 1.0) epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_loss_history.append(epoch_loss) epoch_recon_loss = np.sum(recon_loss_history) / n_total_words epoch_kl_div = np.sum(kl_div_history) / n_total_words self.kl_mult = kl_mult self.epoch_loss = epoch_loss self.epoch_recon_loss = epoch_recon_loss self.epoch_kl_div = epoch_kl_div print_str = f'Epoch {epoch_i + 1} loss average: {epoch_loss:.3f}, ' \ f'recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' if bow_loss_history: self.epoch_bow_loss = np.sum(bow_loss_history) / n_total_words print_str += f', bow_loss = {self.epoch_bow_loss:.3f}' print(print_str) if epoch_i % self.config.save_every_epoch == 0: self.save_model(epoch_i + 1) print('\n<Validation>...') self.validation_loss = self.evaluate() if epoch_i % self.config.plot_every_epoch == 0: self.write_summary(epoch_i) if min_validation_loss > self.validation_loss: min_validation_loss = self.validation_loss else: patience_cnt -= 1 self.save_model(epoch_i) if patience_cnt < 0: print(f'\nEarly stop at {epoch_i}') self.save_model(epoch_i) return epoch_loss_history self.save_model(self.config.n_epoch) return epoch_loss_history