예제 #1
0
def merge_raw_date(data_type: str) -> None:
    """Merge belief state data into user turn

    Args:
        data_type: train, dev or test
    """
    data_path = get_data_path()
    output_dir = os.path.join(data_path, 'crosswoz/dst_bert_data')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    dials_path = os.path.join(data_path, 'crosswoz/dst_trade_data', f'{data_type}_dials.json')
    raw_filename = "val" if data_type == "dev" else data_type
    raw_path = os.path.join(data_path, 'crosswoz/raw', f'{raw_filename}.json.zip')
    dials = json.load(open(dials_path, 'r', encoding='utf8'))
    raw = read_zipped_json(raw_path, f'{raw_filename}.json')

    merge_data = {}
    for dial in tqdm(dials, desc=f'Merging {data_type}'):
        dialogue_idx = dial['dialogue_idx']
        cur_raw = raw[dialogue_idx]
        merge_data[dialogue_idx] = cur_raw
        for turn_id, turn in enumerate(dial['dialogue']):
            assert merge_data[dialogue_idx]['messages'][2 * turn_id]['role'] == 'usr'
            merge_data[dialogue_idx]['messages'][2 * turn_id]['belief_state'] = turn['belief_state']

    with open(os.path.join(output_dir, f'{data_type}4bert_dst.json'), 'w', encoding='utf8') as f:
        json.dump(merge_data, f, ensure_ascii=False, indent=2)
예제 #2
0
def merge_raw_date(data_type: str) -> None:
    """Merge belief state data into user turn

    Args:
        data_type: train, dev or tests
    """
    data_path = get_data_path()
    output_dir = os.path.join(data_path, "crosswoz/dst_bert_data")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    dials_path = os.path.join(data_path, "crosswoz/dst_trade_data",
                              f"{data_type}_dials.json")
    raw_filename = "val" if data_type == "dev" else data_type
    raw_path = os.path.join(data_path, "crosswoz/raw",
                            f"{raw_filename}.json.zip")
    dials = json.load(open(dials_path, "r", encoding="utf8"))
    raw = read_zipped_json(raw_path, f"{raw_filename}.json")

    merge_data = {}
    for dial in tqdm(dials, desc=f"Merging {data_type}"):
        dialogue_idx = dial["dialogue_idx"]
        cur_raw = raw[dialogue_idx]
        merge_data[dialogue_idx] = cur_raw
        for turn_id, turn in enumerate(dial["dialogue"]):
            assert merge_data[dialogue_idx]["messages"][
                2 * turn_id]["role"] == "usr"
            merge_data[dialogue_idx]["messages"][
                2 * turn_id]["belief_state"] = turn["belief_state"]

    with open(os.path.join(output_dir, f"{data_type}4bert_dst.json"),
              "w",
              encoding="utf8") as f:
        json.dump(merge_data, f, ensure_ascii=False, indent=2)
예제 #3
0
def get_single_domain_examples(input_file_path, output_file_path):
    file_in_zip = os.path.basename(input_file_path).rsplit('.', maxsplit=1)[0]
    dataset = read_zipped_json(input_file_path, file_in_zip)

    single_domain_examples = {
        id_: dialogue
        for id_, dialogue in dataset.items() if dialogue['type'] == '单领域'
    }

    print(
        f'{file_in_zip} total has {len(single_domain_examples)} single domain examples'
    )

    dump_json(single_domain_examples, output_file_path)
예제 #4
0
파일: utils.py 프로젝트: zem9401/xbot_mle
    def get_act_ontology(raw_data_path: str, output_path: str) -> None:
        """Generate action ontology from raw train data.

        Args:
            raw_data_path: raw train data path
            output_path: save path of action ontology file
        """
        raw_data = read_zipped_json(raw_data_path, "train.json")

        act_ontology = set()
        for dial_id, dial in tqdm(raw_data.items(),
                                  desc="Generate action ontology ..."):
            for turn_id, turn in enumerate(dial["messages"]):
                if turn["role"] == "sys" and turn["dialog_act"]:
                    for da in turn["dialog_act"]:
                        act = "-".join([da[1], da[0], da[2]])
                        act_ontology.add(act)

        dump_json(list(act_ontology), output_path)
예제 #5
0
파일: utils.py 프로젝트: zem9401/xbot_mle
def preprocess(raw_data_path: str, output_path: str,
               data_type: str) -> List[Dict[str, list]]:
    """Preprocess raw data to generate model inputs.

    Args:
        raw_data_path: raw (train, dev, tests) data path
        output_path: save path of precessed data file
        data_type: train, dev or tests

    Returns:
        precessed data
    """
    raw_data = read_zipped_json(raw_data_path, data_type)

    examples = []
    for dial_id, dial in tqdm(raw_data.items(),
                              desc=f"Preprocessing {data_type}"):
        sys_utter = "对话开始"
        usr_utter = "对话开始"
        for turn_id, turn in enumerate(dial["messages"]):
            if turn["role"] == "usr":
                usr_utter = turn["content"]
            elif turn["dialog_act"]:

                cur_domain, act_vecs = get_label_and_domain(turn)
                source = get_source(cur_domain, turn)

                example = {
                    "dial_id": dial_id,
                    "turn_id": turn_id,
                    "source": source,
                    "sys_utter": sys_utter,
                    "usr_utter": usr_utter,
                    "label": act_vecs,
                }

                examples.append(example)

                sys_utter = turn["content"]

    cache_data(data_type, examples, output_path)
    return examples
예제 #6
0
            if value is not None:
                break
        return value


if __name__ == "__main__":
    from xbot.dm.dst.rule_dst.rule import RuleDST
    from xbot.util.path import get_data_path
    from xbot.util.file_util import read_zipped_json
    from script.policy.rule.rule_test import eval_metrics
    from tqdm import tqdm

    rule_dst = RuleDST()
    bert_policy = BertPolicy()
    train_path = os.path.join(get_data_path(), "crosswoz/raw/train.json.zip")
    train_examples = read_zipped_json(train_path, "train.json")

    sys_state_action_pairs = {}
    for id_, dialogue in tqdm(train_examples.items()):
        sys_state_action_pair = {}
        sess = dialogue["messages"]
        rule_dst.init_session()
        for i, turn in enumerate(sess):
            if turn["role"] == "usr":
                rule_dst.update(usr_da=turn["dialog_act"])
                rule_dst.state["user_action"].clear()
                rule_dst.state["user_action"].extend(turn["dialog_act"])
                rule_dst.state["history"].append(["usr", turn["content"]])
                if i + 2 == len(sess):
                    rule_dst.state["terminated"] = True
            else: