def predict(self, checkpoint, x_test, y_test, true_test): pred_logits = [] hypotheses_test = [] references_test = [] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(sess, checkpoint) for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate( utils.get_batches_xy(x_test, y_test, self.batch_size)): result = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch, self.source_sentence_length: source_sent_lengths, self.keep_prob: 1.0, self.word_dropout_keep_prob: 1.0, self.z_temperature: self.z_temp}) pred_logits.extend(result) for k, pred in enumerate(result): hypotheses_test.append( word_tokenize(" ".join( [self.decoder_idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]]))) references_test.append([word_tokenize(true_test[batch_i * self.batch_size + k])]) bleu_scores = utils.calculate_bleu_scores(references_test, hypotheses_test) print('BLEU 1 to 4 : {}'.format(' | '.join(map(str, bleu_scores)))) return pred_logits
def get_neighbourhood(self, checkpoint, x_test, temp=1.0, num_samples=10): answer_logits = [] pred_sentences = [] x_test_repeated = np.repeat(x_test, num_samples, axis=0) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(sess, checkpoint) for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate( utils.get_batches_xy(x_test, y_test, self.batch_size)): result = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch, self.source_sentence_length: sent_lengths, self.keep_prob: 1.0, self.word_dropout_keep_prob: 1.0, self.z_temperature: temp}) answer_logits.extend(result) for idx, (actual, pred) in enumerate(zip(x_test_repeated, answer_logits)): pred_sentences.append(" ".join([self.decoder_idx_word[i] for i in pred if i not in [self.pad, self.eos]])) for j in range(len(pred_sentences)): if j % num_samples == 0: print('\nA: {}'.format(" ".join([self.decoder_idx_word[i] for i in x_test_repeated[j] if i not in [self.pad, self.eos]]))) print('G: {}'.format(pred_sentences[j]))
def validate(self, sess, x_val, y_val, true_val): # Calculate BLEU on validation data hypotheses_val = [] references_val = [] for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate( utils.get_batches_xy(x_val, y_val, self.batch_size)): answer_logits = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch, self.source_sentence_length: source_sent_lengths, self.keep_prob: 1.0, self.word_dropout_keep_prob: 1.0, self.z_temperature: self.z_temp}) for k, pred in enumerate(answer_logits): hypotheses_val.append( word_tokenize( " ".join([self.decoder_idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]]))) references_val.append([word_tokenize(true_val[batch_i * self.batch_size + k])]) self.val_pred = ([" ".join(sent) for sent in hypotheses_val]) self.val_ref = ([" ".join(sent[0]) for sent in references_val]) bleu_scores = utils.calculate_bleu_scores(references_val, hypotheses_val) self.epoch_bleu_score_val['1'].append(bleu_scores[0]) self.epoch_bleu_score_val['2'].append(bleu_scores[1]) self.epoch_bleu_score_val['3'].append(bleu_scores[2]) self.epoch_bleu_score_val['4'].append(bleu_scores[3])
def validate(self, sess, input_val, output_val, label_val): # Calculate BLEU on validation data hypotheses_val = [] references_val = [] for batch_i, (input_batch, output_batch, label_batch, input_sent_lengths, output_sent_lengths) in enumerate( utils.get_batches_xy(input_val, output_val, label_val, self.batch_size)): pred_sentences, self._validate_logits = sess.run( [self.validate_sent, self.validate_logits], feed_dict={self.input_data: input_batch, self.labels: label_batch, self.source_sentence_length: input_sent_lengths, self.keep_prob: 1.0, }) for pred, actual in zip(pred_sentences, output_batch): hypotheses_val.append( word_tokenize( " ".join([self.idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]]))) references_val.append( [word_tokenize(" ".join([self.idx_word[i] for i in actual if i not in [self.pad, -1, self.eos]]))]) self.val_pred = ([" ".join(sent) for sent in hypotheses_val]) self.val_ref = ([" ".join(sent[0]) for sent in references_val]) bleu_scores = utils.calculate_bleu_scores(references_val, hypotheses_val) self.epoch_bleu_score_val['1'].append(bleu_scores[0]) self.epoch_bleu_score_val['2'].append(bleu_scores[1]) self.epoch_bleu_score_val['3'].append(bleu_scores[2]) self.epoch_bleu_score_val['4'].append(bleu_scores[3])
def train(self, x_train, y_train, x_val, y_val, true_val): print('[INFO] Training process started') learning_rate = self.initial_learning_rate iter_i = 0 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter(self.logs_dir, sess.graph) for epoch_i in range(1, self.epochs + 1): start_time = time.time() for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate( utils.get_batches_xy(x_train, y_train, self.batch_size)): try: iter_i += 1 _, _summary, self.train_xent = sess.run( [self.train_op, self.summary_op, self.xent_loss], feed_dict={ self.input_data: input_batch, self.target_data: output_batch, self.lr: learning_rate, self.source_sentence_length: source_sent_lengths, self.target_sentence_length: tar_sent_lengths, self.keep_prob: self.dropout_keep_prob, self.lambda_coeff: self.lambda_val, }) writer.add_summary(_summary, iter_i) except Exception as e: print(iter_i, e) pass # Reduce learning rate, but not below its minimum value learning_rate = np.max([ self.min_learning_rate, learning_rate * self.learning_rate_decay ]) time_consumption = time.time() - start_time self.monitor(x_val, y_val, true_val, sess, epoch_i, time_consumption)
def get_diversity_metrics(self, checkpoint, x_test, y_test, num_samples=10, num_iterations=3): x_test_repeated = np.repeat(x_test, num_samples, axis=0) y_test_repeated = np.repeat(y_test, num_samples, axis=0) entropy_list = [] uni_diversity = [] bi_diversity = [] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(sess, checkpoint) for _ in tqdm(range(num_iterations)): total_ent = 0 uni = 0 bi = 0 answer_logits = [] pred_sentences = [] for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate( utils.get_batches_xy(x_test_repeated, y_test_repeated, self.batch_size)): result = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch, self.source_sentence_length: source_sent_lengths, self.keep_prob: 1.0, self.word_dropout_keep_prob: 1.0, self.z_temperature: self.z_temp}) answer_logits.extend(result) for idx, (actual, pred) in enumerate(zip(x_test_repeated, answer_logits)): pred_sentences.append(" ".join([self.decoder_idx_word[i] for i in pred if i not in [self.pad, self.eos]])) if (idx + 1) % num_samples == 0: word_list = [word_tokenize(p) for p in pred_sentences] corpus = [item for sublist in word_list for item in sublist] total_ent += utils.calculate_entropy(corpus) diversity_result = utils.calculate_ngram_diversity(corpus) uni += diversity_result[0] bi += diversity_result[1] pred_sentences = [] entropy_list.append(total_ent / len(x_test)) uni_diversity.append(uni / len(x_test)) bi_diversity.append(bi / len(x_test)) print('Entropy = {:>.3f} | Distinct-1 = {:>.3f} | Distinct-2 = {:>.3f}'.format(np.mean(entropy_list), np.mean(uni_diversity), np.mean(bi_diversity)))
def train(self, x_train, y_train, x_val, y_val, true_val): print('[INFO] Training process started') learning_rate = self.initial_learning_rate iter_i = 0 if gl.config['anneal_type'] == 'none': lambda_val = gl.config['lambda_val'] else: lambda_val = 0.0 # Start from zero and anneal upwards in tanh or linear fashion wd_anneal = 1.0 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter(self.logs_dir, sess.graph) for epoch_i in range(1, self.epochs + 1): start_time = time.time() for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate( utils.get_batches_xy(x_train, y_train, self.batch_size)): try: iter_i += 1 _, _summary, self.train_xent = sess.run( [self.train_op, self.summary_op, self.xent_loss], feed_dict={self.input_data: input_batch, self.target_data: output_batch, self.lr: learning_rate, self.source_sentence_length: source_sent_lengths, self.target_sentence_length: tar_sent_lengths, self.keep_prob: self.dropout_keep_prob, self.lambda_coeff: lambda_val, self.z_temperature: self.z_temp, self.word_dropout_keep_prob: wd_anneal, }) writer.add_summary(_summary, iter_i) # KL Annealing till some iteration if iter_i <= self.anneal_till: if gl.config['anneal_type'] == 'tanh': lambda_val = np.round((np.tanh((iter_i - 4500) / 1000) + 1) / 2, decimals=6) # lambda_val = np.round(logistic.cdf(iter_i/4500) - 0.5, decimals=6) elif gl.config['anneal_type'] == 'linear': lambda_val = np.round(iter_i*0.000005, decimals=6) except Exception as e: print(iter_i, e) pass # Reduce learning rate, but not below its minimum value learning_rate = np.max([self.min_learning_rate, learning_rate * self.learning_rate_decay]) # Anneal word dropout from 1.0 to the limit wd_anneal = np.max([self.word_dropout_keep_probability, wd_anneal - 0.05]) time_consumption = time.time() - start_time self.monitor(x_val, y_val, true_val, sess, epoch_i, time_consumption)