Example #1
0
def predict(env):
    """Predict"""
    args = env.args

    logging.info("Load the dataset")
    if args.prob:
        env.fields = env.fields._replace(PHEAD=Field('prob'))
    predicts = Corpus.load(args.infer_data_path, env.fields)
    dataset = TextDataset(predicts, [env.WORD, env.FEAT], args.buckets)
    # set the data loader
    dataset.loader = batchify(dataset, args.batch_size)
    logging.info(f"{len(dataset)} sentences, "
                 f"{len(dataset.loader)} batches")

    logging.info("Load the model")
    model = load(args.model_path)
    model.args = args

    logging.info("Make predictions on the dataset")
    start = datetime.datetime.now()
    model.eval()
    pred_arcs, pred_rels, pred_probs = epoch_predict(env, args, model,
                                                     dataset.loader)
    total_time = datetime.datetime.now() - start
    # restore the order of sentences in the buckets
    indices = np.argsort(
        np.array([i for bucket in dataset.buckets.values() for i in bucket]))
    predicts.head = [pred_arcs[i] for i in indices]
    predicts.deprel = [pred_rels[i] for i in indices]
    if args.prob:
        predicts.prob = [pred_probs[i] for i in indices]
    logging.info(f"Save the predicted result to {args.infer_result_path}")
    predicts.save(args.infer_result_path)
    logging.info(f"{total_time}s elapsed, "
                 f"{len(dataset) / total_time.total_seconds():.2f} Sents/s")
Example #2
0
    def parse_seg(self, inputs):
        """
        预测已切词的句子。

        Args:
            x: list(list(str)), 已分词的句子,类型为list

        Returns:
            outputs: list, 依存分析结果

        Example:
        >>> ddp = DDParser()
        >>> inputs = [['百度', '是', '一家', '高科技', '公司'], ['他', '送', '了', '一本', '书']]
        >>> ddp.parse_seg(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']},
        {'word': ['他', '送', '了', '一本', '书'], 'head': [2, 0, 2, 5, 2], 'deprel': ['SBV', 'HED', 'MT', 'ATT', 'VOB']}]


        >>> ddp = DDParser(prob=True)
        >>> inputs = [['百度', '是', '一家', '高科技', '公司']]
        >>> ddp.parse_seg(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2],
        'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB'], 'prob': [1.0, 1.0, 1.0, 1.0, 1.0]}]
        """
        if not inputs:
            return
        if all([isinstance(i, list) and i and all(i) for i in inputs]):
            predicts = Corpus.load_word_segments(inputs, self.env.fields)
        else:
            logging.warning("please check the foramt of your inputs.")
            return
        dataset = TextDataset(predicts, [self.env.WORD, self.env.FEAT],
                              self.args.buckets)
        # set the ddparser_data loader
        dataset.loader = batchify(
            dataset,
            self.args.batch_size,
            use_multiprocess=False,
            sequential_sampler=True if not self.args.buckets else False,
        )
        pred_arcs, pred_rels, pred_probs = epoch_predict(
            self.env, self.args, self.model, dataset.loader)

        if self.args.buckets:
            indices = np.argsort(
                np.array([
                    i for bucket in dataset.buckets.values() for i in bucket
                ]))
        else:
            indices = range(len(pred_arcs))
        predicts.head = [pred_arcs[i] for i in indices]
        predicts.deprel = [pred_rels[i] for i in indices]
        if self.args.prob:
            predicts.prob = [pred_probs[i] for i in indices]

        outputs = predicts.get_result()
        if outputs[0].get("postag", None):
            for output in outputs:
                del output["postag"]
        return outputs
Example #3
0
    def parse(self, inputs):
        """
        预测未切词的句子。

        Args:
            x: list(str) | str, 未分词的句子,类型为str或list

        Returns:
            outputs: list, 依存分析结果

        Example:
        >>> ddp = DDParser()
        >>> inputs = "百度是一家高科技公司"
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']}]

        >>> inputs = ["百度是一家高科技公司", "他送了一本书"]
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']}, 
         {'word': ['他', '送', '了', '一本', '书'], 'head': [2, 0, 2, 5, 2], 'deprel': ['SBV', 'HED', 'MT', 'ATT', 'VOB']}]

        >>> ddp = DDParser(prob=True, use_pos=True)
        >>> inputs = "百度是一家高科技公司"
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'postag': ['ORG', 'v', 'm', 'n', 'n'], 
        'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB'], 'prob': [1.0, 1.0, 1.0, 1.0, 1.0]}]
        """
        if not self.lac:
            self.lac = LAC.LAC(mode='lac' if self.use_pos else "seg")
        if not inputs:
            return
        if isinstance(inputs, str):
            inputs = [inputs]
        if all([isinstance(i, str) and i for i in inputs]):
            lac_results = self.lac.run(inputs)
            predicts = Corpus.load_lac_results(lac_results, self.env.fields)
        else:
            logging.warning("please check the foramt of your inputs.")
            return
        dataset = TextDataset(predicts, [self.env.WORD, self.env.FEAT])
        # set the data loader
        dataset.loader = batchify(dataset,
                                  self.args.batch_size,
                                  use_multiprocess=False,
                                  sequential_sampler=True)
        pred_arcs, pred_rels, pred_probs = epoch_predict(
            self.env, self.args, self.model, dataset.loader)
        predicts.head = pred_arcs
        predicts.deprel = pred_rels
        if self.args.prob:
            predicts.prob = pred_probs
        outputs = predicts.get_result()
        return outputs
Example #4
0
def predict_query(env):
    """Predict one query"""
    args = env.args
    logging.info("Load the model")
    model = load(args.model_path)
    model.eval()
    lac_mode = "seg" if args.feat != "pos" else "lac"
    lac = LAC.LAC(mode=lac_mode)
    if args.prob:
        env.fields = env.fields._replace(PHEAD=Field("prob"))

    while True:
        query = input()
        if isinstance(query, six.text_type):
            pass
        else:
            query = query.decode("utf-8")
        if not query:
            logging.info("quit!")
            return
        if len(query) > 200:
            logging.info("The length of the query should be less than 200!")
            continue
        start = datetime.datetime.now()
        lac_results = lac.run([query])
        predicts = Corpus.load_lac_results(lac_results, env.fields)
        dataset = TextDataset(predicts, [env.WORD, env.FEAT])
        # set the ddparser_data loader
        dataset.loader = batchify(dataset,
                                  args.batch_size,
                                  use_multiprocess=False,
                                  sequential_sampler=True)
        pred_arcs, pred_rels, pred_probs = epoch_predict(
            env, args, model, dataset.loader)
        predicts.head = pred_arcs
        predicts.deprel = pred_rels
        if args.prob:
            predicts.prob = pred_probs
        predicts._print()
        total_time = datetime.datetime.now() - start
        logging.info("{}s elapsed, {:.2f} Sents/s, {:.2f} ms/Sents".format(
            total_time,
            len(dataset) / total_time.total_seconds(),
            total_time.total_seconds() / len(dataset) * 1000))
Example #5
0
def predict(env):
    """Predict"""
    arguments = env.args

    logging.info("Load the dataset")
    if arguments.prob:
        env.fields = env.fields._replace(PHEAD=Field("prob"))
    predicts = Corpus.load(arguments.infer_data_path, env.fields)
    dataset = TextDataset(predicts, [env.WORD, env.FEAT],
                          arguments.buckets)  # 只需提取word和feat
    # set the ddparser_data loader
    dataset.loader = batchify(dataset, arguments.batch_size)
    logging.info("{} sentences, {} batches".format(len(dataset),
                                                   len(dataset.loader)))

    logging.info("Load the model")
    model = load(arguments.model_path)
    model.args = arguments

    logging.info("Make predictions on the dataset")
    start = datetime.datetime.now()
    model.eval()
    connection_predicts, deprel_predicts, predict_prob = epoch_predict(
        env, arguments, model, dataset.loader)
    total_time = datetime.datetime.now() - start
    # restore the order of sentences in the buckets
    indices = np.argsort(
        np.array([i for bucket in dataset.buckets.values() for i in bucket]))
    predicts.head = [connection_predicts[i] for i in indices]
    predicts.deprel = [deprel_predicts[i] for i in indices]
    if arguments.prob:
        predicts.prob = [predict_prob[i] for i in indices]
    logging.info("Save the predicted result to {}".format(
        arguments.infer_result_path))
    predicts.save(arguments.infer_result_path)
    logging.info("{}s elapsed, {:.2f} Sents/s".format(
        total_time,
        len(dataset) / total_time.total_seconds()))
Example #6
0
    def parse(self, inputs):
        """
        预测未切词的句子。

        Args:
            x: list(str) | str, 未分词的句子,类型为str或list

        Returns:
            outputs: list, 依存分析结果

        Example:
        >>> ddp = DDParser()
        >>> inputs = "百度是一家高科技公司"
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']}]

        >>> inputs = ["百度是一家高科技公司", "他送了一本书"]
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']},
         {'word': ['他', '送', '了', '一本', '书'], 'head': [2, 0, 2, 5, 2], 'deprel': ['SBV', 'HED', 'MT', 'ATT', 'VOB']}]

        >>> ddp = DDParser(prob=True, use_pos=True)
        >>> inputs = "百度是一家高科技公司"
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'postag': ['ORG', 'v', 'm', 'n', 'n'],
        'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB'], 'prob': [1.0, 1.0, 1.0, 1.0, 1.0]}]
        """
        if not self.lac:
            self.lac = LAC.LAC(mode="lac" if self.use_pos else "seg", use_cuda=self.args.use_cuda)
        if not inputs:
            return
        if isinstance(inputs, six.string_types):
            inputs = [inputs]
        if all([isinstance(i, six.string_types) and i for i in inputs]):
            lac_results = []
            position = 0
            try:
                inputs = [query if isinstance(query, six.text_type) else query.decode("utf-8") for query in inputs]
            except UnicodeDecodeError:
                logging.warning("encoding only supports UTF-8!")
                return

            while position < len(inputs):
                lac_results += self.lac.run(inputs[position:position + self.args.batch_size])
                position += self.args.batch_size
            predicts = Corpus.load_lac_results(lac_results, self.env.fields)
        else:
            logging.warning("please check the foramt of your inputs.")
            return
        dataset = TextDataset(predicts, [self.env.WORD, self.env.FEAT], self.args.buckets)
        # set the data loader

        dataset.loader = batchify(
            dataset,
            self.args.batch_size,
            use_multiprocess=False,
            sequential_sampler=True if not self.args.buckets else False,
        )
        pred_arcs, pred_rels, pred_probs = epoch_predict(self.env, self.args, self.model, dataset.loader)

        if self.args.buckets:
            indices = np.argsort(np.array([i for bucket in dataset.buckets.values() for i in bucket]))
        else:
            indices = range(len(pred_arcs))
        predicts.head = [pred_arcs[i] for i in indices]
        predicts.deprel = [pred_rels[i] for i in indices]
        if self.args.prob:
            predicts.prob = [pred_probs[i] for i in indices]

        outputs = predicts.get_result()
        return outputs