Пример #1
0
def process(corpus_folder, raw_file_name, save_folder):
    corpus_list = []
    for name in raw_file_name:
        raw_file = os.path.join(corpus_folder, name)
        with open(raw_file, "r") as fr:
            lines = fr.readlines()

        for i, line in enumerate(lines):
            source, target, label = line.strip().split("\t")
            if label == "0" or source == target:
                continue
            if label != "1":
                input(curLine() + line.strip())
            length = float(len(source) + len(target))

            source_length = len(source)
            if source_length > 8 and source_length < 38 and (
                    i + 1) % 2 > 0:  # 对50%的长句构造交换操作
                rand = random.uniform(0.4, 0.9)
                source_pre = source
                swag_location = int(source_length * rand)
                source = "%s%s" % (source[swag_location:],
                                   source[:swag_location])
                lcs1 = _compute_lcs(source, target)
                lcs_rate = len(lcs1) / length
                if (lcs_rate < 0.4):  # 差异大,换回来
                    source = source_pre
                else:
                    print(
                        curLine(), "source_pre:%s, source:%s, lcs_rate=%f" %
                        (source_pre, source, lcs_rate))

            lcs1 = _compute_lcs(source, target)
            lcs_rate = len(lcs1) / length
            if (lcs_rate < 0.2):
                continue  # 变动过大,忽略

            # if (lcs_rate<0.4):
            #   continue # 变动过大,忽略
            # if len(source)*1.15 < len(target):
            #   new_t = source
            #   source = target
            #   target = new_t
            #   print(curLine(), source, target, ",lcs1:",lcs1 , ",lcs_rate=", lcs_rate)
            corpus = "%s\t%s\t%f\n" % (source, target, lcs_rate)
            corpus_list.append(corpus)
        print(curLine(), len(corpus_list), "from %s" % raw_file)
    save_file = os.path.join(save_folder, "lcqmc.txt")
    with open(save_file, "w") as fw:
        fw.writelines(corpus_list)
    print(curLine(), "have save %d to %s" % (len(corpus_list), save_file))
Пример #2
0
def process(corpus_folder, raw_file_name, save_folder):
  raw_file = os.path.join(corpus_folder, raw_file_name)
  with open(raw_file, "r") as fr:
    lines = fr.readlines()
  corpus_list = []
  for line in lines:
    sent_list = line.strip().split("&&")
    sent_num = len(sent_list)
    for i in range(1, sent_num, 2):
      source= sent_list[i-1]
      target = sent_list[i]
      length = float(len(source) + len(target))
      lcs1 = _compute_lcs(source, target)
      lcs_rate= len(lcs1)/length
      if (lcs_rate<0.3):
        continue # 变动过大,忽略
      if len(source)*1.15 < len(target):
        new_t = source
        source = target
        target = new_t
      corpus = "%s\t%s\t%f\n" % (source, target, lcs_rate)
      corpus_list.append(corpus)
  save_file = os.path.join(save_folder, "baoan_airport.txt")
  with open(save_file, "w") as fw:
    fw.writelines(corpus_list)
  print(curLine(), "have save %d to %s" % (len(corpus_list), save_file))
Пример #3
0
def _get_added_phrases(source: Text, target: Text) -> Sequence[Text]:
    """Computes the phrases that need to be added to the source to get the target.

    This is done by aligning each token in the LCS to the first match in the
    target and checking which phrases in the target remain unaligned.

    TODO(b/142853960): The LCS tokens should ideally be aligned to consecutive(连续不断的)
    target tokens whenever possible, instead of aligning them always to the first
    match. This should result in a more meaningful phrase vocabulary with a higher
    coverage.

    Note that the algorithm is case-insensitive and the resulting phrases are
    always lowercase.

    Args:
      source: Source text.
      target: Target text.

    Returns:
      List of added phrases.
    """
    sep = ' '  # 英文是分成word sep=' ',中文是分成字 sep=''
    source_tokens = utils.get_token_list(
        source.lower())  # list(source.lower()) #  切句成字列表
    target_tokens = utils.get_token_list(
        target.lower())  # list(target.lower()) #  切句成字列表
    #print("phrase_vocabulary_optimization.py source_tokens",source_tokens)
    #print("phrase_vocabulary_optimization.py target_tokens",target_tokens)
    kept_tokens = _compute_lcs(source_tokens, target_tokens)  # 共用字
    #print("phrase_vocabulary_optimization.py kept_tokens",kept_tokens)
    added_phrases = []
    # Index of the `kept_tokens` element that we are currently looking for.
    kept_idx = 0
    phrase = []
    for token in target_tokens:
        if kept_idx < len(kept_tokens) and token == kept_tokens[kept_idx]:
            kept_idx += 1
            #print(phrase)
            if phrase:
                added_phrases.append(sep.join(phrase))
                phrase = []
        else:
            phrase.append(token)
    #print("phrase_vocabulary_optimization sep",sep)
    if phrase:
        added_phrases.append(sep.join(phrase))
    return added_phrases
Пример #4
0
def process(corpus_folder, raw_file_name):
    raw_file = os.path.join(corpus_folder, raw_file_name)
    # 打开文件,获取excel文件的workbook(工作簿)对象
    workbook = xlrd.open_workbook(raw_file)  # 文件路径

    # 通过sheet索引获得sheet对象
    worksheet = workbook.sheet_by_index(0)
    nrows = worksheet.nrows  # 获取该表总行数
    ncols = worksheet.ncols  # 获取该表总列数
    print(
        curLine(), "raw_file_name:%s, worksheet:%s nrows=%d, ncols=%d" %
        (raw_file_name, worksheet.name, nrows, ncols))
    assert ncols == 3
    assert nrows > 0
    col_data = worksheet.col_values(0)  # 获取第一列的内容
    corpus_list = []
    for line in col_data:
        sent_list = line.strip().split("&&")
        sent_num = len(sent_list)
        for i in range(1, sent_num, 2):
            source = sent_list[i - 1]
            target = sent_list[i]
            # source_length = len(source)
            # if source_length > 8 and (i+1)%4>0: # 对50%的长句随机删除
            #   rand = random.uniform(0.1, 0.9)
            #   source_pre = source
            #   swag_location = int(source_length*rand)
            #   source = "%s%s" % (source[:swag_location], source[swag_location+1:])
            #   print(curLine(), "source_pre:%s, source:%s" % (source_pre, source))

            length = float(len(source) + len(target))
            lcs1 = _compute_lcs(source, target)
            lcs_rate = len(lcs1) / length
            if (lcs_rate < 0.2):
                continue  # 变动过大,忽略

            # if (lcs_rate<0.3):
            #   continue # 变动过大,忽略
            # if len(source)*1.15 < len(target):
            #   new_t = source
            #   source = target
            #   target = new_t
            corpus = "%s\t%s\t%f\n" % (source, target, lcs_rate)
            corpus_list.append(corpus)
    return corpus_list
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sourcesA_list = []
    sourcesB_list = []
    target_list = []
    with tf.gfile.GFile(FLAGS.input_file) as f:
        for line in f:
            sourceA, sourceB, label = line.rstrip('\n').split('\t')
            sourcesA_list.append([sourceA.strip(".")])
            sourcesB_list.append([sourceB.strip(".")])
            target_list.append(label)

    number = len(sourcesA_list)  # 总样本数
    predict_batch_size = min(32, number)
    batch_num = math.ceil(float(number) / predict_batch_size)

    start_time = time.time()
    num_predicted = 0
    prediction_list = []
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        for batch_id in range(batch_num):
            sources_batch = sourcesA_list[batch_id *
                                          predict_batch_size:(batch_id + 1) *
                                          predict_batch_size]
            batch_b = sourcesB_list[batch_id *
                                    predict_batch_size:(batch_id + 1) *
                                    predict_batch_size]
            location_batch = []
            sources_batch.extend(batch_b)
            for source in sources_batch:
                location = list()
                for char in source[0]:
                    if (char >= '0' and char <= '9') or char in '.- ' or (
                            char >= 'a' and char <= 'z') or (char >= 'A'
                                                             and char <= 'Z'):
                        location.append("1")  # TODO TODO
                    else:
                        location.append("0")
                location_batch.append("".join(location))
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch, location_batch=location_batch)
            current_batch_size = int(len(sources_batch) / 2)
            assert len(prediction_batch) == current_batch_size * 2

            for id in range(0, current_batch_size):
                target = target_list[num_predicted + id]
                prediction_A = prediction_batch[id]
                prediction_B = prediction_batch[current_batch_size + id]
                sourceA = "".join(sources_batch[id])
                sourceB = "".join(sources_batch[current_batch_size + id])
                if prediction_A == prediction_B:  # 其中一个换为source
                    lcsA = len(_compute_lcs(sourceA, prediction_A))
                    if lcsA < 8:  # A的变化大
                        prediction_B = sourceB
                    else:
                        lcsB = len(_compute_lcs(sourceB, prediction_B))
                        if lcsA <= lcsB:  # A的变化大
                            prediction_B = sourceB
                        else:
                            prediction_A = sourceA
                            print(curLine(), batch_id, prediction_A,
                                  prediction_B, "target:", target,
                                  "current_batch_size=", current_batch_size,
                                  "lcsA=%d,lcsB=%d" % (lcsA, lcsB))
                writer.write(f'{prediction_A}\t{prediction_B}\t{target}\n')

                prediction_list.append("%s\t%s\n" % (sourceA, prediction_A))
                # print(curLine(), id,"sourceA:", sourceA, "sourceB:",sourceB, "target:", target)
                prediction_list.append("%s\t%s\n" % (sourceB, prediction_B))
            num_predicted += current_batch_size
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(curLine(), id, prediction_A, prediction_B, "target:",
                      target, "current_batch_size=", current_batch_size)
                print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB,
                      "target:", target)
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    with open("prediction.txt", "w") as prediction_file:
        prediction_file.writelines(prediction_list)
        print(curLine(), "save to prediction_qa.txt.")
    cost_time = (time.time() - start_time) / 60.0
    print(curLine(), id, prediction_A, prediction_B, "target:", target,
          "current_batch_size=", current_batch_size)
    print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB, "target:",
          target)
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted*60000}ms.'
    )