Esempio n. 1
0
    def __init__(self, config):
        vector = CrossWozVector(sys_da_voc_json=config["sys_da_voc"],
                                usr_da_voc_json=config["usr_da_voc"])

        data_preprocessor = DataPreprocessor(config, vector)
        self.data_train = data_preprocessor.create_dataset(
            "train", config["batch_size"])
        self.data_valid = data_preprocessor.create_dataset(
            "val", config["batch_size"])
        self.data_test = data_preprocessor.create_dataset(
            "tests", config["batch_size"])

        self.save_dir = config["output_dir"]
        self.print_per_batch = config["print_per_batch"]
        self.device = config["device"]

        self.policy = MultiDiscretePolicy(vector.state_dim,
                                          config["hidden_size"],
                                          vector.sys_da_dim)
        model_path = config["model_path"]
        if model_path:
            print(f"Model {model_path} loaded")
            trained_model_params = torch.load(model_path)
            self.policy.load_state_dict(trained_model_params)

        self.optimizer = torch.optim.Adam(self.policy.parameters(),
                                          lr=config["learning_rate"])

        self.policy.to(self.device)
        if config["n_gpus"] > 0:
            self.policy = nn.DataParallel(self.policy)

        self.multi_entropy_loss = nn.MultiLabelSoftMarginLoss()
Esempio n. 2
0
    def __init__(self, config):
        vector = CrossWozVector(sys_da_voc_json=config['sys_da_voc'],
                                usr_da_voc_json=config['usr_da_voc'])

        data_preprocessor = DataPreprocessor(config, vector)
        self.data_train = data_preprocessor.create_dataset('train', config['batch_size'])
        self.data_valid = data_preprocessor.create_dataset('val', config['batch_size'])
        self.data_test = data_preprocessor.create_dataset('test', config['batch_size'])

        self.save_dir = config['output_dir']
        self.print_per_batch = config['print_per_batch']
        self.device = config['device']

        self.policy = MultiDiscretePolicy(vector.state_dim, config['hidden_size'], vector.sys_da_dim)
        model_path = config['model_path']
        if model_path:
            print(f'Model {model_path} loaded')
            trained_model_params = torch.load(model_path)
            self.policy.load_state_dict(trained_model_params)

        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=config['learning_rate'])

        self.policy.to(self.device)
        if config['n_gpus'] > 0:
            self.policy = nn.DataParallel(self.policy)

        self.multi_entropy_loss = nn.MultiLabelSoftMarginLoss()
Esempio n. 3
0
    def __init__(self):
        super(MLEPolicy, self).__init__()
        # load config
        common_config_path = os.path.join(get_config_path(),
                                          MLEPolicy.common_config_name)
        common_config = json.load(open(common_config_path))
        model_config_path = os.path.join(get_config_path(),
                                         MLEPolicy.model_config_name)
        model_config = json.load(open(model_config_path))
        model_config.update(common_config)
        self.model_config = model_config
        self.model_config["data_path"] = os.path.join(
            get_data_path(), "crosswoz/policy_mle_data")
        self.model_config["n_gpus"] = (0 if self.model_config["device"]
                                       == "cpu" else torch.cuda.device_count())
        self.model_config["device"] = torch.device(self.model_config["device"])

        # download data
        for model_key, url in MLEPolicy.model_urls.items():
            dst = os.path.join(self.model_config["data_path"], model_key)
            file_name = (model_key.split(".")[0]
                         if not model_key.endswith("pth") else
                         "trained_model_path")
            self.model_config[file_name] = dst
            if not os.path.exists(dst) or not self.model_config["use_cache"]:
                download_from_url(url, dst)

        self.vector = CrossWozVector(
            sys_da_voc_json=self.model_config["sys_da_voc"],
            usr_da_voc_json=self.model_config["usr_da_voc"],
        )

        policy = MultiDiscretePolicy(self.vector.state_dim,
                                     model_config["hidden_size"],
                                     self.vector.sys_da_dim)

        policy.load_state_dict(
            torch.load(self.model_config["trained_model_path"]))

        self.policy = policy.to(self.model_config["device"]).eval()
        print(f'>>> {self.model_config["trained_model_path"]} loaded ...')
Esempio n. 4
0
class MLEPolicy(Policy):
    model_config_name = "policy/mle/inference.json"
    common_config_name = "policy/mle/common.json"

    model_urls = {
        "model.pth": "",
        "sys_da_voc.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/usr_da_voc.json",
        "usr_da_voc.json":
        "http://qiw2jpwfc.hn-bkt.clouddn.com/usr_da_voc.json",
    }

    def __init__(self):
        super(MLEPolicy, self).__init__()
        # load config
        common_config_path = os.path.join(get_config_path(),
                                          MLEPolicy.common_config_name)
        common_config = json.load(open(common_config_path))
        model_config_path = os.path.join(get_config_path(),
                                         MLEPolicy.model_config_name)
        model_config = json.load(open(model_config_path))
        model_config.update(common_config)
        self.model_config = model_config
        self.model_config["data_path"] = os.path.join(
            get_data_path(), "crosswoz/policy_mle_data")
        self.model_config["n_gpus"] = (0 if self.model_config["device"]
                                       == "cpu" else torch.cuda.device_count())
        self.model_config["device"] = torch.device(self.model_config["device"])

        # download data
        for model_key, url in MLEPolicy.model_urls.items():
            dst = os.path.join(self.model_config["data_path"], model_key)
            file_name = (model_key.split(".")[0]
                         if not model_key.endswith("pth") else
                         "trained_model_path")
            self.model_config[file_name] = dst
            if not os.path.exists(dst) or not self.model_config["use_cache"]:
                download_from_url(url, dst)

        self.vector = CrossWozVector(
            sys_da_voc_json=self.model_config["sys_da_voc"],
            usr_da_voc_json=self.model_config["usr_da_voc"],
        )

        policy = MultiDiscretePolicy(self.vector.state_dim,
                                     model_config["hidden_size"],
                                     self.vector.sys_da_dim)

        policy.load_state_dict(
            torch.load(self.model_config["trained_model_path"]))

        self.policy = policy.to(self.model_config["device"]).eval()
        print(f'>>> {self.model_config["trained_model_path"]} loaded ...')

    def init_session(self):
        pass

    def predict(self, state):
        s_vec = torch.tensor(self.vector.state_vectorize(state),
                             device=self.model_config["device"])
        a = self.policy.select_action(s_vec, sample=False).cpu().numpy()
        action = self.vector.action_devectorize(a)
        state["system_action"] = action
        return action