Beispiel #1
0
    def validate(self, sess, x_val):
        # Calculate BLEU on validation data
        hypotheses_val = []
        references_val = []

        for batch_i, (input_batch, output_batch, sent_lengths) in enumerate(
                utils.get_batches(x_val, self.batch_size)):
            pred_sentences, self._validate_logits = sess.run(
                                     [self.validate_sent, self.validate_logits],
                                     feed_dict={self.input_data: input_batch,
                                                self.source_sentence_length: sent_lengths,
                                                self.keep_prob: 1.0,
                                                })


            for pred, actual in zip(pred_sentences, output_batch):
                hypotheses_val.append(
                    word_tokenize(
                        " ".join([self.idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])))
                references_val.append(
                    [word_tokenize(" ".join([self.idx_word[i] for i in actual if i not in [self.pad, -1, self.eos]]))])
            self.val_pred = ([" ".join(sent)    for sent in hypotheses_val])
            self.val_ref  = ([" ".join(sent[0]) for sent in references_val])

        bleu_scores = utils.calculate_bleu_scores(references_val, hypotheses_val)

        self.epoch_bleu_score_val['1'].append(bleu_scores[0])
        self.epoch_bleu_score_val['2'].append(bleu_scores[1])
        self.epoch_bleu_score_val['3'].append(bleu_scores[2])
        self.epoch_bleu_score_val['4'].append(bleu_scores[3])
Beispiel #2
0
    def predict(self, checkpoint, x_test):
        pred_logits = []
        hypotheses_test = []
        references_test = []

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            for batch_i, (input_batch, output_batch, sent_lengths) in enumerate(
                    utils.get_batches(x_test, self.batch_size)):
                result = sess.run(self.validate_sent, feed_dict={self.input_data: input_batch,
                                                                    self.source_sentence_length: sent_lengths,
                                                                    self.keep_prob: 1.0,
                                                                    })

                pred_logits.extend(result)

                for pred, actual in zip(result, output_batch):
                    hypotheses_test.append(
                        word_tokenize(" ".join(
                            [self.idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])))
                    references_test.append([word_tokenize(
                        " ".join([self.idx_word[i] for i in actual if i not in [self.pad, -1, self.eos]]))])

            bleu_scores = utils.calculate_bleu_scores(references_test, hypotheses_test)

        print('BLEU 1 to 4 : {}'.format(' | '.join(map(str, bleu_scores))))

        return pred_logits
Beispiel #3
0
    def validate(self, sess, x_val, y_val, true_val):
        # Calculate BLEU on validation data
        hypotheses_val = []
        references_val = []

        for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate(
                utils.get_batches_xy(x_val, y_val, self.batch_size)):
            answer_logits = sess.run(self.inference_logits,
                                     feed_dict={self.input_data: input_batch,
                                                self.source_sentence_length: source_sent_lengths,
                                                self.keep_prob: 1.0,
                                                self.word_dropout_keep_prob: 1.0,
                                                self.z_temperature: self.z_temp})

            for k, pred in enumerate(answer_logits):
                hypotheses_val.append(
                    word_tokenize(
                        " ".join([self.decoder_idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])))
                references_val.append([word_tokenize(true_val[batch_i * self.batch_size + k])])
                
        self.val_pred = ([" ".join(sent) for sent in hypotheses_val])
        self.val_ref  = ([" ".join(sent[0]) for sent in references_val])

        bleu_scores = utils.calculate_bleu_scores(references_val, hypotheses_val)
        self.epoch_bleu_score_val['1'].append(bleu_scores[0])
        self.epoch_bleu_score_val['2'].append(bleu_scores[1])
        self.epoch_bleu_score_val['3'].append(bleu_scores[2])
        self.epoch_bleu_score_val['4'].append(bleu_scores[3])