コード例 #1
0
ファイル: utils.py プロジェクト: kiminh/bert_music_correct
def _nlu_examples(input_file, target_domain_name=None):
    with tf.gfile.GFile(input_file) as f:
        reader = csv.reader(f)
        session_list = []
        for row_id, (sessionId, raw_query, domain_intent,
                     param) in enumerate(reader):
            query = normal_transformer(raw_query)
            param = normal_transformer(param)
            sources = []
            if row_id > 0 and sessionId == session_list[row_id - 1][0]:
                sources.append(session_list[row_id - 1][1])  # last query
            sources.append(query)
            if domain_intent == other_tag:
                domain = other_tag
            else:
                domain, intent = domain_intent.split(".")
            session_list.append((sessionId, query))
            if target_domain_name is not None and target_domain_name != domain:
                continue
            yield sources, (intent, param)
コード例 #2
0
def add_to_ac(ac, entity_type, entity_before, entity_after, pri):
    entity_before = normal_transformer(entity_before)
    flag = "ignore"
    if entity_type == "song" and ((entity_after in frequentSong) and entity_after not in {"李白"}):
        return flag
    if entity_type == "singer" and entity_after in frequentSinger:
        return flag
    elif entity_type == "toplist" and entity_before == "首张":
        return flag
    elif entity_type == "emotion" and entity_before in {"high歌","相思","喜欢"}:  # train和devs就是这么标注的
        return flag
    elif entity_type == "language" and entity_before in ["中国"]:  # train和devs就是这么标注的
        return flag
    ac.add(keywords=entity_before, meta_data=(entity_after,pri))
    return "add success"
コード例 #3
0
    def process_sample(self, text1):
        text = normal_transformer(text1)
        if len(text) < 1:
            print(curLine(), "text:%s, text1:%s" % (text, text1))
        if self.args.language.lower() == "chinese":
            processed_text = [
                self.vocab.index(w) for w in list(text)[:self.args.max_len]
            ]
        else:
            processed_text = [
                self.vocab.index(w) for w in text.split()[:self.args.max_len]
            ],
        processed_len = len(processed_text)

        return processed_text, processed_len
コード例 #4
0
ファイル: utils.py プロジェクト: kiminh/nlu_match
def _nlu_examples(input_file):
  with tf.gfile.GFile(input_file) as f:
    reader = csv.reader(f)
    session_list = []
    for row_id, (sessionId, raw_query, domain_intent, param) in enumerate(reader):
      query = normal_transformer(raw_query)
      sources = []
      # 目前最多考虑2个历史请求,即三轮对话
      if row_id > 1 and sessionId == session_list[row_id - 2][0]:
        sources.append(session_list[row_id-2][1])  # last last query
      if row_id > 0 and sessionId == session_list[row_id - 1][0]:
        sources.append(session_list[row_id-1][1])  # last query
      sources.append(query)
      if domain_intent == other_tag:
        domain = other_tag
      else:
        domain, intent = domain_intent.split(".")
      session_list.append((sessionId, query))
      yield sources, domain
コード例 #5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')
    label_map = utils.read_label_map(FLAGS.label_map_file)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s saved_model:%s" % (curLine(), FLAGS.saved_model), "red"))

    ##### test
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    domain_list = []
    slot_info_list = []
    intent_list = []
    sources_list = []
    predict_batch_size = 32
    limit = predict_batch_size * 1500  # 5184 # 10001 #
    with tf.gfile.GFile(FLAGS.input_file) as f:
        reader = csv.reader(f)
        session_list = []
        for row_id, line in enumerate(reader):
            if len(line) > 2:
                (sessionId, raw_query, domain_intent, slot) = line
            else:
                (sessionId, raw_query) = line
            query = normal_transformer(raw_query)
            sources = []
            if row_id > 1 and sessionId == session_list[row_id - 2][0]:
                sources.append(session_list[row_id - 2][1])  # last last query
            if row_id > 0 and sessionId == session_list[row_id - 1][0]:
                sources.append(session_list[row_id - 1][1])  # last query
            sources.append(query)
            session_list.append((sessionId, raw_query))
            sources_list.append(sources)
            if len(line) > 2:  # 有标注
                if domain_intent == other_tag:
                    domain = other_tag
                    intent = other_tag
                else:
                    domain, intent = domain_intent.split(".")
                domain_list.append(domain)
                intent_list.append(intent)
                slot_info_list.append(slot)
            if len(sources_list) >= limit:
                print(
                    colored(
                        "%s stop reading at %d to save time" %
                        (curLine(), limit), "red"))
                break

    number = len(sources_list)  # 总样本数
    predict_domain_list = []
    predict_intent_list = []
    predict_slot_list = []
    pred_domainMap_list = []
    predict_batch_size = min(predict_batch_size, number)
    batch_num = math.ceil(float(number) / predict_batch_size)
    start_time = time.time()
    num_predicted = 0
    modemode = 'a'
    if len(domain_list) > 0:  # 有标注
        modemode = 'w'
    previous_sessionId = None
    domain_history = []
    with tf.gfile.Open(FLAGS.output_file, modemode) as writer:
        if len(domain_list) > 0:  # 有标注
            writer.write("\t".join([
                "sessionId", "query", "predDomain", "predIntent", "predSlot",
                "domain", "intent", "Slot"
            ]) + "\n")
        for batch_id in range(batch_num):
            sources_batch = sources_list[batch_id *
                                         predict_batch_size:(batch_id + 1) *
                                         predict_batch_size]
            prediction_batch, pred_domainMap_batch = predictor.predict_batch(
                sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [current_predict_domain, pred_domainMap,
                     sources] in enumerate(
                         zip(prediction_batch, pred_domainMap_batch,
                             sources_batch)):
                sessionId, raw_query = session_list[batch_id *
                                                    predict_batch_size + id]
                if sessionId != previous_sessionId:  # 新的会话
                    domain_history = []
                    previous_sessionId = sessionId
                predict_domain, predict_intent, slot_info = rules(
                    raw_query, current_predict_domain, domain_history)
                pred_domainMap_list.append(pred_domainMap)
                domain_history.append((predict_domain, predict_intent))  # 记录多轮
                predict_domain_list.append(predict_domain)
                predict_intent_list.append(predict_intent)
                predict_slot_list.append(slot_info)
                if len(domain_list) > 0:  # 有标注
                    domain = domain_list[batch_id * predict_batch_size + id]
                    intent = intent_list[batch_id * predict_batch_size + id]
                    slot = slot_info_list[batch_id * predict_batch_size + id]
                    writer.write("\t".join([
                        sessionId, raw_query, predict_domain, predict_intent,
                        slot_info, domain, intent, slot
                    ]) + "\n")
            if batch_id % 5 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    print(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.'
    )

    if FLAGS.submit_file is not None:
        domain_counter = collections.Counter()
        if os.path.exists(path=FLAGS.submit_file):
            os.remove(FLAGS.submit_file)
        with open(FLAGS.submit_file, 'w', encoding='UTF-8') as f:
            writer = csv.writer(f, dialect='excel')
            # writer.writerow(["session_id", "query", "intent", "slot_annotation"])  # TODO
            for example_id, sources in enumerate(sources_list):
                sessionId, raw_query = session_list[example_id]
                predict_domain = predict_domain_list[example_id]
                predict_intent = predict_intent_list[example_id]
                predict_domain_intent = other_tag
                domain_counter.update([predict_domain])
                if predict_domain != other_tag:
                    predict_domain_intent = "%s.%s" % (predict_domain,
                                                       predict_intent)
                line = [
                    sessionId, raw_query, predict_domain_intent,
                    predict_slot_list[example_id]
                ]
                writer.writerow(line)
        print(curLine(), "example_id=", example_id)
        print(curLine(), "domain_counter:", domain_counter)
        cost_time = (time.time() - start_time) / 60.0
        num_predicted = example_id + 1
        print(curLine(), "domain cost %f s" % (cost_time))
        print(
            f'{curLine()} {num_predicted} predictions saved to:{FLAGS.submit_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.'
        )
        domain_score_file = "%s/submit_domain_score.json" % (
            FLAGS.domain_score_folder)
    else:
        domain_score_file = "%s/predict_domain_score.json" % (
            FLAGS.domain_score_folder)

    with open(domain_score_file, "w") as fw:
        json.dump(pred_domainMap_list, fw, ensure_ascii=False, indent=4)
    print(curLine(),
          "dump %d to %s" % (len(pred_domainMap_list), domain_score_file))
コード例 #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')
    label_map = utils.read_label_map(FLAGS.label_map_file)
    slot_label_map = utils.read_label_map(FLAGS.slot_label_map_file)
    target_domain_name = FLAGS.domain_name
    print(curLine(), "target_domain_name:", target_domain_name)
    assert target_domain_name in ["navigation", "phone_call", "music"]
    entity_type_list = utils.read_label_map(FLAGS.entity_type_list_file)[FLAGS.domain_name]

    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, slot_label_map=slot_label_map,
                                              entity_type_list=entity_type_list, get_entity_func=exacter_acmation.get_all_entity)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map, slot_label_map, target_domain_name=target_domain_name)
    print(colored("%s saved_model:%s" % (curLine(), FLAGS.saved_model), "red"))

    ##### test
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))


    domain_list = []
    slot_info_list = []
    intent_list = []

    predict_domain_list = []
    previous_pred_slot_list = []
    previous_pred_intent_list = []
    sources_list = []
    predict_batch_size = 64
    limit = predict_batch_size * 1500 # 5184 # 10001 #
    with tf.gfile.GFile(FLAGS.input_file) as f:
        reader = csv.reader(f)
        session_list = []
        for row_id, line in enumerate(reader):
            if len(line) == 1:
                line = line[0].strip().split("\t")
            if len(line) > 4:  # 有标注
                (sessionId, raw_query, predDomain, predIntent, predSlot, domain, intent, slot) = line
                domain_list.append(domain)
                intent_list.append(intent)
                slot_info_list.append(slot)
            else:
                (sessionId, raw_query, predDomainIntent, predSlot) = line
                if "." in predDomainIntent:
                    predDomain,predIntent = predDomainIntent.split(".")
                else:
                    predDomain,predIntent = predDomainIntent, predDomainIntent
            if "忘记电话" in raw_query:
                predDomain = "phone_call" # rule
            if "专用道" in raw_query:
                predDomain = "navigation" # rule
            predict_domain_list.append(predDomain)
            previous_pred_slot_list.append(predSlot)
            previous_pred_intent_list.append(predIntent)
            query = normal_transformer(raw_query)
            if query != raw_query:
                print(curLine(), len(query),     "query:    ", query)
                print(curLine(), len(raw_query), "raw_query:", raw_query)

            sources = []
            if row_id > 0 and sessionId == session_list[row_id - 1][0]:
                sources.append(session_list[row_id - 1][1])  # last query
            sources.append(query)
            session_list.append((sessionId, raw_query))
            sources_list.append(sources)

            if len(sources_list) >= limit:
                print(colored("%s stop reading at %d to save time" %(curLine(), limit), "red"))
                break

    number = len(sources_list)  # 总样本数

    predict_intent_list = []
    predict_slot_list = []
    predict_batch_size = min(predict_batch_size, number)
    batch_num = math.ceil(float(number) / predict_batch_size)
    start_time = time.time()
    num_predicted = 0
    modemode = 'a'
    if len(domain_list) > 0:  # 有标注
        modemode = 'w'
    with tf.gfile.Open(FLAGS.output_file, modemode) as writer:
        # if len(domain_list) > 0:  # 有标注
        #     writer.write("\t".join(["sessionId", "query", "predDomain", "predIntent", "predSlot", "domain", "intent", "Slot"]) + "\n")
        for batch_id in range(batch_num):
            # if batch_id <= 48:
            #     continue
            sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            predict_domain_batch = predict_domain_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            predict_intent_batch, predict_slot_batch = predictor.predict_batch(sources_batch=sources_batch, target_domain_name=target_domain_name, predict_domain_batch=predict_domain_batch)
            assert len(predict_intent_batch) == len(sources_batch)
            num_predicted += len(predict_intent_batch)
            for id, [predict_intent, predict_slot_info, sources] in enumerate(zip(predict_intent_batch, predict_slot_batch, sources_batch)):
                sessionId, raw_query = session_list[batch_id * predict_batch_size + id]
                predict_domain = predict_domain_list[batch_id * predict_batch_size + id]
                # if predict_domain == "music":
                #     predict_slot_info = raw_query
                #     if predict_intent == "play":  # 模型分类到播放意图,但没有找到槽位,这时用ac自动机提高召回
                #         predict_intent_rule, predict_slot_info = rules(raw_query, predict_domain, target_domain_name)
                        # # if predict_intent_rule in {"pause", "next"}:
                        # #     predict_intent = predict_intent_rule
                        # if "<" in predict_slot_info_rule : # and "<" not in predict_slot_info:
                        #     predict_slot_info = predict_slot_info_rule
                        #     print(curLine(), "predict_slot_info_rule:", predict_slot_info_rule)
                        #     print(curLine())

                if predict_domain != target_domain_name:  #  不是当前模型的domain,用规则识别
                    predict_intent = previous_pred_intent_list[batch_id * predict_batch_size + id]
                    predict_slot_info = previous_pred_slot_list[batch_id * predict_batch_size + id]
                # else:
                #     print(curLine(), predict_intent, "predict_slot_info:", predict_slot_info)
                predict_intent_list.append(predict_intent)
                predict_slot_list.append(predict_slot_info)
                if len(domain_list) > 0:  # 有标注
                    domain = domain_list[batch_id * predict_batch_size + id]
                    intent = intent_list[batch_id * predict_batch_size + id]
                    slot = slot_info_list[batch_id * predict_batch_size + id]
                    domain_flag = "right"
                    if domain != predict_domain:
                        domain_flag = "wrong"
                    writer.write("\t".join([sessionId, raw_query, predict_domain, predict_intent, predict_slot_info, domain, intent, slot]) + "\n") # , domain_flag
            if batch_id % 5 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
                      (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    print(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.')


    if FLAGS.submit_file is not None:
        import collections, os
        domain_counter = collections.Counter()
        if os.path.exists(path=FLAGS.submit_file):
            os.remove(FLAGS.submit_file)
        with open(FLAGS.submit_file, 'w',encoding='UTF-8') as f:
            writer = csv.writer(f, dialect='excel')
            # writer.writerow(["session_id", "query", "intent", "slot_annotation"])  # TODO
            for example_id, sources in enumerate(sources_list):
                sessionId, raw_query = session_list[example_id]
                predict_domain = predict_domain_list[example_id]
                predict_intent = predict_intent_list[example_id]
                predict_domain_intent = other_tag
                domain_counter.update([predict_domain])
                slot = raw_query
                if predict_domain != other_tag:
                    predict_domain_intent = "%s.%s" % (predict_domain, predict_intent)
                    slot = predict_slot_list[example_id]
                # if predict_domain == "navigation": # TODO  TODO
                #     predict_domain_intent = other_tag
                #     slot = raw_query
                line = [sessionId, raw_query, predict_domain_intent, slot]
                writer.writerow(line)
        print(curLine(), "example_id=", example_id)
        print(curLine(), "domain_counter:", domain_counter)
        cost_time = (time.time() - start_time) / 60.0
        num_predicted = example_id+1
        print(curLine(), "%s cost %f s" % (target_domain_name, cost_time))
        print(
            f'{curLine()} {num_predicted} predictions saved to:{FLAGS.submit_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.')