示例#1
0
    def __init__(
        self,
        use_cuda=False,
        tree=True,
        prob=False,
        use_pos=False,
        model_files_path=None,
        buckets=False,
        batch_size=None,
        encoding_model="ernie-lstm",
    ):
        if model_files_path is None:
            if encoding_model in ["lstm", "transformer", "ernie-1.0", "ernie-tiny", "ernie-lstm"]:
                model_files_path = self._get_abs_path(os.path.join("./model_files/", encoding_model))
            else:
                raise KeyError("Unknown encoding model.")

            if not os.path.exists(model_files_path):
                try:
                    utils.download_model_from_url(model_files_path, encoding_model)
                except Exception as e:
                    logging.error("Failed to download model, please try again")
                    logging.error("error: {}".format(e))
                    raise e

        args = [
            "--model_files={}".format(model_files_path), "--config_path={}".format(self._get_abs_path('config.ini')),
            "--encoding_model={}".format(encoding_model)
        ]

        if use_cuda:
            args.append("--use_cuda")
        if tree:
            args.append("--tree")
        if prob:
            args.append("--prob")
        if batch_size:
            args.append("--batch_size={}".format(batch_size))

        args = ArgConfig(args)
        # Don't instantiate the log handle
        args.log_path = None
        self.env = Environment(args)
        self.args = self.env.args
        fluid.enable_imperative(self.env.place)
        self.model = load(self.args.model_path)
        self.model.eval()
        self.lac = None
        self.use_pos = use_pos
        # buckets=None if not buckets else defaults
        if not buckets:
            self.args.buckets = None
        if args.prob:
            self.env.fields = self.env.fields._replace(PHEAD=Field("prob"))
        if self.use_pos:
            self.env.fields = self.env.fields._replace(CPOS=Field("postag"))
        # set default batch size if batch_size is None and not buckets
        if batch_size is None and not buckets:
            self.args.batch_size = 50
示例#2
0
    def __init__(self,
                 use_cuda=False,
                 tree=True,
                 prob=False,
                 use_pos=False,
                 model_files_path=None,
                 buckets=False,
                 batch_size=None):
        if model_files_path is None:
            model_files_path = self._get_abs_path('./model_files/baidu')
            if not os.path.exists(model_files_path):
                try:
                    utils.download_model_from_url(model_files_path)
                except Exception as e:
                    logging.error("Failed to download model, please try again")
                    logging.error(f"error: {e}")
                    return

        args = [
            f"--model_files={model_files_path}",
            f"--config_path={self._get_abs_path('config.ini')}"
        ]

        if use_cuda:
            args.append("--use_cuda")
        if tree:
            args.append("--tree")
        if prob:
            args.append("--prob")
        if batch_size:
            args.append(f"--batch_size={batch_size}")

        args = ArgConfig(args)
        # Don't instantiate the log handle
        args.log_path = None
        self.env = Environment(args)
        self.args = self.env.args
        fluid.enable_imperative(self.env.place)
        self.model = load(self.args.model_path)
        self.lac = None
        self.use_pos = use_pos
        # buckets=None if not buckets else defaults
        if not buckets:
            self.args.buckets = None
        if args.prob:
            self.env.fields = self.env.fields._replace(PHEAD=Field('prob'))
        if self.use_pos:
            self.env.fields = self.env.fields._replace(CPOS=Field('postag'))
        # set default batch size if batch_size is None and not buckets
        if batch_size is None and not buckets:
            self.args.batch_size = 50
示例#3
0
文件: run.py 项目: zw331/DDParser
        outputs = predicts.get_result()
        if outputs[0].get("postag", None):
            for output in outputs:
                del output["postag"]
        return outputs

    def _get_abs_path(self, path):
        return os.path.normpath(
            os.path.join(os.path.dirname(os.path.abspath(__file__)), path))


if __name__ == '__main__':
    logging.info("init arguments.")
    args = ArgConfig()
    logging.info("init environment.")
    env = Environment(args)
    logging.info(f"Override the default configs\n{env.args}")
    logging.info(f"{env.WORD}\n{env.FEAT}\n{env.ARC}\n{env.REL}")
    logging.info(f"Set the max num of threads to {env.args.threads}")
    logging.info(
        f"Set the seed for generating random numbers to {env.args.seed}")
    logging.info(f"Run the subcommand in mode {env.args.mode}")

    fluid.enable_imperative(env.place)
    mode = env.args.mode
    if mode == "train":
        train(env)
    elif mode == "evaluate":
        evaluate(env)
    elif mode == "predict":
        predict(env)