예제 #1
0
def output_to_classification_batch(output,batch, batcher, cla_batcher,cc):
    example_list =[]
    bleu =[]
    for i in range(FLAGS.batch_size):
        decoded_words_all = []



        output_ids = [int(t) for t in output[i]]
        decoded_words = data.outputids2words(output_ids, batcher._vocab, None)
        # Remove the [STOP] token from decoded_words, if necessary
        try:
            fst_stop_idx = decoded_words.index(data.STOP_DECODING)  # index of the (first) [STOP] symbol
            decoded_words = decoded_words[:fst_stop_idx]
        except ValueError:
            decoded_words = decoded_words


        decoded_words_all = ' '.join(decoded_words).strip()  # single string


        decoded_words_all = decoded_words_all.replace("[UNK] ", "")
        decoded_words_all = decoded_words_all.replace("[UNK]", "")
       

        if decoded_words_all.strip() == "":
            bleu.append(0)
            new_dis_example = bc.Example(".", batch.score, cla_batcher._vocab, cla_batcher._hps)
            
        else:
            bleu.append(sentence_bleu([batch.original_reviews[i].split()],decoded_words_all.split(),smoothing_function=cc.method1))
            new_dis_example = bc.Example(decoded_words_all, batch.score, cla_batcher._vocab, cla_batcher._hps)
        example_list.append(new_dis_example)

    return bc.Batch(example_list, cla_batcher._hps, cla_batcher._vocab), bleu
예제 #2
0
def batch_classification_batch(batch, batcher, cla_batcher):
    db_example_list = []

    for i in range(FLAGS.batch_size):

        original_text = batch.original_reviews[i].split()
        if len(original_text) > batcher._hps.max_enc_steps:  #:
            original_text = original_text[:batcher._hps.max_enc_steps]

        new_original_text = []

        for j in range(len(original_text)):
            if batch.weight[i][j] >=1:
                new_original_text.append(original_text[j])

        new_original_text = " ".join(new_original_text)
        if new_original_text.strip() =="":
            new_original_text = ". "

        new_dis_example = bc.Example(new_original_text,
                                              batch.score,
                                              cla_batcher._vocab, cla_batcher._hps)
        db_example_list.append(new_dis_example)

    return bc.Batch(db_example_list, cla_batcher._hps, cla_batcher._vocab)
    def run_test_classification(self, data_path, model, batcher, sess):

        example_list = self.add_example_queue(data_path, batcher._vocab,
                                              batcher._hps)
        step = 0
        right = 0
        all = 0

        while step < int(len(example_list) / FLAGS.batch_size):
            current_batch = bc.Batch(
                example_list[step * FLAGS.batch_size:(step + 1) *
                             FLAGS.batch_size], batcher._hps, batcher._vocab)
            step += 1
            right_s, number, error_list, error_label = model.run_eval_step(
                sess, current_batch)

            all += number
            right += right_s

        tf.logging.info(str("classification acc: ") + str(right / (all * 1.0)))
예제 #4
0
    def generator_validation_negative_example(self, path, batcher, model_class,
                                              sess_cls, cla_batcher, mode):

        if not os.path.exists(path): os.mkdir(path)
        shutil.rmtree(path)
        if not os.path.exists(path): os.mkdir(path)
        counter = 0
        step = 0

        t0 = time.time()
        if mode == 'valid-transfer':
            batches = self.valid_transfer
        elif mode == 'test-transfer':
            batches = self.test_transfer
            print("len test", len(batches))

        list_ref = []
        list_pre = []
        right = 0
        all = 0

        #while step < len(batches):
        while step < 310:
            cla_input = []

            batch = batches[step]
            step += 1
            decode_result = self._model.max_generator(self._sess, batch)

            example_list = []
            #print("step", step)
            for i in range(FLAGS.batch_size):

                original_review = batch.original_reviews[i]  # string
                score = batch.score
                output_ids = [int(t) for t in decode_result['generated'][i]][:]
                decoded_words = data.outputids2words(output_ids, self._vocab,
                                                     None)
                # Remove the [STOP] token from decoded_words, if necessary
                try:
                    fst_stop_idx = decoded_words.index(
                        data.STOP_DECODING
                    )  # index of the (first) [STOP] symbol
                    decoded_words = decoded_words[:fst_stop_idx]
                except ValueError:
                    decoded_words = decoded_words
                decoded_output = ' '.join(decoded_words)  # single string
                self.write_negtive_to_json(original_review, decoded_output,
                                           score, counter, path)
                counter += 1  # this is how many examples we've decoded
                cla_input.append(decoded_output)
                '''
                if len(original_review.split())>2 and len(decoded_output.split())>2:
                    list_ref.append([original_review.split()])
                    list_pre.append(decoded_output.split())
                '''
                #bleu.append(sentence_bleu([batch.original_reviews[i]], decoded_words_all.split()))
                if decoded_output.strip() == "":
                    decoded_output = ". "
                new_dis_example = bc.Example(decoded_output, batch.score,
                                             cla_batcher._vocab,
                                             cla_batcher._hps)
                example_list.append(new_dis_example)

            cla_batch = bc.Batch(example_list, cla_batcher._hps,
                                 cla_batcher._vocab)
            right_s, all_s, _, pre = model_class.run_eval_step(
                sess_cls, cla_batch)
            right += right_s
            all += all_s

            for i in range(FLAGS.batch_size):
                if len(batch.original_reviews[i].split()) > 2 and len(
                        cla_input[i].split()) > 2 and batch.score == pre[i]:
                    list_ref.append([batch.original_reviews[i].split()])
                    list_pre.append(cla_input[i].split())

        transfer_acc = right * 1.0 / all * 100
        bleu = corpus_bleu(list_ref, list_pre) * 100

        tf.logging.info("valid transfer acc: " + str(transfer_acc))
        tf.logging.info("BLEU: " + str(bleu))

        return transfer_acc, bleu
예제 #5
0
def run_test_our_method(cla_batcher, cnn_classifier, sess_cnn, filename):
    test_to_true = read_test_result_our(filename, 1)
    test_to_false = read_test_result_our(filename, 0)

    gold_to_true = read_test_input(filename, 1)
    gold_to_false = read_test_input(filename, 0)

    list_ref = []
    list_pre = []
    right = 0
    all = 0

    right_cnn = 0
    all_cnn = 0

    for i in range(len(gold_to_true) // 64):
        example_list = []
        for j in range(64):
            # example_list.append(test_true[i*64+j])
            new_dis_example = bc.Example(test_to_true[i * 64 + j], 1,
                                         cla_batcher._vocab, cla_batcher._hps)
            # list_pre.append(test_false[i*64+j].split())
            example_list.append(new_dis_example)

            # list_ref.append([gold_text[i*64+j].split()])

        cla_batch = bc.Batch(example_list, cla_batcher._hps,
                             cla_batcher._vocab)

        right_s, all_s, _, pre = cnn_classifier.run_eval_step(
            sess_cnn, cla_batch)
        right_cnn += right_s
        all_cnn += all_s

        for j in range(64):

            if len(gold_to_true[i * 64 + j].split()) > 2 and len(
                    test_to_true[i * 64 + j].split()) > 2 and 1 == pre[j]:
                list_ref.append([gold_to_true[i * 64 + j].split()])
                list_pre.append(test_to_true[i * 64 + j].split())

    for i in range(len(gold_to_false) // 64):
        example_list = []
        for j in range(64):
            # example_list.append(test_true[i*64+j])
            new_dis_example = bc.Example(test_to_false[i * 64 + j], 0,
                                         cla_batcher._vocab, cla_batcher._hps)
            # list_pre.append(test_false[i*64+j].split())
            example_list.append(new_dis_example)

            # list_ref.append([gold_text[i*64+j].split()])

        cla_batch = bc.Batch(example_list, cla_batcher._hps,
                             cla_batcher._vocab)

        right_s, all_s, _, pre = cnn_classifier.run_eval_step(
            sess_cnn, cla_batch)
        right_cnn += right_s
        all_cnn += all_s

        for j in range(64):

            if len(gold_to_false[i * 64 + j].split()) > 2 and len(
                    test_to_false[i * 64 + j].split()) > 2 and 0 == pre[j]:
                list_ref.append([gold_to_false[i * 64 + j].split()])
                list_pre.append(test_to_false[i * 64 + j].split())

    tf.logging.info("cnn test acc: " + str(right_cnn * 1.0 / all_cnn))
    cc = SmoothingFunction()
    tf.logging.info(
        "BLEU: " +
        str(corpus_bleu(list_ref, list_pre, smoothing_function=cc.method1)))