def validate(self, sess, x_val, y_val, true_val): # Calculate BLEU on validation data hypotheses_val = [] references_val = [] symbol = [] if self.config['experiment'] == 'qgen': symbol.append('?') for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate( data_utils.get_batches(x_val, y_val, self.batch_size)): answer_logits = sess.run(self.inference_logits, feed_dict={ self.input_data: input_batch, self.source_sentence_length: source_sent_lengths, self.keep_prob: 1.0 }) for k, pred in enumerate(answer_logits): hypotheses_val.append( word_tokenize(" ".join([ self.decoder_idx_word[i] for i in pred if i not in [self.pad, -1, self.eos] ])) + symbol) references_val.append( [word_tokenize(true_val[batch_i * self.batch_size + k])]) bleu_scores = eval_utils.calculate_bleu_scores(references_val, hypotheses_val) self.epoch_bleu_score_val['1'].append(bleu_scores[0]) self.epoch_bleu_score_val['2'].append(bleu_scores[1]) self.epoch_bleu_score_val['3'].append(bleu_scores[2]) self.epoch_bleu_score_val['4'].append(bleu_scores[3])
def predict(self, checkpoint, x_test, y_test, true_test): pred_logits = [] hypotheses_test = [] references_test = [] symbol=[] if self.config['experiment'] == 'qgen': symbol.append('?') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(sess, checkpoint) for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate( data_utils.get_batches(x_test, y_test, self.batch_size)): result = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch, self.source_sentence_length: source_sent_lengths, self.keep_prob: 1.0}) pred_logits.extend(result) for k, pred in enumerate(result): hypotheses_test.append( word_tokenize(" ".join( [self.decoder_idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])) + symbol) references_test.append([word_tokenize(true_test[batch_i * self.batch_size + k])]) bleu_scores = eval_utils.calculate_bleu_scores(references_test, hypotheses_test) print('BLEU 1 to 4 : {}'.format(' | '.join(map(str, bleu_scores)))) return pred_logits