def batch_to_batch(batch, batcher, dis_batcher): db_example_list = [] for i in range(FLAGS.batch_size): new_dis_example = bd.Example(batch.original_review_output[i], -0.01, dis_batcher._vocab, dis_batcher._hps) db_example_list.append(new_dis_example) return bd.Batch(db_example_list, dis_batcher._hps, dis_batcher._vocab)
def output_to_batch(current_batch, result, batcher, dis_batcher): example_list= [] db_example_list = [] for i in range(FLAGS.batch_size): decoded_words_all = [] encode_words = current_batch.original_review_inputs[i] for j in range(FLAGS.max_dec_sen_num): output_ids = [int(t) for t in result['generated'][i][j]][1:] decoded_words = data.outputids2words(output_ids, batcher._vocab, None) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words if len(decoded_words) < 2: continue if len(decoded_words_all) > 0: new_set1 = set(decoded_words_all[len(decoded_words_all) - 1].split()) new_set2 = set(decoded_words) if len(new_set1 & new_set2) > 0.5 * len(new_set2): continue if decoded_words[-1] != '.' and decoded_words[-1] != '!' and decoded_words[-1] != '?': decoded_words.append('.') decoded_output = ' '.join(decoded_words).strip() # single string decoded_words_all.append(decoded_output) decoded_words_all = ' '.join(decoded_words_all).strip() try: fst_stop_idx = decoded_words_all.index( data.STOP_DECODING_DOCUMENT) # index of the (first) [STOP] symbol decoded_words_all = decoded_words_all[:fst_stop_idx] except ValueError: decoded_words_all = decoded_words_all decoded_words_all = decoded_words_all.replace("[UNK] ", "") decoded_words_all = decoded_words_all.replace("[UNK]", "") decoded_words_all, _ = re.subn(r"(! ){2,}", "", decoded_words_all) decoded_words_all, _ = re.subn(r"(\. ){2,}", "", decoded_words_all) if decoded_words_all.strip() == "": '''tf.logging.info("decode") tf.logging.info(current_batch.original_reviews[i]) tf.logging.info("encode") tf.logging.info(encode_words)''' new_dis_example = bd.Example(current_batch.original_review_output[i], -0.0001, dis_batcher._vocab, dis_batcher._hps) new_example = Example(current_batch.original_review_output[i], batcher._vocab, batcher._hps,encode_words) else: '''tf.logging.info("decode") tf.logging.info(decoded_words_all) tf.logging.info("encode") tf.logging.info(encode_words)''' new_dis_example = bd.Example(decoded_words_all, 1, dis_batcher._vocab, dis_batcher._hps) new_example = Example(decoded_words_all, batcher._vocab, batcher._hps,encode_words) example_list.append(new_example) db_example_list.append(new_dis_example) return Batch(example_list, batcher._hps, batcher._vocab), bd.Batch(db_example_list, dis_batcher._hps, dis_batcher._vocab)