예제 #1
0
    def parse_seg(self, inputs):
        """
        预测已切词的句子。

        Args:
            x: list(list(str)), 已分词的句子,类型为list

        Returns:
            outputs: list, 依存分析结果

        Example:
        >>> ddp = DDParser()
        >>> inputs = [['百度', '是', '一家', '高科技', '公司'], ['他', '送', '了', '一本', '书']]
        >>> ddp.parse_seg(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']},
        {'word': ['他', '送', '了', '一本', '书'], 'head': [2, 0, 2, 5, 2], 'deprel': ['SBV', 'HED', 'MT', 'ATT', 'VOB']}]


        >>> ddp = DDParser(prob=True)
        >>> inputs = [['百度', '是', '一家', '高科技', '公司']]
        >>> ddp.parse_seg(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2],
        'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB'], 'prob': [1.0, 1.0, 1.0, 1.0, 1.0]}]
        """
        if not inputs:
            return
        if all([isinstance(i, list) and i and all(i) for i in inputs]):
            predicts = Corpus.load_word_segments(inputs, self.env.fields)
        else:
            logging.warning("please check the foramt of your inputs.")
            return
        dataset = TextDataset(predicts, [self.env.WORD, self.env.FEAT],
                              self.args.buckets)
        # set the ddparser_data loader
        dataset.loader = batchify(
            dataset,
            self.args.batch_size,
            use_multiprocess=False,
            sequential_sampler=True if not self.args.buckets else False,
        )
        pred_arcs, pred_rels, pred_probs = epoch_predict(
            self.env, self.args, self.model, dataset.loader)

        if self.args.buckets:
            indices = np.argsort(
                np.array([
                    i for bucket in dataset.buckets.values() for i in bucket
                ]))
        else:
            indices = range(len(pred_arcs))
        predicts.head = [pred_arcs[i] for i in indices]
        predicts.deprel = [pred_rels[i] for i in indices]
        if self.args.prob:
            predicts.prob = [pred_probs[i] for i in indices]

        outputs = predicts.get_result()
        if outputs[0].get("postag", None):
            for output in outputs:
                del output["postag"]
        return outputs
예제 #2
0
파일: run.py 프로젝트: zw331/DDParser
def predict(env):
    """Predict"""
    args = env.args

    logging.info("Load the dataset")
    if args.prob:
        env.fields = env.fields._replace(PHEAD=Field('prob'))
    predicts = Corpus.load(args.infer_data_path, env.fields)
    dataset = TextDataset(predicts, [env.WORD, env.FEAT], args.buckets)
    # set the data loader
    dataset.loader = batchify(dataset, args.batch_size)
    logging.info(f"{len(dataset)} sentences, "
                 f"{len(dataset.loader)} batches")

    logging.info("Load the model")
    model = load(args.model_path)
    model.args = args

    logging.info("Make predictions on the dataset")
    start = datetime.datetime.now()
    model.eval()
    pred_arcs, pred_rels, pred_probs = epoch_predict(env, args, model,
                                                     dataset.loader)
    total_time = datetime.datetime.now() - start
    # restore the order of sentences in the buckets
    indices = np.argsort(
        np.array([i for bucket in dataset.buckets.values() for i in bucket]))
    predicts.head = [pred_arcs[i] for i in indices]
    predicts.deprel = [pred_rels[i] for i in indices]
    if args.prob:
        predicts.prob = [pred_probs[i] for i in indices]
    logging.info(f"Save the predicted result to {args.infer_result_path}")
    predicts.save(args.infer_result_path)
    logging.info(f"{total_time}s elapsed, "
                 f"{len(dataset) / total_time.total_seconds():.2f} Sents/s")
예제 #3
0
def evaluate(env):
    """Evaluate"""
    arguments = env.args
    punctuation = dygraph.to_variable(env.puncts, zero_copy=False)

    logging.info("Load the dataset")
    evaluates = Corpus.load(arguments.test_data_path, env.fields)
    dataset = TextDataset(evaluates, env.fields, arguments.buckets)
    # set the ddparser_data loader
    dataset.loader = batchify(dataset, arguments.batch_size)

    logging.info("{} sentences, ".format(len(dataset)) +
                 "{} batches, ".format(len(dataset.loader)) +
                 "{} buckets".format(len(dataset.buckets)))
    logging.info("Load the model")
    model = load(arguments.model_path)

    logging.info("Evaluate the dataset")
    start = datetime.datetime.now()
    loss, metric = epoch_evaluate(arguments, model, dataset.loader,
                                  punctuation)
    total_time = datetime.datetime.now() - start
    logging.info("Loss: {:.4f} {}".format(loss, metric))
    logging.info("{}s elapsed, {:.2f} Sents/s".format(
        total_time,
        len(dataset) / total_time.total_seconds()))
예제 #4
0
    def parse(self, inputs):
        """
        预测未切词的句子。

        Args:
            x: list(str) | str, 未分词的句子,类型为str或list

        Returns:
            outputs: list, 依存分析结果

        Example:
        >>> ddp = DDParser()
        >>> inputs = "百度是一家高科技公司"
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']}]

        >>> inputs = ["百度是一家高科技公司", "他送了一本书"]
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']}, 
         {'word': ['他', '送', '了', '一本', '书'], 'head': [2, 0, 2, 5, 2], 'deprel': ['SBV', 'HED', 'MT', 'ATT', 'VOB']}]

        >>> ddp = DDParser(prob=True, use_pos=True)
        >>> inputs = "百度是一家高科技公司"
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'postag': ['ORG', 'v', 'm', 'n', 'n'], 
        'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB'], 'prob': [1.0, 1.0, 1.0, 1.0, 1.0]}]
        """
        if not self.lac:
            self.lac = LAC.LAC(mode='lac' if self.use_pos else "seg")
        if not inputs:
            return
        if isinstance(inputs, str):
            inputs = [inputs]
        if all([isinstance(i, str) and i for i in inputs]):
            lac_results = self.lac.run(inputs)
            predicts = Corpus.load_lac_results(lac_results, self.env.fields)
        else:
            logging.warning("please check the foramt of your inputs.")
            return
        dataset = TextDataset(predicts, [self.env.WORD, self.env.FEAT])
        # set the data loader
        dataset.loader = batchify(dataset,
                                  self.args.batch_size,
                                  use_multiprocess=False,
                                  sequential_sampler=True)
        pred_arcs, pred_rels, pred_probs = epoch_predict(
            self.env, self.args, self.model, dataset.loader)
        predicts.head = pred_arcs
        predicts.deprel = pred_rels
        if self.args.prob:
            predicts.prob = pred_probs
        outputs = predicts.get_result()
        return outputs
예제 #5
0
def predict_query(env):
    """Predict one query"""
    args = env.args
    logging.info("Load the model")
    model = load(args.model_path)
    model.eval()
    lac_mode = "seg" if args.feat != "pos" else "lac"
    lac = LAC.LAC(mode=lac_mode)
    if args.prob:
        env.fields = env.fields._replace(PHEAD=Field("prob"))

    while True:
        query = input()
        if isinstance(query, six.text_type):
            pass
        else:
            query = query.decode("utf-8")
        if not query:
            logging.info("quit!")
            return
        if len(query) > 200:
            logging.info("The length of the query should be less than 200!")
            continue
        start = datetime.datetime.now()
        lac_results = lac.run([query])
        predicts = Corpus.load_lac_results(lac_results, env.fields)
        dataset = TextDataset(predicts, [env.WORD, env.FEAT])
        # set the ddparser_data loader
        dataset.loader = batchify(dataset,
                                  args.batch_size,
                                  use_multiprocess=False,
                                  sequential_sampler=True)
        pred_arcs, pred_rels, pred_probs = epoch_predict(
            env, args, model, dataset.loader)
        predicts.head = pred_arcs
        predicts.deprel = pred_rels
        if args.prob:
            predicts.prob = pred_probs
        predicts._print()
        total_time = datetime.datetime.now() - start
        logging.info("{}s elapsed, {:.2f} Sents/s, {:.2f} ms/Sents".format(
            total_time,
            len(dataset) / total_time.total_seconds(),
            total_time.total_seconds() / len(dataset) * 1000))
예제 #6
0
def predict(env):
    """Predict"""
    arguments = env.args

    logging.info("Load the dataset")
    if arguments.prob:
        env.fields = env.fields._replace(PHEAD=Field("prob"))
    predicts = Corpus.load(arguments.infer_data_path, env.fields)
    dataset = TextDataset(predicts, [env.WORD, env.FEAT],
                          arguments.buckets)  # 只需提取word和feat
    # set the ddparser_data loader
    dataset.loader = batchify(dataset, arguments.batch_size)
    logging.info("{} sentences, {} batches".format(len(dataset),
                                                   len(dataset.loader)))

    logging.info("Load the model")
    model = load(arguments.model_path)
    model.args = arguments

    logging.info("Make predictions on the dataset")
    start = datetime.datetime.now()
    model.eval()
    connection_predicts, deprel_predicts, predict_prob = epoch_predict(
        env, arguments, model, dataset.loader)
    total_time = datetime.datetime.now() - start
    # restore the order of sentences in the buckets
    indices = np.argsort(
        np.array([i for bucket in dataset.buckets.values() for i in bucket]))
    predicts.head = [connection_predicts[i] for i in indices]
    predicts.deprel = [deprel_predicts[i] for i in indices]
    if arguments.prob:
        predicts.prob = [predict_prob[i] for i in indices]
    logging.info("Save the predicted result to {}".format(
        arguments.infer_result_path))
    predicts.save(arguments.infer_result_path)
    logging.info("{}s elapsed, {:.2f} Sents/s".format(
        total_time,
        len(dataset) / total_time.total_seconds()))
예제 #7
0
파일: run.py 프로젝트: zw331/DDParser
def evaluate(env):
    """Evaluate"""
    args = env.args
    puncts = dygraph.to_variable(env.puncts, zero_copy=False)

    logging.info("Load the dataset")
    evaluates = Corpus.load(args.test_data_path, env.fields)
    dataset = TextDataset(evaluates, env.fields, args.buckets)
    # set the data loader
    dataset.loader = batchify(dataset, args.batch_size)

    logging.info(f"{len(dataset)} sentences, "
                 f"{len(dataset.loader)} batches, "
                 f"{len(dataset.buckets)} buckets")
    logging.info("Load the model")
    model = load(args.model_path)

    logging.info("Evaluate the dataset")
    start = datetime.datetime.now()
    loss, metric = epoch_evaluate(args, model, dataset.loader, puncts)
    total_time = datetime.datetime.now() - start
    logging.info(f"Loss: {loss:.4f} {metric}")
    logging.info(f"{total_time}s elapsed, "
                 f"{len(dataset) / total_time.total_seconds():.2f} Sents/s")
예제 #8
0
파일: run.py 프로젝트: zw331/DDParser
def train(env):
    """Train"""
    args = env.args

    logging.info("loading data.")
    train = Corpus.load(args.train_data_path, env.fields)
    dev = Corpus.load(args.valid_data_path, env.fields)
    test = Corpus.load(args.test_data_path, env.fields)
    logging.info("init dataset.")
    train = TextDataset(train, env.fields, args.buckets)
    dev = TextDataset(dev, env.fields, args.buckets)
    test = TextDataset(test, env.fields, args.buckets)
    logging.info("set the data loaders.")
    train.loader = batchify(train, args.batch_size, args.use_data_parallel,
                            True)
    dev.loader = batchify(dev, args.batch_size)
    test.loader = batchify(test, args.batch_size)

    logging.info(f"{'train:':6} {len(train):5} sentences, "
                 f"{len(train.loader):3} batches, "
                 f"{len(train.buckets)} buckets")
    logging.info(f"{'dev:':6} {len(dev):5} sentences, "
                 f"{len(dev.loader):3} batches, "
                 f"{len(train.buckets)} buckets")
    logging.info(f"{'test:':6} {len(test):5} sentences, "
                 f"{len(test.loader):3} batches, "
                 f"{len(train.buckets)} buckets")

    logging.info("Create the model")
    model = Model(args, env.WORD.embed)

    # init parallel strategy
    if args.use_data_parallel:
        strategy = dygraph.parallel.prepare_context()
        model = dygraph.parallel.DataParallel(model, strategy)

    if args.use_cuda:
        grad_clip = fluid.clip.GradientClipByNorm(clip_norm=args.clip)
    else:
        grad_clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.clip)
    decay = dygraph.ExponentialDecay(learning_rate=args.lr,
                                     decay_steps=args.decay_steps,
                                     decay_rate=args.decay)
    optimizer = fluid.optimizer.AdamOptimizer(
        learning_rate=decay,
        beta1=args.mu,
        beta2=args.nu,
        epsilon=args.epsilon,
        parameter_list=model.parameters(),
        grad_clip=grad_clip)

    total_time = datetime.timedelta()
    best_e, best_metric = 1, Metric()

    puncts = dygraph.to_variable(env.puncts, zero_copy=False)
    logging.info("start training.")
    for epoch in range(1, args.epochs + 1):
        start = datetime.datetime.now()
        # train one epoch and update the parameter
        logging.info(f"Epoch {epoch} / {args.epochs}:")
        epoch_train(args, model, optimizer, train.loader, epoch)
        if args.local_rank == 0:
            loss, dev_metric = epoch_evaluate(args, model, dev.loader, puncts)
            logging.info(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}")
            loss, test_metric = epoch_evaluate(args, model, test.loader,
                                               puncts)
            logging.info(f"{'test:':6} Loss: {loss:.4f} {test_metric}")

            t = datetime.datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric and epoch > args.patience // 10:
                best_e, best_metric = epoch, dev_metric
                save(args.model_path, args, model, optimizer)
                logging.info(f"{t}s elapsed (saved)\n")
            else:
                logging.info(f"{t}s elapsed\n")
            total_time += t
            if epoch - best_e >= args.patience:
                break
    if args.local_rank == 0:
        model = load(args.model_path, model)
        loss, metric = epoch_evaluate(args, model, test.loader, puncts)
        logging.info(
            f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        logging.info(
            f"the score of test at epoch {best_e} is {metric.score:.2%}")
        logging.info(f"average time of each epoch is {total_time / epoch}s")
        logging.info(f"{total_time}s elapsed")
예제 #9
0
파일: config.py 프로젝트: baidu/DDParser
    def __init__(self, args):
        self.args = args
        # init log
        if args.log_path:
            utils.init_log(args.log_path, args.local_rank, args.log_level)
        # init seed
        fluid.default_main_program().random_seed = args.seed
        np.random.seed(args.seed)
        # init place
        if args.use_cuda:
            self.place = "gpu"

        else:
            self.place = "cpu"

        os.environ["FLAGS_paddle_num_threads"] = str(args.threads)
        if not os.path.exists(self.args.model_files):
            os.makedirs(self.args.model_files)
        if not os.path.exists(args.fields_path) or args.preprocess:
            logging.info("Preprocess the data")
            if args.encoding_model in [
                    "ernie-1.0", "ernie-tiny", "ernie-lstm"
            ]:
                tokenizer = ErnieTokenizer.from_pretrained(args.encoding_model)
                args["ernie_vocabs_size"] = len(tokenizer.vocab)
                self.WORD = ErnieField(
                    "word",
                    pad=tokenizer.pad_token,
                    unk=tokenizer.unk_token,
                    bos=tokenizer.cls_token,
                    eos=tokenizer.sep_token,
                    fix_len=args.fix_len,
                    tokenizer=tokenizer,
                )
                self.WORD.vocab = tokenizer.vocab
                args.feat = None
            else:
                self.WORD = Field(
                    "word",
                    pad=utils.pad,
                    unk=utils.unk,
                    bos=utils.bos,
                    eos=utils.eos,
                    lower=True,
                )
            if args.feat == "char":
                self.FEAT = SubwordField(
                    "chars",
                    pad=utils.pad,
                    unk=utils.unk,
                    bos=utils.bos,
                    eos=utils.eos,
                    fix_len=args.fix_len,
                    tokenize=list,
                )
            elif args.feat == "pos":
                self.FEAT = Field("postag", bos=utils.bos, eos=utils.eos)
            else:
                self.FEAT = None
            self.ARC = Field(
                "head",
                bos=utils.bos,
                eos=utils.eos,
                use_vocab=False,
                fn=utils.numericalize,
            )
            self.REL = Field("deprel", bos=utils.bos, eos=utils.eos)
            if args.feat == "char":
                self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=self.WORD,
                                    CPOS=self.FEAT,
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)

            train = Corpus.load(args.train_data_path, self.fields)

            if not args.encoding_model.startswith("ernie"):
                self.WORD.build(train, args.min_freq)
                self.FEAT.build(train)

            self.REL.build(train)
            if args.local_rank == 0:
                with open(args.fields_path, "wb") as f:
                    logging.info("dumping fileds to disk.")
                    pickle.dump(self.fields, f, protocol=2)
        else:
            logging.info("loading the fields.")
            with open(args.fields_path, "rb") as f:
                self.fields = pickle.load(f)

            if isinstance(self.fields.FORM, tuple):
                self.WORD, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.ARC, self.REL = self.fields.HEAD, self.fields.DEPREL

        if args.encoding_model.startswith("ernie"):
            vocab_items = self.WORD.vocab.items()
        else:
            vocab_items = self.WORD.vocab.stoi.items()

        self.puncts = np.array([i for s, i in vocab_items if utils.ispunct(s)],
                               dtype=np.int64)

        self.args.update({
            "n_words": len(self.WORD.vocab),
            "n_feats": self.FEAT and len(self.FEAT.vocab),
            "n_rels": len(self.REL.vocab),
            "pad_index": self.WORD.pad_index,
            "unk_index": self.WORD.unk_index,
            "bos_index": self.WORD.bos_index,
            "eos_index": self.WORD.eos_index,
            "feat_pad_index": self.FEAT and self.FEAT.pad_index,
        })
예제 #10
0
파일: config.py 프로젝트: zw331/DDParser
    def __init__(self, args):
        self.args = args
        # init log
        if self.args.log_path:
            utils.init_log(self.args.log_path, self.args.local_rank,
                           self.args.log_level)
        # init seed
        fluid.default_main_program().random_seed = self.args.seed
        np.random.seed(self.args.seed)
        # init place
        if self.args.use_cuda:
            if self.args.use_data_parallel:
                self.place = fluid.CUDAPlace(
                    fluid.dygraph.parallel.Env().dev_id)
            else:
                self.place = fluid.CUDAPlace(0)
        else:
            self.place = fluid.CPUPlace()

        os.environ['FLAGS_paddle_num_threads'] = str(self.args.threads)
        os.makedirs(self.args.model_files, exist_ok=True)

        if not os.path.exists(self.args.fields_path) or self.args.preprocess:
            logging.info("Preprocess the data")
            self.WORD = Field('word',
                              pad=utils.pad,
                              unk=utils.unk,
                              bos=utils.bos,
                              lower=True)
            if self.args.feat == 'char':
                self.FEAT = SubwordField('chars',
                                         pad=utils.pad,
                                         unk=utils.unk,
                                         bos=utils.bos,
                                         fix_len=self.args.fix_len,
                                         tokenize=list)
            else:
                self.FEAT = Field('postag', bos=utils.bos)
            self.ARC = Field('head',
                             bos=utils.bos,
                             use_vocab=False,
                             fn=utils.numericalize)
            self.REL = Field('deprel', bos=utils.bos)
            if self.args.feat == 'char':
                self.fields = CoNLL(FORM=(self.WORD, self.FEAT),
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)
            else:
                self.fields = CoNLL(FORM=self.WORD,
                                    CPOS=self.FEAT,
                                    HEAD=self.ARC,
                                    DEPREL=self.REL)

            train = Corpus.load(self.args.train_data_path, self.fields)
            if self.args.pretrained_embedding_dir:
                logging.info("loading pretrained embedding from file.")
                embed = Embedding.load(self.args.pretrained_embedding_dir,
                                       self.args.unk)
            else:
                embed = None
            self.WORD.build(train, self.args.min_freq, embed)
            self.FEAT.build(train)
            self.REL.build(train)
            if self.args.local_rank == 0:
                with open(self.args.fields_path, "wb") as f:
                    logging.info("dumping fileds to disk.")
                    pickle.dump(self.fields, f, protocol=2)
        else:
            logging.info("loading the fields.")
            with open(self.args.fields_path, "rb") as f:
                self.fields = pickle.load(f)

            if isinstance(self.fields.FORM, tuple):
                self.WORD, self.FEAT = self.fields.FORM
            else:
                self.WORD, self.FEAT = self.fields.FORM, self.fields.CPOS
            self.ARC, self.REL = self.fields.HEAD, self.fields.DEPREL
        self.puncts = np.array(
            [i for s, i in self.WORD.vocab.stoi.items() if utils.ispunct(s)],
            dtype=np.int64)

        if self.WORD.embed is not None:
            self.args["pretrained_embed_shape"] = self.WORD.embed.shape
        else:
            self.args["pretrained_embed_shape"] = None

        self.args.update({
            'n_words': self.WORD.vocab.n_init,
            'n_feats': len(self.FEAT.vocab),
            'n_rels': len(self.REL.vocab),
            'pad_index': self.WORD.pad_index,
            'unk_index': self.WORD.unk_index,
            'bos_index': self.WORD.bos_index,
            'feat_pad_index': self.FEAT.pad_index
        })
예제 #11
0
def train(env):
    """Train"""
    args = env.args

    logging.info("loading data.")
    train = Corpus.load(args.train_data_path, env.fields)
    dev = Corpus.load(args.valid_data_path, env.fields)
    test = Corpus.load(args.test_data_path, env.fields)
    logging.info("init dataset.")
    train = TextDataset(train, env.fields, args.buckets)
    dev = TextDataset(dev, env.fields, args.buckets)
    test = TextDataset(test, env.fields, args.buckets)
    logging.info("set the data loaders.")
    train.loader = batchify(train, args.batch_size, args.use_data_parallel, True)
    dev.loader = batchify(dev, args.batch_size)
    test.loader = batchify(test, args.batch_size)

    logging.info("{:6} {:5} sentences, ".format('train:', len(train)) + "{:3} batches, ".format(len(train.loader)) +
                 "{} buckets".format(len(train.buckets)))
    logging.info("{:6} {:5} sentences, ".format('dev:', len(dev)) + "{:3} batches, ".format(len(dev.loader)) +
                 "{} buckets".format(len(dev.buckets)))
    logging.info("{:6} {:5} sentences, ".format('test:', len(test)) + "{:3} batches, ".format(len(test.loader)) +
                 "{} buckets".format(len(test.buckets)))

    logging.info("Create the model")
    model = Model(args)

    # init parallel strategy
    if args.use_data_parallel:
        dist.init_parallel_env()
        model = paddle.DataParallel(model)

    if args.encoding_model.startswith(
            "ernie") and args.encoding_model != "ernie-lstm" or args.encoding_model == 'transformer':
        args['lr'] = args.ernie_lr
    else:
        args['lr'] = args.lstm_lr

    if args.encoding_model.startswith("ernie") and args.encoding_model != "ernie-lstm":
        max_steps = 100 * len(train.loader)
        decay = LinearDecay(args.lr, int(args.warmup_proportion * max_steps), max_steps)
        clip = args.ernie_clip
    else:
        decay = dygraph.ExponentialDecay(learning_rate=args.lr, decay_steps=args.decay_steps, decay_rate=args.decay)
        clip = args.clip

    if args.use_cuda:
        grad_clip = fluid.clip.GradientClipByNorm(clip_norm=clip)
    else:
        grad_clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=clip)

    if args.encoding_model.startswith("ernie") and args.encoding_model != "ernie-lstm":
        optimizer = AdamW(
            learning_rate=decay,
            parameter_list=model.parameters(),
            weight_decay=args.weight_decay,
            grad_clip=grad_clip,
        )
    else:
        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=decay,
            beta1=args.mu,
            beta2=args.nu,
            epsilon=args.epsilon,
            parameter_list=model.parameters(),
            grad_clip=grad_clip,
        )

    total_time = datetime.timedelta()
    best_e, best_metric = 1, Metric()

    puncts = dygraph.to_variable(env.puncts, zero_copy=False)
    logging.info("start training.")

    for epoch in range(1, args.epochs + 1):
        start = datetime.datetime.now()
        # train one epoch and update the parameter
        logging.info("Epoch {} / {}:".format(epoch, args.epochs))
        epoch_train(args, model, optimizer, train.loader, epoch)
        if args.local_rank == 0:
            loss, dev_metric = epoch_evaluate(args, model, dev.loader, puncts)
            logging.info("{:6} Loss: {:.4f} {}".format('dev:', loss, dev_metric))
            loss, test_metric = epoch_evaluate(args, model, test.loader, puncts)
            logging.info("{:6} Loss: {:.4f} {}".format('test:', loss, test_metric))

            t = datetime.datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric and epoch > args.patience // 10:
                best_e, best_metric = epoch, dev_metric
                save(args.model_path, args, model, optimizer)
                logging.info("{}s elapsed (saved)\n".format(t))
            else:
                logging.info("{}s elapsed\n".format(t))
            total_time += t
            if epoch - best_e >= args.patience:
                break
    if args.local_rank == 0:
        model = load(args.model_path, model)
        loss, metric = epoch_evaluate(args, model, test.loader, puncts)
        logging.info("max score of dev is {:.2%} at epoch {}".format(best_metric.score, best_e))
        logging.info("the score of test at epoch {} is {:.2%}".format(best_e, metric.score))
        logging.info("average time of each epoch is {}s".format(total_time / epoch))
        logging.info("{}s elapsed".format(total_time))
예제 #12
0
    def parse(self, inputs):
        """
        预测未切词的句子。

        Args:
            x: list(str) | str, 未分词的句子,类型为str或list

        Returns:
            outputs: list, 依存分析结果

        Example:
        >>> ddp = DDParser()
        >>> inputs = "百度是一家高科技公司"
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']}]

        >>> inputs = ["百度是一家高科技公司", "他送了一本书"]
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB']},
         {'word': ['他', '送', '了', '一本', '书'], 'head': [2, 0, 2, 5, 2], 'deprel': ['SBV', 'HED', 'MT', 'ATT', 'VOB']}]

        >>> ddp = DDParser(prob=True, use_pos=True)
        >>> inputs = "百度是一家高科技公司"
        >>> ddp.parse(inputs)
        [{'word': ['百度', '是', '一家', '高科技', '公司'], 'postag': ['ORG', 'v', 'm', 'n', 'n'],
        'head': [2, 0, 5, 5, 2], 'deprel': ['SBV', 'HED', 'ATT', 'ATT', 'VOB'], 'prob': [1.0, 1.0, 1.0, 1.0, 1.0]}]
        """
        if not self.lac:
            self.lac = LAC.LAC(mode="lac" if self.use_pos else "seg", use_cuda=self.args.use_cuda)
        if not inputs:
            return
        if isinstance(inputs, six.string_types):
            inputs = [inputs]
        if all([isinstance(i, six.string_types) and i for i in inputs]):
            lac_results = []
            position = 0
            try:
                inputs = [query if isinstance(query, six.text_type) else query.decode("utf-8") for query in inputs]
            except UnicodeDecodeError:
                logging.warning("encoding only supports UTF-8!")
                return

            while position < len(inputs):
                lac_results += self.lac.run(inputs[position:position + self.args.batch_size])
                position += self.args.batch_size
            predicts = Corpus.load_lac_results(lac_results, self.env.fields)
        else:
            logging.warning("please check the foramt of your inputs.")
            return
        dataset = TextDataset(predicts, [self.env.WORD, self.env.FEAT], self.args.buckets)
        # set the data loader

        dataset.loader = batchify(
            dataset,
            self.args.batch_size,
            use_multiprocess=False,
            sequential_sampler=True if not self.args.buckets else False,
        )
        pred_arcs, pred_rels, pred_probs = epoch_predict(self.env, self.args, self.model, dataset.loader)

        if self.args.buckets:
            indices = np.argsort(np.array([i for bucket in dataset.buckets.values() for i in bucket]))
        else:
            indices = range(len(pred_arcs))
        predicts.head = [pred_arcs[i] for i in indices]
        predicts.deprel = [pred_rels[i] for i in indices]
        if self.args.prob:
            predicts.prob = [pred_probs[i] for i in indices]

        outputs = predicts.get_result()
        return outputs