def validate(self, sess, x_val): # Calculate BLEU on validation data hypotheses_val = [] references_val = [] for batch_i, (input_batch, output_batch, sent_lengths) in enumerate( utils.get_batches(x_val, self.batch_size)): pred_sentences, self._validate_logits = sess.run( [self.validate_sent, self.validate_logits], feed_dict={self.input_data: input_batch, self.source_sentence_length: sent_lengths, self.keep_prob: 1.0, }) for pred, actual in zip(pred_sentences, output_batch): hypotheses_val.append( word_tokenize( " ".join([self.idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]]))) references_val.append( [word_tokenize(" ".join([self.idx_word[i] for i in actual if i not in [self.pad, -1, self.eos]]))]) self.val_pred = ([" ".join(sent) for sent in hypotheses_val]) self.val_ref = ([" ".join(sent[0]) for sent in references_val]) bleu_scores = utils.calculate_bleu_scores(references_val, hypotheses_val) self.epoch_bleu_score_val['1'].append(bleu_scores[0]) self.epoch_bleu_score_val['2'].append(bleu_scores[1]) self.epoch_bleu_score_val['3'].append(bleu_scores[2]) self.epoch_bleu_score_val['4'].append(bleu_scores[3])
def predict(self, checkpoint, x_test): pred_logits = [] hypotheses_test = [] references_test = [] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(sess, checkpoint) for batch_i, (input_batch, output_batch, sent_lengths) in enumerate( utils.get_batches(x_test, self.batch_size)): result = sess.run(self.validate_sent, feed_dict={self.input_data: input_batch, self.source_sentence_length: sent_lengths, self.keep_prob: 1.0, }) pred_logits.extend(result) for pred, actual in zip(result, output_batch): hypotheses_test.append( word_tokenize(" ".join( [self.idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]]))) references_test.append([word_tokenize( " ".join([self.idx_word[i] for i in actual if i not in [self.pad, -1, self.eos]]))]) bleu_scores = utils.calculate_bleu_scores(references_test, hypotheses_test) print('BLEU 1 to 4 : {}'.format(' | '.join(map(str, bleu_scores)))) return pred_logits
def validate(self, sess, x_val, y_val, true_val): # Calculate BLEU on validation data hypotheses_val = [] references_val = [] for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate( utils.get_batches_xy(x_val, y_val, self.batch_size)): answer_logits = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch, self.source_sentence_length: source_sent_lengths, self.keep_prob: 1.0, self.word_dropout_keep_prob: 1.0, self.z_temperature: self.z_temp}) for k, pred in enumerate(answer_logits): hypotheses_val.append( word_tokenize( " ".join([self.decoder_idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]]))) references_val.append([word_tokenize(true_val[batch_i * self.batch_size + k])]) self.val_pred = ([" ".join(sent) for sent in hypotheses_val]) self.val_ref = ([" ".join(sent[0]) for sent in references_val]) bleu_scores = utils.calculate_bleu_scores(references_val, hypotheses_val) self.epoch_bleu_score_val['1'].append(bleu_scores[0]) self.epoch_bleu_score_val['2'].append(bleu_scores[1]) self.epoch_bleu_score_val['3'].append(bleu_scores[2]) self.epoch_bleu_score_val['4'].append(bleu_scores[3])