def __init__( self, use_cuda=False, tree=True, prob=False, use_pos=False, model_files_path=None, buckets=False, batch_size=None, encoding_model="ernie-lstm", ): if model_files_path is None: if encoding_model in ["lstm", "transformer", "ernie-1.0", "ernie-tiny", "ernie-lstm"]: model_files_path = self._get_abs_path(os.path.join("./model_files/", encoding_model)) else: raise KeyError("Unknown encoding model.") if not os.path.exists(model_files_path): try: utils.download_model_from_url(model_files_path, encoding_model) except Exception as e: logging.error("Failed to download model, please try again") logging.error("error: {}".format(e)) raise e args = [ "--model_files={}".format(model_files_path), "--config_path={}".format(self._get_abs_path('config.ini')), "--encoding_model={}".format(encoding_model) ] if use_cuda: args.append("--use_cuda") if tree: args.append("--tree") if prob: args.append("--prob") if batch_size: args.append("--batch_size={}".format(batch_size)) args = ArgConfig(args) # Don't instantiate the log handle args.log_path = None self.env = Environment(args) self.args = self.env.args fluid.enable_imperative(self.env.place) self.model = load(self.args.model_path) self.model.eval() self.lac = None self.use_pos = use_pos # buckets=None if not buckets else defaults if not buckets: self.args.buckets = None if args.prob: self.env.fields = self.env.fields._replace(PHEAD=Field("prob")) if self.use_pos: self.env.fields = self.env.fields._replace(CPOS=Field("postag")) # set default batch size if batch_size is None and not buckets if batch_size is None and not buckets: self.args.batch_size = 50
def __init__(self, use_cuda=False, tree=True, prob=False, use_pos=False, model_files_path=None, buckets=False, batch_size=None): if model_files_path is None: model_files_path = self._get_abs_path('./model_files/baidu') if not os.path.exists(model_files_path): try: utils.download_model_from_url(model_files_path) except Exception as e: logging.error("Failed to download model, please try again") logging.error(f"error: {e}") return args = [ f"--model_files={model_files_path}", f"--config_path={self._get_abs_path('config.ini')}" ] if use_cuda: args.append("--use_cuda") if tree: args.append("--tree") if prob: args.append("--prob") if batch_size: args.append(f"--batch_size={batch_size}") args = ArgConfig(args) # Don't instantiate the log handle args.log_path = None self.env = Environment(args) self.args = self.env.args fluid.enable_imperative(self.env.place) self.model = load(self.args.model_path) self.lac = None self.use_pos = use_pos # buckets=None if not buckets else defaults if not buckets: self.args.buckets = None if args.prob: self.env.fields = self.env.fields._replace(PHEAD=Field('prob')) if self.use_pos: self.env.fields = self.env.fields._replace(CPOS=Field('postag')) # set default batch size if batch_size is None and not buckets if batch_size is None and not buckets: self.args.batch_size = 50
outputs = predicts.get_result() if outputs[0].get("postag", None): for output in outputs: del output["postag"] return outputs def _get_abs_path(self, path): return os.path.normpath( os.path.join(os.path.dirname(os.path.abspath(__file__)), path)) if __name__ == '__main__': logging.info("init arguments.") args = ArgConfig() logging.info("init environment.") env = Environment(args) logging.info(f"Override the default configs\n{env.args}") logging.info(f"{env.WORD}\n{env.FEAT}\n{env.ARC}\n{env.REL}") logging.info(f"Set the max num of threads to {env.args.threads}") logging.info( f"Set the seed for generating random numbers to {env.args.seed}") logging.info(f"Run the subcommand in mode {env.args.mode}") fluid.enable_imperative(env.place) mode = env.args.mode if mode == "train": train(env) elif mode == "evaluate": evaluate(env) elif mode == "predict": predict(env)