def validate(self, sess, x_val, y_val, true_val):
        # Calculate BLEU on validation data
        hypotheses_val = []
        references_val = []
        symbol = []
        if self.config['experiment'] == 'qgen':
            symbol.append('?')
        for batch_i, (input_batch, output_batch, source_sent_lengths,
                      tar_sent_lengths) in enumerate(
                          data_utils.get_batches(x_val, y_val,
                                                 self.batch_size)):
            answer_logits = sess.run(self.inference_logits,
                                     feed_dict={
                                         self.input_data: input_batch,
                                         self.source_sentence_length:
                                         source_sent_lengths,
                                         self.keep_prob: 1.0
                                     })

            for k, pred in enumerate(answer_logits):
                hypotheses_val.append(
                    word_tokenize(" ".join([
                        self.decoder_idx_word[i]
                        for i in pred if i not in [self.pad, -1, self.eos]
                    ])) + symbol)
                references_val.append(
                    [word_tokenize(true_val[batch_i * self.batch_size + k])])

        bleu_scores = eval_utils.calculate_bleu_scores(references_val,
                                                       hypotheses_val)
        self.epoch_bleu_score_val['1'].append(bleu_scores[0])
        self.epoch_bleu_score_val['2'].append(bleu_scores[1])
        self.epoch_bleu_score_val['3'].append(bleu_scores[2])
        self.epoch_bleu_score_val['4'].append(bleu_scores[3])
    def predict(self, checkpoint, x_test, y_test, true_test):
        pred_logits = []
        hypotheses_test = []
        references_test = []
        symbol=[]
        if self.config['experiment'] == 'qgen':
            symbol.append('?')

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate(
                    data_utils.get_batches(x_test, y_test, self.batch_size)):
                result = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch,
                                                                    self.source_sentence_length: source_sent_lengths,
                                                                    self.keep_prob: 1.0})

                pred_logits.extend(result)

                for k, pred in enumerate(result):
                    hypotheses_test.append(
                        word_tokenize(" ".join(
                            [self.decoder_idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])) + symbol)
                    references_test.append([word_tokenize(true_test[batch_i * self.batch_size + k])])

            bleu_scores = eval_utils.calculate_bleu_scores(references_test, hypotheses_test)

        print('BLEU 1 to 4 : {}'.format(' | '.join(map(str, bleu_scores))))

        return pred_logits