예제 #1
0
    def __init__(
        self,
        use_cuda=False,
        tree=True,
        prob=False,
        use_pos=False,
        model_files_path=None,
        buckets=False,
        batch_size=None,
        encoding_model="ernie-lstm",
    ):
        if model_files_path is None:
            if encoding_model in ["lstm", "transformer", "ernie-1.0", "ernie-tiny", "ernie-lstm"]:
                model_files_path = self._get_abs_path(os.path.join("./model_files/", encoding_model))
            else:
                raise KeyError("Unknown encoding model.")

            if not os.path.exists(model_files_path):
                try:
                    utils.download_model_from_url(model_files_path, encoding_model)
                except Exception as e:
                    logging.error("Failed to download model, please try again")
                    logging.error("error: {}".format(e))
                    raise e

        args = [
            "--model_files={}".format(model_files_path), "--config_path={}".format(self._get_abs_path('config.ini')),
            "--encoding_model={}".format(encoding_model)
        ]

        if use_cuda:
            args.append("--use_cuda")
        if tree:
            args.append("--tree")
        if prob:
            args.append("--prob")
        if batch_size:
            args.append("--batch_size={}".format(batch_size))

        args = ArgConfig(args)
        # Don't instantiate the log handle
        args.log_path = None
        self.env = Environment(args)
        self.args = self.env.args
        fluid.enable_imperative(self.env.place)
        self.model = load(self.args.model_path)
        self.model.eval()
        self.lac = None
        self.use_pos = use_pos
        # buckets=None if not buckets else defaults
        if not buckets:
            self.args.buckets = None
        if args.prob:
            self.env.fields = self.env.fields._replace(PHEAD=Field("prob"))
        if self.use_pos:
            self.env.fields = self.env.fields._replace(CPOS=Field("postag"))
        # set default batch size if batch_size is None and not buckets
        if batch_size is None and not buckets:
            self.args.batch_size = 50
예제 #2
0
    def __init__(self,
                 use_cuda=False,
                 tree=True,
                 prob=False,
                 use_pos=False,
                 model_files_path=None,
                 buckets=False,
                 batch_size=None):
        if model_files_path is None:
            model_files_path = self._get_abs_path('./model_files/baidu')
            if not os.path.exists(model_files_path):
                try:
                    utils.download_model_from_url(model_files_path)
                except Exception as e:
                    logging.error("Failed to download model, please try again")
                    logging.error(f"error: {e}")
                    return

        args = [
            f"--model_files={model_files_path}",
            f"--config_path={self._get_abs_path('config.ini')}"
        ]

        if use_cuda:
            args.append("--use_cuda")
        if tree:
            args.append("--tree")
        if prob:
            args.append("--prob")
        if batch_size:
            args.append(f"--batch_size={batch_size}")

        args = ArgConfig(args)
        # Don't instantiate the log handle
        args.log_path = None
        self.env = Environment(args)
        self.args = self.env.args
        fluid.enable_imperative(self.env.place)
        self.model = load(self.args.model_path)
        self.lac = None
        self.use_pos = use_pos
        # buckets=None if not buckets else defaults
        if not buckets:
            self.args.buckets = None
        if args.prob:
            self.env.fields = self.env.fields._replace(PHEAD=Field('prob'))
        if self.use_pos:
            self.env.fields = self.env.fields._replace(CPOS=Field('postag'))
        # set default batch size if batch_size is None and not buckets
        if batch_size is None and not buckets:
            self.args.batch_size = 50
예제 #3
0
파일: run.py 프로젝트: zw331/DDParser
def predict(env):
    """Predict"""
    args = env.args

    logging.info("Load the dataset")
    if args.prob:
        env.fields = env.fields._replace(PHEAD=Field('prob'))
    predicts = Corpus.load(args.infer_data_path, env.fields)
    dataset = TextDataset(predicts, [env.WORD, env.FEAT], args.buckets)
    # set the data loader
    dataset.loader = batchify(dataset, args.batch_size)
    logging.info(f"{len(dataset)} sentences, "
                 f"{len(dataset.loader)} batches")

    logging.info("Load the model")
    model = load(args.model_path)
    model.args = args

    logging.info("Make predictions on the dataset")
    start = datetime.datetime.now()
    model.eval()
    pred_arcs, pred_rels, pred_probs = epoch_predict(env, args, model,
                                                     dataset.loader)
    total_time = datetime.datetime.now() - start
    # restore the order of sentences in the buckets
    indices = np.argsort(
        np.array([i for bucket in dataset.buckets.values() for i in bucket]))
    predicts.head = [pred_arcs[i] for i in indices]
    predicts.deprel = [pred_rels[i] for i in indices]
    if args.prob:
        predicts.prob = [pred_probs[i] for i in indices]
    logging.info(f"Save the predicted result to {args.infer_result_path}")
    predicts.save(args.infer_result_path)
    logging.info(f"{total_time}s elapsed, "
                 f"{len(dataset) / total_time.total_seconds():.2f} Sents/s")
예제 #4
0
    def load(cls, data_file_path, fields):
        """
            Load ddparser_data from given path to generate corpus

            <data_file_path>: str, absolute/relative path of ddparser_data file
            <fields>: CoNLL,

            return: Corpus

            <fields>: list, 各元素为Field类,分别对应各列的name和相关配置(如tokenizer, bos, eos等)
            <sentences>: list,数据中句子,每个句子为一个list,其中元素为tuple,分别对应各列
        """
        start, sentences = 0, []
        fields = [fd if fd is not None else Field(str(i)) for i, fd in enumerate(fields)]  # 将值为None的元素设置为Field()
        with open(data_file_path, 'r', encoding='utf-8') as f:  # 读取数据为list,每行为一个元素,保留空行,忽略注释(#行)等
            lines = [line.strip() for line in f.readlines()
                     if not line.startswith('#') and (len(line) == 1 or line.split()[0].isdigit())]
        for i, line in enumerate(lines):  # 将数据按句子分组,组成list,每个句子为一个list,其中元素为tuple,分别对应各列
            if not line:
                values = list(zip(*[j.split('\t') for j in lines[start:i]]))
                if values:
                    sentences.append(Sentence(fields, values))
                start = i + 1

        return cls(fields, sentences)
예제 #5
0
    def load_word_segments(cls, inputs, fields):
        """Load ddparser_data from word segmentation results to generate corpus"""
        fields = [fd if fd is not None else Field(str(i)) for i, fd in enumerate(fields)]
        sentences = []
        for tokens in inputs:
            values = [list(range(1, len(tokens) + 1)), tokens, tokens] + [['-'] * len(tokens) for _ in range(7)]

            sentences.append(Sentence(fields, values))
        return cls(fields, sentences)
예제 #6
0
    def load_lac_results(cls, inputs, fields):
        """Load ddparser_data from lac results to generate corpus"""
        sentences = []
        fields = [fd if fd is not None else Field(str(i)) for i, fd in enumerate(fields)]
        for _input in inputs:
            if isinstance(_input[0], list):
                tokens, poss = _input
            else:
                tokens = _input
                poss = ['-'] * len(tokens)
            values = [list(range(1,
                                 len(tokens) + 1)), tokens, tokens, poss, poss] + [['-'] * len(tokens)
                                                                                   for _ in range(5)]

            sentences.append(Sentence(fields, values))
        return cls(fields, sentences)
예제 #7
0
def predict_query(env):
    """Predict one query"""
    args = env.args
    logging.info("Load the model")
    model = load(args.model_path)
    model.eval()
    lac_mode = "seg" if args.feat != "pos" else "lac"
    lac = LAC.LAC(mode=lac_mode)
    if args.prob:
        env.fields = env.fields._replace(PHEAD=Field("prob"))

    while True:
        query = input()
        if isinstance(query, six.text_type):
            pass
        else:
            query = query.decode("utf-8")
        if not query:
            logging.info("quit!")
            return
        if len(query) > 200:
            logging.info("The length of the query should be less than 200!")
            continue
        start = datetime.datetime.now()
        lac_results = lac.run([query])
        predicts = Corpus.load_lac_results(lac_results, env.fields)
        dataset = TextDataset(predicts, [env.WORD, env.FEAT])
        # set the ddparser_data loader
        dataset.loader = batchify(dataset,
                                  args.batch_size,
                                  use_multiprocess=False,
                                  sequential_sampler=True)
        pred_arcs, pred_rels, pred_probs = epoch_predict(
            env, args, model, dataset.loader)
        predicts.head = pred_arcs
        predicts.deprel = pred_rels
        if args.prob:
            predicts.prob = pred_probs
        predicts._print()
        total_time = datetime.datetime.now() - start
        logging.info("{}s elapsed, {:.2f} Sents/s, {:.2f} ms/Sents".format(
            total_time,
            len(dataset) / total_time.total_seconds(),
            total_time.total_seconds() / len(dataset) * 1000))
예제 #8
0
파일: corpus.py 프로젝트: baidu/DDParser
    def load(cls, path, fields):
        """Load data from path to generate corpus"""
        start, sentences = 0, []
        fields = [
            fd if fd is not None else Field(str(i))
            for i, fd in enumerate(fields)
        ]
        with open(path, 'r', encoding='utf-8') as f:
            lines = [
                line.strip() for line in f.readlines()
                if not line.startswith('#') and (
                    len(line) == 1 or line.split()[0].isdigit())
            ]
        for i, line in enumerate(lines):
            if not line:
                values = list(zip(*[j.split('\t') for j in lines[start:i]]))
                if values:
                    sentences.append(Sentence(fields, values))
                start = i + 1

        return cls(fields, sentences)
예제 #9
0
def predict(env):
    """Predict"""
    arguments = env.args

    logging.info("Load the dataset")
    if arguments.prob:
        env.fields = env.fields._replace(PHEAD=Field("prob"))
    predicts = Corpus.load(arguments.infer_data_path, env.fields)
    dataset = TextDataset(predicts, [env.WORD, env.FEAT],
                          arguments.buckets)  # 只需提取word和feat
    # set the ddparser_data loader
    dataset.loader = batchify(dataset, arguments.batch_size)
    logging.info("{} sentences, {} batches".format(len(dataset),
                                                   len(dataset.loader)))

    logging.info("Load the model")
    model = load(arguments.model_path)
    model.args = arguments

    logging.info("Make predictions on the dataset")
    start = datetime.datetime.now()
    model.eval()
    connection_predicts, deprel_predicts, predict_prob = epoch_predict(
        env, arguments, model, dataset.loader)
    total_time = datetime.datetime.now() - start
    # restore the order of sentences in the buckets
    indices = np.argsort(
        np.array([i for bucket in dataset.buckets.values() for i in bucket]))
    predicts.head = [connection_predicts[i] for i in indices]
    predicts.deprel = [deprel_predicts[i] for i in indices]
    if arguments.prob:
        predicts.prob = [predict_prob[i] for i in indices]
    logging.info("Save the predicted result to {}".format(
        arguments.infer_result_path))
    predicts.save(arguments.infer_result_path)
    logging.info("{}s elapsed, {:.2f} Sents/s".format(
        total_time,
        len(dataset) / total_time.total_seconds()))
예제 #10
0
파일: config.py 프로젝트: baidu/DDParser
    def __init__(self, args):
        self.args = args
        # init log
        if args.log_path:
            utils.init_log(args.log_path, args.local_rank, args.log_level)
        # init seed
        fluid.default_main_program().random_seed = args.seed
        np.random.seed(args.seed)
        # init place
        if args.use_cuda:
            self.place = "gpu"

        else:
            self.place = "cpu"

        os.environ["FLAGS_paddle_num_threads"] = str(args.threads)
        if not os.path.exists(self.args.model_files):
            os.makedirs(self.args.model_files)
        if not os.path.exists(args.fields_path) or args.preprocess:
            logging.info("Preprocess the data")
            if args.encoding_model in [
                    "ernie-1.0", "ernie-tiny", "ernie-lstm"
            ]:
                tokenizer = ErnieTokenizer.from_pretrained(args.encoding_model)
                args["ernie_vocabs_size"] = len(tokenizer.vocab)
                self.WORD = ErnieField(
                    "word",
                    pad=tokenizer.pad_token,
                    unk=tokenizer.unk_token,
                    bos=tokenizer.cls_token,
                    eos=tokenizer.sep_token,
                    fix_len=args.fix_len,
                    tokenizer=tokenizer,
                )
                self.WORD.vocab = tokenizer.vocab
                args.feat = None
            else:
                self.WORD = Field(
                    "word",
                    pad=utils.pad,
                    unk=utils.unk,
                    bos=utils.bos,
                    eos=utils.eos,
                    lower=True,
                )
            if args.feat == "char":
                self.FEAT = SubwordField(
                    "chars",
                    pad=utils.pad,
                    unk=utils.unk,
                    bos=utils.bos,
                    eos=utils.eos,
                    fix_len=args.fix_len,
                    tokenize=list,
                )
            elif args.feat == "pos":
                self.FEAT = Field("postag", bos=utils.bos, eos=utils.eos)
            else:
                self.FEAT = None
            self.ARC = Field(
                "head",
                bos=utils.bos,
                eos=utils.eos,
                use_vocab=False,
                fn=utils.numericalize,
            )
            self.REL = Field("deprel", bos=utils.bos, eos=utils.eos)
            if args.feat == "char":
                self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=self.WORD,
                                    CPOS=self.FEAT,
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)

            train = Corpus.load(args.train_data_path, self.fields)

            if not args.encoding_model.startswith("ernie"):
                self.WORD.build(train, args.min_freq)
                self.FEAT.build(train)

            self.REL.build(train)
            if args.local_rank == 0:
                with open(args.fields_path, "wb") as f:
                    logging.info("dumping fileds to disk.")
                    pickle.dump(self.fields, f, protocol=2)
        else:
            logging.info("loading the fields.")
            with open(args.fields_path, "rb") as f:
                self.fields = pickle.load(f)

            if isinstance(self.fields.FORM, tuple):
                self.WORD, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.ARC, self.REL = self.fields.HEAD, self.fields.DEPREL

        if args.encoding_model.startswith("ernie"):
            vocab_items = self.WORD.vocab.items()
        else:
            vocab_items = self.WORD.vocab.stoi.items()

        self.puncts = np.array([i for s, i in vocab_items if utils.ispunct(s)],
                               dtype=np.int64)

        self.args.update({
            "n_words": len(self.WORD.vocab),
            "n_feats": self.FEAT and len(self.FEAT.vocab),
            "n_rels": len(self.REL.vocab),
            "pad_index": self.WORD.pad_index,
            "unk_index": self.WORD.unk_index,
            "bos_index": self.WORD.bos_index,
            "eos_index": self.WORD.eos_index,
            "feat_pad_index": self.FEAT and self.FEAT.pad_index,
        })
예제 #11
0
파일: config.py 프로젝트: zw331/DDParser
    def __init__(self, args):
        self.args = args
        # init log
        if self.args.log_path:
            utils.init_log(self.args.log_path, self.args.local_rank,
                           self.args.log_level)
        # init seed
        fluid.default_main_program().random_seed = self.args.seed
        np.random.seed(self.args.seed)
        # init place
        if self.args.use_cuda:
            if self.args.use_data_parallel:
                self.place = fluid.CUDAPlace(
                    fluid.dygraph.parallel.Env().dev_id)
            else:
                self.place = fluid.CUDAPlace(0)
        else:
            self.place = fluid.CPUPlace()

        os.environ['FLAGS_paddle_num_threads'] = str(self.args.threads)
        os.makedirs(self.args.model_files, exist_ok=True)

        if not os.path.exists(self.args.fields_path) or self.args.preprocess:
            logging.info("Preprocess the data")
            self.WORD = Field('word',
                              pad=utils.pad,
                              unk=utils.unk,
                              bos=utils.bos,
                              lower=True)
            if self.args.feat == 'char':
                self.FEAT = SubwordField('chars',
                                         pad=utils.pad,
                                         unk=utils.unk,
                                         bos=utils.bos,
                                         fix_len=self.args.fix_len,
                                         tokenize=list)
            else:
                self.FEAT = Field('postag', bos=utils.bos)
            self.ARC = Field('head',
                             bos=utils.bos,
                             use_vocab=False,
                             fn=utils.numericalize)
            self.REL = Field('deprel', bos=utils.bos)
            if self.args.feat == 'char':
                self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=self.WORD,
                                    CPOS=self.FEAT,
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)

            train = Corpus.load(self.args.train_data_path, self.fields)
            if self.args.pretrained_embedding_dir:
                logging.info("loading pretrained embedding from file.")
                embed = Embedding.load(self.args.pretrained_embedding_dir,
                                       self.args.unk)
            else:
                embed = None
            self.WORD.build(train, self.args.min_freq, embed)
            self.FEAT.build(train)
            self.REL.build(train)
            if self.args.local_rank == 0:
                with open(self.args.fields_path, "wb") as f:
                    logging.info("dumping fileds to disk.")
                    pickle.dump(self.fields, f, protocol=2)
        else:
            logging.info("loading the fields.")
            with open(self.args.fields_path, "rb") as f:
                self.fields = pickle.load(f)

            if isinstance(self.fields.FORM, tuple):
                self.WORD, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.ARC, self.REL = self.fields.HEAD, self.fields.DEPREL
        self.puncts = np.array(
            [i for s, i in self.WORD.vocab.stoi.items() if utils.ispunct(s)],
            dtype=np.int64)

        if self.WORD.embed is not None:
            self.args["pretrained_embed_shape"] = self.WORD.embed.shape
        else:
            self.args["pretrained_embed_shape"] = None

        self.args.update({
            'n_words': self.WORD.vocab.n_init,
            'n_feats': len(self.FEAT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index,
            'feat_pad_index': self.FEAT.pad_index
        })