Exemplo n.º 1
0
def batch_to_batch(batch, batcher, dis_batcher):

    db_example_list = []

    for i in range(FLAGS.batch_size):
        new_dis_example = bd.Example(batch.original_review_inputs[i], -0.01,
                                     dis_batcher._vocab, dis_batcher._hps)
        db_example_list.append(new_dis_example)
    return bd.Batch(db_example_list, dis_batcher._hps, dis_batcher._vocab)
Exemplo n.º 2
0
def output_to_batch(current_batch, results, batcher, dis_batcher):
    # 生成新的 batch 和 dis-batch
    example_list = []
    db_example_list = []

    for i in range(FLAGS.batch_size):
        decoded_words_all = []
        encode_words = current_batch.original_review_inputs[i]

        for j in range(FLAGS.max_dec_sen_num):
            output_ids = [int(t) for t in results['Greedy_outputs'][i][0:]]
            decoded_words = data.outputids2words(output_ids, batcher._vocab,
                                                 None)
            # Remove the [STOP] token from decoded_words, if necessary
            try:
                # index of the (first) [STOP] symbol
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words
            if len(decoded_words) < 2:
                continue
            if len(decoded_words_all) > 0:
                new_set1 = set(decoded_words_all[len(decoded_words_all) -
                                                 1].split())
                new_set2 = set(decoded_words)
                if len(new_set1 & new_set2) > 0.5 * len(new_set2):
                    continue
            if decoded_words[-1] != '.' and decoded_words[
                    -1] != '!' and decoded_words[-1] != '?':
                decoded_words.append('.')
            decoded_output = ' '.join(decoded_words).strip()
            decoded_words_all.append(decoded_output)

        decoded_words_all = ' '.join(decoded_words_all).strip()
        try:
            fst_stop_idx = decoded_words_all.index(data.STOP_DECODING_DOCUMENT)
            decoded_words_all = decoded_words_all[:fst_stop_idx]
        except ValueError:
            decoded_words_all = decoded_words_all
        decoded_words_all = decoded_words_all.replace("[UNK] ", "")
        decoded_words_all = decoded_words_all.replace("[UNK]", "")
        decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all)
        decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all)

        if decoded_words_all.strip() == "":
            new_dis_example = bd.Example(
                current_batch.original_review_output[i], -0.0001,
                dis_batcher._vocab, dis_batcher._hps)
            new_example = Example(current_batch.original_review_output[i],
                                  batcher._vocab, batcher._hps, encode_words)
        else:
            new_dis_example = bd.Example(decoded_words_all, 1,
                                         dis_batcher._vocab, dis_batcher._hps)
            new_example = Example(decoded_words_all, batcher._vocab,
                                  batcher._hps, encode_words)
        example_list.append(new_example)
        db_example_list.append(new_dis_example)
    return Batch(example_list, batcher._hps,
                 batcher._vocab), bd.Batch(db_example_list, dis_batcher._hps,
                                           dis_batcher._vocab)