示例#1
0
def main():
    model_config_name = "dst/bert/train.json"
    common_config_name = "dst/bert/common.json"

    data_urls = {
        "train4bert_dst.json":
        "http://src.xbot.bslience.cn/train4bert_dst.json",
        "dev4bert_dst.json": "http://src.xbot.bslience.cn/dev4bert_dst.json",
        "test4bert_dst.json": "http://src.xbot.bslience.cn/test4bert_dst.json",
        "cleaned_ontology.json":
        "http://src.xbot.bslience.cn/cleaned_ontology.json",
        "config.json":
        "http://src.xbot.bslience.cn/bert-base-chinese/config.json",
        "pytorch_model.bin":
        "http://src.xbot.bslience.cn/bert-base-chinese/pytorch_model.bin",
        "vocab.txt": "http://src.xbot.bslience.cn/bert-base-chinese/vocab.txt",
    }

    # load config
    root_path = get_root_path()
    common_config_path = os.path.join(get_config_path(), common_config_name)
    train_config_path = os.path.join(get_config_path(), model_config_name)
    common_config = json.load(open(common_config_path))
    train_config = json.load(open(train_config_path))
    train_config.update(common_config)
    train_config["n_gpus"] = torch.cuda.device_count()
    train_config["train_batch_size"] = (max(1, train_config["n_gpus"]) *
                                        train_config["train_batch_size"])
    train_config["device"] = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")

    train_config["data_path"] = os.path.join(get_data_path(),
                                             "crosswoz/dst_bert_data")
    train_config["output_dir"] = os.path.join(root_path,
                                              train_config["output_dir"])
    if not os.path.exists(train_config["data_path"]):
        os.makedirs(train_config["data_path"])
    if not os.path.exists(train_config["output_dir"]):
        os.makedirs(train_config["output_dir"])

    # download data
    for data_key, url in data_urls.items():
        dst = os.path.join(train_config["data_path"], data_key)
        file_name = data_key.split(".")[0]
        train_config[file_name] = dst
        if not os.path.exists(dst):
            download_from_url(url, dst)

    # train
    trainer = Trainer(train_config)
    trainer.train()
    trainer.eval_test()
    get_recall(train_config["data_path"])
示例#2
0
    def load_config() -> dict:
        """Load config for inference.

        Returns:
            config dict
        """
        common_config_path = os.path.join(get_config_path(),
                                          BertPolicy.common_config_name)
        infer_config_path = os.path.join(get_config_path(),
                                         BertPolicy.inference_config_name)
        common_config = load_json(common_config_path)
        infer_config = load_json(infer_config_path)
        infer_config.update(common_config)
        infer_config["device"] = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        infer_config["data_path"] = os.path.join(get_data_path(),
                                                 "crosswoz/policy_bert_data")
        if not os.path.exists(infer_config["data_path"]):
            os.makedirs(infer_config["data_path"])
        return infer_config
示例#3
0
def update_config(common_config_name, train_config_name, task_path):
    root_path = get_root_path()
    common_config_path = os.path.join(get_config_path(), common_config_name)
    train_config_path = os.path.join(get_config_path(), train_config_name)
    common_config = json.load(open(common_config_path))
    train_config = json.load(open(train_config_path))
    train_config.update(common_config)
    train_config["n_gpus"] = torch.cuda.device_count()
    train_config["train_batch_size"] = (max(1, train_config["n_gpus"]) *
                                        train_config["train_batch_size"])
    train_config["device"] = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")
    train_config["data_path"] = os.path.join(get_data_path(), task_path)
    train_config["output_dir"] = os.path.join(root_path,
                                              train_config["output_dir"])
    if not os.path.exists(train_config["data_path"]):
        os.makedirs(train_config["data_path"])
    if not os.path.exists(train_config["output_dir"]):
        os.makedirs(train_config["output_dir"])
    return train_config
示例#4
0
文件: mle.py 项目: zy12105228/xbot
    def __init__(self):
        super(MLEPolicy, self).__init__()
        # load config
        common_config_path = os.path.join(get_config_path(),
                                          MLEPolicy.common_config_name)
        common_config = json.load(open(common_config_path))
        model_config_path = os.path.join(get_config_path(),
                                         MLEPolicy.model_config_name)
        model_config = json.load(open(model_config_path))
        model_config.update(common_config)
        self.model_config = model_config
        self.model_config["data_path"] = os.path.join(
            get_data_path(), "crosswoz/policy_mle_data")
        self.model_config["n_gpus"] = (0 if self.model_config["device"]
                                       == "cpu" else torch.cuda.device_count())
        self.model_config["device"] = torch.device(self.model_config["device"])

        # download data
        for model_key, url in MLEPolicy.model_urls.items():
            dst = os.path.join(self.model_config["data_path"], model_key)
            file_name = (model_key.split(".")[0]
                         if not model_key.endswith("pth") else
                         "trained_model_path")
            self.model_config[file_name] = dst
            if not os.path.exists(dst) or not self.model_config["use_cache"]:
                download_from_url(url, dst)

        self.vector = CrossWozVector(
            sys_da_voc_json=self.model_config["sys_da_voc"],
            usr_da_voc_json=self.model_config["usr_da_voc"],
        )

        policy = MultiDiscretePolicy(self.vector.state_dim,
                                     model_config["hidden_size"],
                                     self.vector.sys_da_dim)

        policy.load_state_dict(
            torch.load(self.model_config["trained_model_path"]))

        self.policy = policy.to(self.model_config["device"]).eval()
        print(f'>>> {self.model_config["trained_model_path"]} loaded ...')
示例#5
0
    def load_config() -> dict:
        """Load config from common config and inference config from src/xbot/config/dst/bert .

        Returns:
            config dict
        """
        root_path = get_root_path()
        common_config_path = os.path.join(get_config_path(), BertDST.common_config_name)
        infer_config_path = os.path.join(get_config_path(), BertDST.infer_config_name)
        common_config = json.load(open(common_config_path))
        infer_config = json.load(open(infer_config_path))
        infer_config.update(common_config)
        infer_config["device"] = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        infer_config["data_path"] = os.path.join(
            get_data_path(), "crosswoz/dst_bert_data"
        )
        infer_config["output_dir"] = os.path.join(root_path, infer_config["output_dir"])
        if not os.path.exists(infer_config["data_path"]):
            os.makedirs(infer_config["data_path"])
        if not os.path.exists(infer_config["output_dir"]):
            os.makedirs(infer_config["output_dir"])
        return infer_config
示例#6
0
    def __init__(self):
        # path
        root_path = get_root_path()

        config_file = os.path.join(
            get_config_path(), IntentWithBertPredictor.default_model_config
        )

        # load config
        config = json.load(open(config_file))
        self.device = config["DEVICE"]

        # load intent vocabulary and dataloader
        intent_vocab = json.load(
            open(
                os.path.join(
                    get_data_path(), "crosswoz/nlu_intent_data/intent_vocab.json"
                ),
                encoding="utf-8",
            )
        )
        dataloader = Dataloader(
            intent_vocab=intent_vocab,
            pretrained_weights=config["model"]["pretrained_weights"],
        )
        # load best model
        best_model_path = os.path.join(
            os.path.join(root_path, DEFAULT_MODEL_PATH),
            IntentWithBertPredictor.default_model_name,
        )
        # best_model_path = os.path.join(DEFAULT_MODEL_PATH, IntentWithBertPredictor.default_model_name)
        if not os.path.exists(best_model_path):
            download_from_url(
                IntentWithBertPredictor.default_model_url, best_model_path
            )
        model = IntentWithBert(config["model"], self.device, dataloader.intent_dim)
        model.load_state_dict(torch.load(best_model_path, map_location=self.device))

        model.to(self.device)
        model.eval()
        self.model = model
        self.dataloader = dataloader
        print(f"{best_model_path} loaded - {best_model_path}")
示例#7
0
    np.random.seed(seed)
    torch.manual_seed(seed)


if __name__ == "__main__":
    data_urls = {
        "intent_train_data.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/intent_train_data.json",
        "intent_val_data.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/intent_val_data.json",
        "intent_test_data.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/intent_test_data.json",
    }
    # load config
    root_path = get_root_path()
    config_path = os.path.join(os.path.join(get_config_path(), "nlu"),
                               "crosswoz_all_context_nlu_intent.json")
    config = json.load(open(config_path))
    data_path = os.path.join(get_data_path(), "crosswoz/nlu_intent_data/")
    output_dir = config["output_dir"]
    output_dir = os.path.join(root_path, output_dir)
    log_dir = config["log_dir"]
    log_dir = os.path.join(root_path, log_dir)
    device = config["DEVICE"]

    # download data
    for data_key, url in data_urls.items():
        dst = os.path.join(os.path.join(data_path, data_key))
        if not os.path.exists(dst):
            download_from_url(url, dst)
示例#8
0
def main():
    model_config_name = "policy/mle/train.json"
    common_config_name = "policy/mle/common.json"

    data_urls = {
        "sys_da_voc.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/usr_da_voc.json",
        "usr_da_voc.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/usr_da_voc.json",
    }

    # load config
    root_path = get_root_path()
    common_config_path = os.path.join(get_config_path(), common_config_name)
    model_config_path = os.path.join(get_config_path(), model_config_name)
    common_config = json.load(open(common_config_path))
    model_config = json.load(open(model_config_path))
    model_config.update(common_config)

    model_config["n_gpus"] = torch.cuda.device_count()
    model_config["batch_size"] = (max(1, model_config["n_gpus"]) *
                                  model_config["batch_size"])
    model_config["device"] = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")

    model_config["data_path"] = os.path.join(get_data_path(),
                                             "crosswoz/policy_mle_data")
    model_config["raw_data_path"] = os.path.join(get_data_path(),
                                                 "crosswoz/raw")
    model_config["output_dir"] = os.path.join(root_path,
                                              model_config["output_dir"])
    if model_config["load_model_name"]:
        model_config["model_path"] = os.path.join(
            model_config["output_dir"], model_config["load_model_name"])
    else:
        model_config["model_path"] = ""
    if not os.path.exists(model_config["data_path"]):
        os.makedirs(model_config["data_path"])
    if not os.path.exists(model_config["output_dir"]):
        os.makedirs(model_config["output_dir"])

    # download data
    for data_key, url in data_urls.items():
        dst = os.path.join(model_config["data_path"], data_key)
        file_name = data_key.split(".")[0]
        model_config[file_name] = dst
        if not os.path.exists(dst):
            download_from_url(url, dst)

    print(f">>> Train configs:")
    print("\t", model_config)

    set_seed(model_config["random_seed"])

    agent = Trainer(model_config)

    # 训练
    if model_config["do_train"]:
        start_epoch = (0 if not model_config["model_path"] else
                       int(model_config["model_path"].split("-")[2]) + 1)
        best = float("inf")
        for epoch in tqdm(range(start_epoch, model_config["num_epochs"]),
                          desc="Epoch"):
            agent.imitating(epoch)
            best = agent.imit_eval(epoch, best)

    agent.calc_metrics()
示例#9
0
    torch.manual_seed(seed)


if __name__ == "__main__":
    data_urls = {
        "intent_train_data.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/intent_train_data.json",
        "intent_val_data.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/intent_val_data.json",
        "intent_test_data.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/intent_test_data.json",
    }

    # path
    root_path = get_root_path()
    config_file = os.path.join(get_config_path(),
                               IntentWithBertPredictor.default_model_config)
    config = json.load(open(config_file))
    data_dir = os.path.join(get_data_path(), "crosswoz/nlu_intent_data/")
    output_dir = config["output_dir"]
    output_dir = os.path.join(root_path, output_dir)
    log_dir = config["log_dir"]
    log_dir = os.path.join(root_path, log_dir)
    device = config["DEVICE"]

    # download data
    for data_key, url in data_urls.items():
        dst = os.path.join(os.path.join(data_dir, data_key))
        if not os.path.exists(dst):
            download_from_url(url, dst)
示例#10
0
def main():
    data_urls = {
        "intent_train_data.json": "http://xbot.bslience.cn/intent_train_data.json",
        "intent_val_data.json": "http://xbot.bslience.cn/intent_val_data.json",
        "intent_test_data.json": "http://xbot.bslience.cn/intent_test_data.json",
    }
    # load config
    root_path = get_root_path()
    config_path = os.path.join(
        os.path.join(get_config_path(), "nlu"), "crosswoz_all_context_nlu_intent.json"
    )
    config = json.load(open(config_path))
    data_path = config["data_dir"]
    data_path = os.path.join(root_path, data_path)
    output_dir = config["output_dir"]
    output_dir = os.path.join(root_path, output_dir)
    log_dir = config["log_dir"]
    log_dir = os.path.join(root_path, log_dir)
    device = config["DEVICE"]

    # download data
    for data_key, url in data_urls.items():
        dst = os.path.join(os.path.join(data_path, data_key))
        if not os.path.exists(dst):
            download_from_url(url, dst)

    # seed
    set_seed(config["seed"])

    # load intent vocabulary and dataloader
    intent_vocab = json.load(
        open(os.path.join(data_path, "intent_vocab.json"), encoding="utf-8")
    )
    dataloader = Dataloader(
        intent_vocab=intent_vocab,
        pretrained_weights=config["model"]["pretrained_weights"],
    )

    # load data
    for data_key in ["train", "val", "test"]:
        dataloader.load_data(
            json.load(
                open(
                    os.path.join(data_path, "intent_{}_data.json".format(data_key)),
                    encoding="utf-8",
                )
            ),
            data_key,
            cut_sen_len=config["cut_sen_len"],
            use_bert_tokenizer=config["use_bert_tokenizer"],
        )
        print("{} set size: {}".format(data_key, len(dataloader.data[data_key])))

    # output and log dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    writer = SummaryWriter(log_dir)

    # model
    model = IntentWithBert(
        config["model"], device, dataloader.intent_dim, dataloader.intent_weight
    )
    model.to(device)

    # optimizer and scheduler
    if config["model"]["finetune"]:
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": config["model"]["weight_decay"],
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=config["model"]["learning_rate"],
            eps=config["model"]["adam_epsilon"],
        )
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=config["model"]["warmup_steps"],
            num_training_steps=config["model"]["max_step"],
        )
    else:
        for n, p in model.named_parameters():
            if "bert_policy" in n:
                p.requires_grad = False
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=config["model"]["learning_rate"],
        )

    max_step = config["model"]["max_step"]
    check_step = config["model"]["check_step"]
    batch_size = config["model"]["batch_size"]
    model.zero_grad()
    train_intent_loss = 0
    best_val_f1 = 0.0

    writer.add_text("config", json.dumps(config))

    for step in range(1, max_step + 1):
        model.train()
        batched_data = dataloader.get_train_batch(batch_size)
        batched_data = tuple(t.to(device) for t in batched_data)
        word_seq_tensor, word_mask_tensor, intent_tensor = batched_data
        intent_logits, intent_loss = model.forward(
            word_seq_tensor, word_mask_tensor, intent_tensor
        )

        train_intent_loss += intent_loss.item()
        loss = intent_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if config["model"]["finetune"]:
            scheduler.step()  # Update learning rate schedule

        model.zero_grad()
        if step % check_step == 0:
            train_intent_loss = train_intent_loss / check_step
            print("[%d|%d] step" % (step, max_step))
            print("\t intent loss:", train_intent_loss)

            predict_golden = {"intent": []}

            val_intent_loss = 0
            model.eval()
            for pad_batch, ori_batch, real_batch_size in dataloader.yield_batches(
                    batch_size, data_key="val"
            ):
                pad_batch = tuple(t.to(device) for t in pad_batch)
                word_seq_tensor, word_mask_tensor, intent_tensor = pad_batch

                with torch.no_grad():
                    intent_logits, intent_loss = model.forward(
                        word_seq_tensor, word_mask_tensor, intent_tensor
                    )

                val_intent_loss += intent_loss.item() * real_batch_size
                for j in range(real_batch_size):
                    predicts = recover_intent(dataloader, intent_logits[j])
                    labels = ori_batch[j][1]

                    predict_golden["intent"].append(
                        {
                            "predict": [x for x in predicts],
                            "golden": [x for x in labels],
                        }
                    )

            total = len(dataloader.data["val"])
            val_intent_loss /= total
            print("%d samples val" % total)
            print("\t intent loss:", val_intent_loss)

            writer.add_scalar("intent_loss/train", train_intent_loss, global_step=step)
            writer.add_scalar("intent_loss/val", val_intent_loss, global_step=step)

            for x in ["intent"]:
                precision, recall, F1 = calculate_f1(predict_golden[x])
                print("-" * 20 + x + "-" * 20)
                print("\t Precision: %.2f" % (100 * precision))
                print("\t Recall: %.2f" % (100 * recall))
                print("\t F1: %.2f" % (100 * F1))

                writer.add_scalar(
                    "val_{}/precision".format(x), precision, global_step=step
                )
                writer.add_scalar("val_{}/recall".format(x), recall, global_step=step)
                writer.add_scalar("val_{}/F1".format(x), F1, global_step=step)

            if F1 > best_val_f1:
                best_val_f1 = F1
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, "pytorch-intent-with-bert_policy.pt"),
                )
                print("best val F1 %.4f" % best_val_f1)
                print("save on", output_dir)

            train_intent_loss = 0

    writer.add_text("val intent F1", "%.2f" % (100 * best_val_f1))
    writer.close()

    model_path = os.path.join(output_dir, "pytorch-intent-with-bert_policy.pt")  ##存放模型
    zip_path = config["zipped_model_path"]
    zip_path = os.path.join(root_path, zip_path)
    print("zip model to", zip_path)

    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:  ##存放压缩模型
        zf.write(model_path)
示例#11
0
文件: trade.py 项目: zy12105228/xbot
    def __init__(self):
        super(TradeDST, self).__init__()
        # load config
        common_config_path = os.path.join(get_config_path(),
                                          TradeDST.common_config_name)
        common_config = json.load(open(common_config_path))
        model_config_path = os.path.join(get_config_path(),
                                         TradeDST.model_config_name)
        model_config = json.load(open(model_config_path))
        model_config.update(common_config)
        self.model_config = model_config
        self.model_config["data_path"] = os.path.join(
            get_data_path(), "crosswoz/dst_trade_data")
        self.model_config["n_gpus"] = (0 if self.model_config["device"]
                                       == "cpu" else torch.cuda.device_count())
        self.model_config["device"] = torch.device(self.model_config["device"])
        if model_config["load_embedding"]:
            model_config["hidden_size"] = 300

        # download data
        for model_key, url in TradeDST.model_urls.items():
            dst = os.path.join(self.model_config["data_path"], model_key)
            if model_key.endswith("pth"):
                file_name = "trained_model_path"
            elif model_key.endswith("pkl"):
                file_name = model_key.rsplit("-", maxsplit=1)[0]
            else:
                file_name = model_key.split(".")[0]  # ontology
            self.model_config[file_name] = dst
            if not os.path.exists(dst) or not self.model_config["use_cache"]:
                download_from_url(url, dst)

        # load date & model
        ontology = json.load(
            open(self.model_config["ontology"], "r", encoding="utf8"))
        self.all_slots = get_slot_information(ontology)
        self.gate2id = {"ptr": 0, "none": 1}
        self.id2gate = {id_: gate for gate, id_ in self.gate2id.items()}
        self.lang = pickle.load(open(self.model_config["lang"], "rb"))
        self.mem_lang = pickle.load(open(self.model_config["mem-lang"], "rb"))

        model = Trade(
            lang=self.lang,
            vocab_size=len(self.lang.index2word),
            hidden_size=self.model_config["hidden_size"],
            dropout=self.model_config["dropout"],
            num_encoder_layers=self.model_config["num_encoder_layers"],
            num_decoder_layers=self.model_config["num_decoder_layers"],
            pad_id=self.model_config["pad_id"],
            slots=self.all_slots,
            num_gates=len(self.gate2id),
            unk_mask=self.model_config["unk_mask"],
        )

        model.load_state_dict(
            torch.load(self.model_config["trained_model_path"]))

        self.model = model.to(self.model_config["device"]).eval()
        print(f'>>> {self.model_config["trained_model_path"]} loaded ...')
        self.state = default_state()
        print(">>> State initialized ...")
示例#12
0
文件: train.py 项目: zy12105228/xbot
def main():
    model_config_name = "dst/trade/train.json"
    common_config_name = "dst/trade/common.json"

    data_urls = {
        "train_dials.json": "http://xbot.bslience.cn/train_dials.json",
        "dev_dials.json": "http://xbot.bslience.cn/dev_dials.json",
        "test_dials.json": "http://xbot.bslience.cn/test_dials.json",
        "ontology.json": "http://xbot.bslience.cn/ontology.json",
        "sgns.wiki.bigram.bz2": "http://xbot.bslience.cn/sgns.wiki.bigram.bz2",
    }

    # load config
    root_path = get_root_path()
    common_config_path = os.path.join(get_config_path(), common_config_name)
    model_config_path = os.path.join(get_config_path(), model_config_name)
    common_config = json.load(open(common_config_path))
    model_config = json.load(open(model_config_path))
    model_config.update(common_config)
    model_config["n_gpus"] = torch.cuda.device_count()
    model_config["batch_size"] = (max(1, model_config["n_gpus"]) *
                                  model_config["batch_size"])
    model_config["device"] = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")
    if model_config["load_embedding"]:
        model_config["hidden_size"] = 300

    model_config["data_path"] = os.path.join(get_data_path(),
                                             "crosswoz/dst_trade_data")
    model_config["output_dir"] = os.path.join(
        root_path, model_config["output_dir"])  # 可以用来保存模型文件
    if model_config["load_model_name"]:
        model_config["model_path"] = os.path.join(
            model_config["output_dir"], model_config["load_model_name"])
    else:
        model_config["model_path"] = ""
    if not os.path.exists(model_config["data_path"]):
        os.makedirs(model_config["data_path"])
    if not os.path.exists(model_config["output_dir"]):
        os.makedirs(model_config["output_dir"])

    # download data
    for data_key, url in data_urls.items():
        dst = os.path.join(model_config["data_path"], data_key)
        if "_" in data_key:
            file_name = data_key.split(".")[0]
        elif "wiki.bigram" in data_key:
            file_name = "orig_pretrained_embedding"
        else:
            file_name = data_key.split(".")[0]  # ontology
        model_config[file_name] = dst
        if not os.path.exists(dst):
            download_from_url(url, dst)

    avg_best, cnt, acc = 0.0, 0, 0.0

    # 数据预处理
    train, dev, test, langs, slots, gating_dict = prepare_data_seq(
        model_config)
    lang = langs[0]
    model_config["pretrained_embedding_path"] = os.path.join(
        model_config["data_path"], f"emb{len(lang.index2word)}")

    print(f">>> Train configs:")
    print("\t", model_config)

    # 初始化训练
    trainer = Trainer(config=model_config,
                      langs=langs,
                      gating_dict=gating_dict,
                      slots=slots)

    # 训练
    start_epoch = (0 if not model_config["model_path"] else
                   int(model_config["model_path"].split("-")[2]) + 1)

    for epoch in tqdm(range(start_epoch, model_config["num_epochs"]),
                      desc="Epoch"):
        progress_bar = tqdm(enumerate(train), total=len(train))

        for i, data in progress_bar:
            trainer.train_batch(data, slots, reset=(i == 0))
            trainer.optimize(int(model_config["grad_clip"]))
            progress_bar.set_description(trainer.print_loss())

        if (epoch + 1) % int(model_config["eval_steps"]) == 0:

            acc = trainer.evaluate(dev, avg_best, slots, epoch,
                                   model_config["early_stop"])
            trainer.scheduler.step(acc)

            if acc >= avg_best:
                avg_best = acc
                cnt = 0
            else:
                cnt += 1

            if cnt == model_config["patience"] or (
                    acc == 1.0 and model_config["early_stop"] is None):
                print("Ran out of patient, early stop...")
                break