Example #1
0
    def predict(self, checkpoint, x_test, y_test, true_test):
        pred_logits = []
        hypotheses_test = []
        references_test = []

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate(
                    utils.get_batches_xy(x_test, y_test, self.batch_size)):
                result = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch,
                                                                    self.source_sentence_length: source_sent_lengths,
                                                                    self.keep_prob: 1.0,
                                                                    self.word_dropout_keep_prob: 1.0,
                                                                    self.z_temperature: self.z_temp})

                pred_logits.extend(result)

                for k, pred in enumerate(result):
                    hypotheses_test.append(
                        word_tokenize(" ".join(
                            [self.decoder_idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])))
                    references_test.append([word_tokenize(true_test[batch_i * self.batch_size + k])])

            bleu_scores = utils.calculate_bleu_scores(references_test, hypotheses_test)

        print('BLEU 1 to 4 : {}'.format(' | '.join(map(str, bleu_scores))))

        return pred_logits
Example #2
0
    def get_neighbourhood(self, checkpoint, x_test, temp=1.0, num_samples=10):
        answer_logits = []
        pred_sentences = []
        x_test_repeated = np.repeat(x_test, num_samples, axis=0)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate(
                    utils.get_batches_xy(x_test, y_test, self.batch_size)):
                result = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch,
                                                                    self.source_sentence_length: sent_lengths,
                                                                    self.keep_prob: 1.0,
                                                                    self.word_dropout_keep_prob: 1.0,
                                                                    self.z_temperature: temp})
                answer_logits.extend(result)

            for idx, (actual, pred) in enumerate(zip(x_test_repeated, answer_logits)):
                pred_sentences.append(" ".join([self.decoder_idx_word[i] for i in pred if i not in [self.pad, self.eos]]))

        for j in range(len(pred_sentences)):
            if j % num_samples == 0:
                print('\nA: {}'.format(" ".join([self.decoder_idx_word[i] for i in x_test_repeated[j] if i not in [self.pad, self.eos]])))
            print('G: {}'.format(pred_sentences[j]))
Example #3
0
    def validate(self, sess, x_val, y_val, true_val):
        # Calculate BLEU on validation data
        hypotheses_val = []
        references_val = []

        for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate(
                utils.get_batches_xy(x_val, y_val, self.batch_size)):
            answer_logits = sess.run(self.inference_logits,
                                     feed_dict={self.input_data: input_batch,
                                                self.source_sentence_length: source_sent_lengths,
                                                self.keep_prob: 1.0,
                                                self.word_dropout_keep_prob: 1.0,
                                                self.z_temperature: self.z_temp})

            for k, pred in enumerate(answer_logits):
                hypotheses_val.append(
                    word_tokenize(
                        " ".join([self.decoder_idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])))
                references_val.append([word_tokenize(true_val[batch_i * self.batch_size + k])])
                
        self.val_pred = ([" ".join(sent) for sent in hypotheses_val])
        self.val_ref  = ([" ".join(sent[0]) for sent in references_val])

        bleu_scores = utils.calculate_bleu_scores(references_val, hypotheses_val)
        self.epoch_bleu_score_val['1'].append(bleu_scores[0])
        self.epoch_bleu_score_val['2'].append(bleu_scores[1])
        self.epoch_bleu_score_val['3'].append(bleu_scores[2])
        self.epoch_bleu_score_val['4'].append(bleu_scores[3])
    def validate(self, sess, input_val, output_val, label_val):
        # Calculate BLEU on validation data
        hypotheses_val = []
        references_val = []

        for batch_i, (input_batch, output_batch, label_batch, input_sent_lengths, output_sent_lengths) in enumerate(
            utils.get_batches_xy(input_val, output_val, label_val, self.batch_size)):
            pred_sentences, self._validate_logits = sess.run(
                                 [self.validate_sent, self.validate_logits],
                                 feed_dict={self.input_data: input_batch,
                                            self.labels: label_batch,
                                            self.source_sentence_length: input_sent_lengths,
                                            self.keep_prob: 1.0,
                                            })


            for pred, actual in zip(pred_sentences, output_batch):
                hypotheses_val.append(
                    word_tokenize(
                        " ".join([self.idx_word[i] for i in pred if i not in [self.pad, -1, self.eos]])))
                references_val.append(
                    [word_tokenize(" ".join([self.idx_word[i] for i in actual if i not in [self.pad, -1, self.eos]]))])
            self.val_pred = ([" ".join(sent)    for sent in hypotheses_val])
            self.val_ref  = ([" ".join(sent[0]) for sent in references_val])

        bleu_scores = utils.calculate_bleu_scores(references_val, hypotheses_val)

        self.epoch_bleu_score_val['1'].append(bleu_scores[0])
        self.epoch_bleu_score_val['2'].append(bleu_scores[1])
        self.epoch_bleu_score_val['3'].append(bleu_scores[2])
        self.epoch_bleu_score_val['4'].append(bleu_scores[3])
Example #5
0
    def train(self, x_train, y_train, x_val, y_val, true_val):

        print('[INFO] Training process started')

        learning_rate = self.initial_learning_rate
        iter_i = 0

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            writer = tf.summary.FileWriter(self.logs_dir, sess.graph)

            for epoch_i in range(1, self.epochs + 1):

                start_time = time.time()

                for batch_i, (input_batch, output_batch, source_sent_lengths,
                              tar_sent_lengths) in enumerate(
                                  utils.get_batches_xy(x_train, y_train,
                                                       self.batch_size)):

                    try:
                        iter_i += 1

                        _, _summary, self.train_xent = sess.run(
                            [self.train_op, self.summary_op, self.xent_loss],
                            feed_dict={
                                self.input_data: input_batch,
                                self.target_data: output_batch,
                                self.lr: learning_rate,
                                self.source_sentence_length:
                                source_sent_lengths,
                                self.target_sentence_length: tar_sent_lengths,
                                self.keep_prob: self.dropout_keep_prob,
                                self.lambda_coeff: self.lambda_val,
                            })

                        writer.add_summary(_summary, iter_i)

                    except Exception as e:
                        print(iter_i, e)
                        pass

                # Reduce learning rate, but not below its minimum value
                learning_rate = np.max([
                    self.min_learning_rate,
                    learning_rate * self.learning_rate_decay
                ])

                time_consumption = time.time() - start_time
                self.monitor(x_val, y_val, true_val, sess, epoch_i,
                             time_consumption)
Example #6
0
    def get_diversity_metrics(self, checkpoint, x_test, y_test, num_samples=10, num_iterations=3):

        x_test_repeated = np.repeat(x_test, num_samples, axis=0)
        y_test_repeated = np.repeat(y_test, num_samples, axis=0)

        entropy_list = []
        uni_diversity = []
        bi_diversity = []

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint)

            for _ in tqdm(range(num_iterations)):
                total_ent = 0
                uni = 0
                bi = 0
                answer_logits = []
                pred_sentences = []

                for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate(
                        utils.get_batches_xy(x_test_repeated, y_test_repeated, self.batch_size)):
                    result = sess.run(self.inference_logits, feed_dict={self.input_data: input_batch,
                                                                        self.source_sentence_length: source_sent_lengths,
                                                                        self.keep_prob: 1.0,
                                                                        self.word_dropout_keep_prob: 1.0,
                                                                        self.z_temperature: self.z_temp})
                    answer_logits.extend(result)

                for idx, (actual, pred) in enumerate(zip(x_test_repeated, answer_logits)):
                    pred_sentences.append(" ".join([self.decoder_idx_word[i] for i in pred if i not in [self.pad, self.eos]]))

                    if (idx + 1) % num_samples == 0:
                        word_list = [word_tokenize(p) for p in pred_sentences]
                        corpus = [item for sublist in word_list for item in sublist]
                        total_ent += utils.calculate_entropy(corpus)
                        diversity_result = utils.calculate_ngram_diversity(corpus)
                        uni += diversity_result[0]
                        bi += diversity_result[1]

                        pred_sentences = []

                entropy_list.append(total_ent / len(x_test))
                uni_diversity.append(uni / len(x_test))
                bi_diversity.append(bi / len(x_test))

        print('Entropy = {:>.3f} | Distinct-1 = {:>.3f} | Distinct-2 = {:>.3f}'.format(np.mean(entropy_list),
                                                                                       np.mean(uni_diversity),
                                                                                       np.mean(bi_diversity)))
Example #7
0
    def train(self, x_train, y_train, x_val, y_val, true_val):

        print('[INFO] Training process started')

        learning_rate = self.initial_learning_rate
        iter_i = 0
        if gl.config['anneal_type'] == 'none':
            lambda_val = gl.config['lambda_val']
        else:
            lambda_val = 0.0 # Start from zero and anneal upwards in tanh or linear fashion

        wd_anneal = 1.0

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            writer = tf.summary.FileWriter(self.logs_dir, sess.graph)

            for epoch_i in range(1, self.epochs + 1):

                start_time = time.time()

                for batch_i, (input_batch, output_batch, source_sent_lengths, tar_sent_lengths) in enumerate(
                        utils.get_batches_xy(x_train, y_train, self.batch_size)):

                    try:
                        iter_i += 1

                        _, _summary, self.train_xent = sess.run(
                            [self.train_op, self.summary_op, self.xent_loss],
                            feed_dict={self.input_data: input_batch,
                                       self.target_data: output_batch,
                                       self.lr: learning_rate,
                                       self.source_sentence_length: source_sent_lengths,
                                       self.target_sentence_length: tar_sent_lengths,
                                       self.keep_prob: self.dropout_keep_prob,
                                       self.lambda_coeff: lambda_val,
                                       self.z_temperature: self.z_temp,
                                       self.word_dropout_keep_prob: wd_anneal,
                                       })

                        writer.add_summary(_summary, iter_i)

                        # KL Annealing till some iteration
                        if iter_i <= self.anneal_till:
                            if gl.config['anneal_type'] == 'tanh':                         
                                lambda_val = np.round((np.tanh((iter_i - 4500) / 1000) + 1) / 2, decimals=6)
                                # lambda_val = np.round(logistic.cdf(iter_i/4500) - 0.5, decimals=6)
                            elif gl.config['anneal_type'] == 'linear':
                                lambda_val = np.round(iter_i*0.000005, decimals=6)
                        
                    except Exception as e:
                        print(iter_i, e)
                        pass

                # Reduce learning rate, but not below its minimum value
                learning_rate = np.max([self.min_learning_rate, learning_rate * self.learning_rate_decay])

                # Anneal word dropout from 1.0 to the limit
                wd_anneal = np.max([self.word_dropout_keep_probability, wd_anneal - 0.05])
                
                time_consumption = time.time() - start_time
                self.monitor(x_val, y_val, true_val, sess, epoch_i, time_consumption)