Beispiel #1
0
def main():
    # 如果要输出 LTP master 分支可以使用的模型,传入 ltp_adapter 参数为输出文件夹路径,如 ltp_model
    parser = ArgumentParser()
    parser = add_task_specific_args(parser)
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    parser.set_defaults(min_epochs=1, max_epochs=10)
    parser.set_defaults(gradient_clip_val=1.0,
                        lr_layers_getter='get_layer_lrs_with_crf')
    args = parser.parse_args()

    if args.ltp_model is not None and args.resume_from_checkpoint is not None:
        deploy_model(args, args.ltp_version)
    elif args.build_ner_dataset:
        build_ner_distill_dataset(args)
    elif args.tune:
        tune_train(args,
                   model_class=Model,
                   task_info=task_info,
                   build_method=build_method)
    else:
        common_train(args,
                     model_class=Model,
                     task_info=task_info,
                     build_method=build_method)
Beispiel #2
0
def main():
    parser = ArgumentParser()

    # add task level args
    parser = add_common_specific_args(parser)
    parser = add_tune_specific_args(parser)
    parser = add_task_specific_args(parser)

    # add model specific args
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    parser.set_defaults(min_epochs=1, max_epochs=10)
    parser.set_defaults(gradient_clip_val=1.0,
                        lr_layers_getter='get_layer_lrs_with_crf')
    args = parser.parse_args()

    if args.ltp_model is not None and args.resume_from_checkpoint is not None:
        deploy_model(args, args.ltp_version)
    elif args.tune:
        tune_train(args,
                   model_class=Model,
                   task_info=task_info,
                   build_method=build_method)
    else:
        common_train(args,
                     model_class=Model,
                     task_info=task_info,
                     build_method=build_method)
Beispiel #3
0
def deploy_model_4_1(args):
    from argparse import Namespace

    model = Model.load_from_checkpoint(args.resume_from_checkpoint,
                                       hparams=args)
    model_state_dict = model.state_dict()
    model_config = Namespace(**model.hparams)

    ltp_model = {
        'version': "4.1.0",
        'model': model_state_dict,
        'model_config': model_config,
        'transformer_config': model.transformer.config.to_dict(),
        'seg': ['I-W', 'B-W'],
        'pos': load_labels(os.path.join(args.pos_data_dir, 'pos_labels.txt')),
        'ner': load_labels(os.path.join(args.ner_data_dir, 'ner_labels.txt')),
        'srl': load_labels(os.path.join(args.srl_data_dir, 'srl_labels.txt')),
        'dep': load_labels(os.path.join(args.dep_data_dir, 'dep_labels.txt')),
        'sdp': load_labels(os.path.join(args.sdp_data_dir, 'deps_labels.txt')),
    }
    os.makedirs(args.ltp_model, exist_ok=True)
    torch.save(ltp_model, os.path.join(args.ltp_model, 'ltp.model'))

    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.transformer)
    tokenizer.save_pretrained(args.ltp_model)
Beispiel #4
0
    def __init__(self, path: str = 'small', device=None, **kwargs):
        if device is not None:
            if isinstance(device, torch.device):
                self.device = device
            elif isinstance(device, str):
                self.device = torch.device(device)
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        if path in model_map or is_remote_url(path) or os.path.isfile(path):
            proxies = kwargs.pop("proxies", None)
            cache_dir = kwargs.pop("cache_dir", LTP_CACHE)
            force_download = kwargs.pop("force_download", False)
            resume_download = kwargs.pop("resume_download", False)
            local_files_only = kwargs.pop("local_files_only", False)
            path = cached_path(model_map.get(path, path),
                               cache_dir=cache_dir,
                               force_download=force_download,
                               proxies=proxies,
                               resume_download=resume_download,
                               local_files_only=local_files_only,
                               extract_compressed_file=True)
        elif not os.path.isdir(path):
            raise FileNotFoundError()
        try:
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)
        except Exception as e:
            fake_import_pytorch_lightning()
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)

        self.cache_dir = path
        config = AutoConfig.for_model(**ckpt['transformer_config'])
        self.model = Model(ckpt['model_config'], config=config).to(self.device)
        self.model.load_state_dict(ckpt['model'], strict=False)
        self.model.eval()
        self.max_length = self.model.transformer.config.max_position_embeddings
        self.seg_vocab = ckpt.get('seg', [WORD_MIDDLE, WORD_START])
        self.pos_vocab = ckpt.get('pos', [])
        self.ner_vocab = ckpt.get('ner', [])
        self.dep_vocab = ckpt.get('dep', [])
        self.sdp_vocab = ckpt.get('sdp', [])
        self.srl_vocab = [
            re.sub(r'ARG(\d)', r'A\1', tag.lstrip('ARGM-'))
            for tag in ckpt.get('srl', [])
        ]
        self.tokenizer = AutoTokenizer.from_pretrained(
            path, config=self.model.transformer.config, use_fast=True)
        self.trie = Trie()
Beispiel #5
0
def build_ner_distill_dataset(args):
    model = Model.load_from_checkpoint(
        args.resume_from_checkpoint, hparams=args
    )

    model.eval()
    model.freeze()

    dataset, metric = ner.build_dataset(
        model, args.ner_data_dir,
        ner.task_info.task_name
    )
    train_dataloader = torch.utils.data.DataLoader(
        dataset[datasets.Split.TRAIN],
        batch_size=args.batch_size,
        collate_fn=collate,
        num_workers=args.num_workers
    )

    output = os.path.join(args.ner_data_dir, ner.task_info.task_name, 'output.npz')

    if torch.cuda.is_available():
        model.cuda()
        map2cpu = lambda x: map2device(x)
        map2cuda = lambda x: map2device(x, model.device)
    else:
        map2cpu = lambda x: x
        map2cuda = lambda x: x

    with torch.no_grad():
        batchs = []
        for batch in tqdm(train_dataloader):
            batch = map2cuda(batch)
            logits = model.forward(task='ner', **batch).logits
            batch.update(logits=logits)
            batchs.append(map2cpu(batch))
        try:
            numpy.savez(
                output,
                data=convert2npy(batchs),
                extra=convert2npy({
                    'transitions': model.ner_classifier.crf.transitions,
                    'start_transitions': model.ner_classifier.crf.start_transitions,
                    'end_transitions': model.ner_classifier.crf.end_transitions
                })
            )
        except Exception as e:
            numpy.savez(output, data=convert2npy(batchs))

    print("Done")
Beispiel #6
0
def main():
    # 如果要输出 LTP master 分支可以使用的模型,传入 ltp_adapter 参数为输出文件夹路径,如 ltp_model
    parser = ArgumentParser()

    # add task level args
    parser = add_common_specific_args(parser)
    parser = add_tune_specific_args(parser)
    parser = add_task_specific_args(parser)

    # add model specific args
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    parser.set_defaults(min_epochs=1, max_epochs=10)
    parser.set_defaults(gradient_clip_val=1.0, lr_layers_getter='get_layer_lrs_with_crf')
    args = parser.parse_args()

    if args.ltp_model is not None and args.resume_from_checkpoint is not None:
        deploy_model(args, args.ltp_version)
    elif args.build_ner_dataset:
        build_ner_distill_dataset(args)
    elif args.tune:
        from ltp.utils.common_train import tune
        tune_config = {
            # 3e-4 for Small, 1e-4 for Base, 5e-5 for Large
            "lr": tune.loguniform(args.tune_min_lr, args.tune_max_lr),

            # dataset split
            "tau": tune.choice([0.8, 0.9, 1.0]),

            # 梯度衰减
            "weight_decay": tune.choice([0.0, 0.01]),

            # 梯度裁剪
            "gradient_clip_val": tune.choice([1.0, 2.0, 3.0, 4.0, 5.0]),

            # lr scheduler
            "lr_scheduler": tune.choice([
                'linear_schedule_with_warmup',
                'polynomial_decay_schedule_with_warmup',
            ]),
        }
        tune_train(args, model_class=Model, task_info=task_info, build_method=build_method, tune_config=tune_config)
    else:
        common_train(args, model_class=Model, task_info=task_info, build_method=build_method)
Beispiel #7
0
def main():
    parser = ArgumentParser()
    parser = add_task_specific_args(parser)
    parser = Model.add_model_specific_args(parser)
    parser = optimization.add_optimizer_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    parser.set_defaults(gradient_clip_val=1.0)
    args = parser.parse_args()

    if args.ltp_model is not None and args.resume_from_checkpoint is not None:
        deploy_model(args, args.ltp_version)
    else:
        common_train(args,
                     metric=f'val_{task_info.metric_name}',
                     model_class=Model,
                     build_method=build_method,
                     task=task_info.task_name)
Beispiel #8
0
def deploy_model_4_0(args, version):
    ltp_adapter_mapper = sorted([
        ('transformer', 'pretrained'),
        ('seg_classifier', 'seg_decoder'),
        ('pos_classifier', 'pos_decoder'),
        ('ner_classifier', 'ner_decoder'),
        ('ner_classifier.classifier', 'ner_decoder.mlp'),
        ('ner_classifier.relative_transformer', 'ner_decoder.transformer'),
        ('srl_classifier', 'srl_decoder'),
        ('srl_classifier.rel_atten', 'srl_decoder.biaffine'),
        ('srl_classifier.crf', 'srl_decoder.crf'),
        ('dep_classifier', 'dep_decoder'),
        ('sdp_classifier', 'sdp_decoder'),
    ],
                                key=lambda x: len(x[0]),
                                reverse=True)

    model = Model.load_from_checkpoint(args.resume_from_checkpoint,
                                       hparams=args)
    model_state_dict = OrderedDict(model.state_dict().items())
    for preffix, target_preffix in ltp_adapter_mapper:
        model_state_dict = {
            key.replace(preffix, target_preffix, 1): value
            for key, value in model_state_dict.items()
        }

    pos_labels = load_labels(args.pos_data_dir, 'vocabs', 'xpos.txt')
    ner_labels = load_labels(args.ner_data_dir, 'ner_labels.txt')
    srl_labels = load_labels(args.srl_data_dir, 'srl_labels.txt')
    dep_labels = load_labels(args.dep_data_dir, 'vocabs', 'deprel.txt')
    sdp_labels = load_labels(args.sdp_data_dir, 'vocabs', 'deps.txt')

    ltp_model = {
        'version': '4.0.0',
        'code_version': version,
        'seg': ['I-W', 'B-W'],
        'pos': pos_labels,
        'ner': ner_labels,
        'srl': srl_labels,
        'dep': dep_labels,
        'sdp': sdp_labels,
        'pretrained_config': model.transformer.config,
        'model_config': {
            'class': 'SimpleMultiTaskModel',
            'init': {
                'seg': {
                    'label_num': args.seg_num_labels
                },
                'pos': {
                    'label_num': args.pos_num_labels
                },
                'ner': {
                    'label_num': args.ner_num_labels,
                    'decoder': 'RelativeTransformer',
                    'RelativeTransformer': {
                        'num_heads': args.ner_num_heads,
                        'num_layers': args.ner_num_layers,
                        'hidden_size': args.ner_hidden_size,
                        'dropout': args.dropout
                    }
                },
                'dep': {
                    'label_num': args.dep_num_labels,
                    'decoder': 'Graph',
                    'Graph': {
                        'arc_hidden_size': args.dep_arc_hidden_size,
                        'rel_hidden_size': args.dep_rel_hidden_size,
                        'dropout': args.dropout
                    }
                },
                'sdp': {
                    'label_num': args.sdp_num_labels,
                    'decoder': 'Graph',
                    'Graph': {
                        'arc_hidden_size': args.sdp_arc_hidden_size,
                        'rel_hidden_size': args.sdp_rel_hidden_size,
                        'dropout': args.dropout
                    }
                },
                'srl': {
                    'label_num': args.srl_num_labels,
                    'decoder': 'BiLinearCRF',
                    'BiLinearCRF': {
                        'hidden_size': args.srl_hidden_size,
                        'dropout': args.dropout
                    }
                }
            }
        },
        'model': model_state_dict
    }
    os.makedirs(args.ltp_model, exist_ok=True)
    torch.save(ltp_model, os.path.join(args.ltp_model, 'ltp.model'))

    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.transformer)
    tokenizer.save_pretrained(args.ltp_model)
Beispiel #9
0
def deploy_model_4_1(args, version):
    from argparse import Namespace

    fake_parser = ArgumentParser()
    fake_parser = Model.add_model_specific_args(fake_parser)
    model_args, _ = fake_parser.parse_known_args(namespace=args)

    transformer_config = AutoConfig.from_pretrained(model_args.transformer)
    model = Model.load_from_checkpoint(args.resume_from_checkpoint,
                                       strict=False,
                                       hparams=model_args,
                                       config=transformer_config)

    model_config = Namespace(**model.hparams)
    # LOAD VOCAB
    pos_labels = load_labels(args.pos_data_dir, 'vocabs', 'xpos.txt')
    ner_labels = load_labels(args.ner_data_dir, 'ner_labels.txt')
    srl_labels = load_labels(args.srl_data_dir, 'srl_labels.txt')
    dep_labels = load_labels(args.dep_data_dir, 'vocabs', 'deprel.txt')
    sdp_labels = load_labels(args.sdp_data_dir, 'vocabs', 'deps.txt')

    # MODEL CLIP
    if not len(pos_labels):
        del model.pos_classifier
        model_config.pos_num_labels = 0

    if not len(ner_labels):
        del model.ner_classifier
        model_config.ner_num_labels = 0

    if not len(srl_labels):
        del model.srl_classifier
        model_config.srl_num_labels = 0

    if not len(dep_labels):
        del model.dep_classifier
        model_config.dep_num_labels = 0

    if not len(sdp_labels):
        del model.sdp_classifier
        model_config.sdp_num_labels = 0

    model_state_dict = OrderedDict(model.state_dict().items())

    ltp_model = {
        'version': version,
        'model': model_state_dict,
        'model_config': model_config,
        'transformer_config': model.transformer.config.to_dict(),
        'seg': ['I-W', 'B-W'],
        'pos': pos_labels,
        'ner': ner_labels,
        'srl': srl_labels,
        'dep': dep_labels,
        'sdp': sdp_labels,
    }
    os.makedirs(args.ltp_model, exist_ok=True)
    torch.save(ltp_model, os.path.join(args.ltp_model, 'ltp.model'))

    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.transformer)
    tokenizer.save_pretrained(args.ltp_model)
Beispiel #10
0
def build_method(model: Model, task_info: TaskInfo):
    multi_dataset, multi_metric = task_info.build_dataset(
        model,
        seg=model.hparams.seg_data_dir,
        pos=model.hparams.pos_data_dir,
        ner=model.hparams.ner_data_dir,
        dep=model.hparams.dep_data_dir,
        sdp=model.hparams.sdp_data_dir,
        srl=model.hparams.srl_data_dir)

    def train_dataloader(self):
        multi_dataloader = {
            task:
            torch.utils.data.DataLoader(task_dataset[datasets.Split.TRAIN],
                                        batch_size=self.hparams.batch_size,
                                        collate_fn=collate,
                                        num_workers=self.hparams.num_workers,
                                        pin_memory=True)
            for task, task_dataset in multi_dataset.items()
        }
        res = MultiTaskDataloader(tau=self.hparams.tau, **multi_dataloader)
        return res

    def training_step(self, batch, batch_idx):
        result = self(**batch)
        self.log("loss", result.loss.item())
        return {"loss": result.loss}

    def val_dataloader(self):
        return [
            torch.utils.data.DataLoader(
                task_dataset[datasets.Split.VALIDATION],
                batch_size=self.hparams.batch_size,
                collate_fn=collate,
                num_workers=self.hparams.num_workers,
                pin_memory=True)
            for task, task_dataset in multi_dataset.items()
        ]

    def test_dataloader(self):
        return [
            torch.utils.data.DataLoader(task_dataset[datasets.Split.TEST],
                                        batch_size=self.hparams.batch_size,
                                        collate_fn=collate,
                                        num_workers=self.hparams.num_workers,
                                        pin_memory=True)
            for task, task_dataset in multi_dataset.items()
        ]

    # AdamW + LR scheduler
    def configure_optimizers(self: Model):
        num_epoch_steps = sum(
            (len(dataset[datasets.Split.TRAIN]) + self.hparams.batch_size -
             1) // self.hparams.batch_size
            for dataset in multi_dataset.values())
        num_train_steps = num_epoch_steps * self.hparams.max_epochs
        optimizer, scheduler = optimization.from_argparse_args(
            self.hparams,
            model=self,
            num_train_steps=num_train_steps,
            n_transformer_layers=self.transformer.config.num_hidden_layers)
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]

    model.configure_optimizers = types.MethodType(configure_optimizers, model)

    model.train_dataloader = types.MethodType(train_dataloader, model)
    model.training_step = types.MethodType(training_step, model)

    validation_step, validation_epoch_end = task_info.validation_method(
        multi_metric,
        loss_tag='val_loss',
        metric_tags={
            task_name: f"val_{task_module.task_info.metric_name}"
            for task_name, task_module in task_builder.items()
        },
        metric_tag=f"val_{task_info.metric_name}")

    model.val_dataloader = types.MethodType(val_dataloader, model)
    model.validation_step = types.MethodType(validation_step, model)
    model.validation_epoch_end = types.MethodType(validation_epoch_end, model)

    test_step, test_epoch_end = task_info.validation_method(
        multi_metric,
        loss_tag='test_loss',
        metric_tags={
            task_name: f"test_{task_module.task_info.metric_name}"
            for task_name, task_module in task_builder.items()
        },
        metric_tag=f"test_{task_info.metric_name}")

    model.test_dataloader = types.MethodType(test_dataloader, model)
    model.test_step = types.MethodType(test_step, model)
    model.test_epoch_end = types.MethodType(test_epoch_end, model)
Beispiel #11
0
def build_method(model: Model, task_info: TaskInfo):
    (multi_dataset, distill_datasets,
     distill_datasets_extra), multi_metric = build_dataset(
         model,
         seg=model.hparams.seg_data_dir,
         pos=model.hparams.pos_data_dir,
         ner=model.hparams.ner_data_dir,
         dep=model.hparams.dep_data_dir,
         sdp=model.hparams.sdp_data_dir,
         srl=model.hparams.srl_data_dir)

    disable_distill = {
        'seg': model.hparams.disable_seg,
        'pos': model.hparams.disable_pos,
        'ner': model.hparams.disable_ner,
        'dep': model.hparams.disable_dep,
        'sdp': model.hparams.disable_sdp,
        'srl': model.hparams.disable_srl,
    }

    disable_distill = {
        task
        for task, disable in disable_distill.items() if disable
    }

    temperature_scheduler = flsw_temperature_scheduler_builder(
        beta=model.hparams.distill_beta,
        gamma=model.hparams.distill_gamma,
        base_temperature=model.hparams.temperature)

    def train_dataloader(self):
        multi_dataloader = {
            task:
            torch.utils.data.DataLoader(task_dataset,
                                        batch_size=None,
                                        num_workers=self.hparams.num_workers,
                                        pin_memory=True,
                                        shuffle=True)
            for task, task_dataset in distill_datasets.items()
        }
        res = MultiTaskDataloader(tau=self.hparams.tau, **multi_dataloader)
        return res

    def training_step(self: Model, batch, batch_idx):
        task = batch['task']
        target_logits = batch.pop('logits')
        result = self(**batch)
        norm_loss = result.loss

        if task not in disable_distill:
            distill_loss = distill_loss_map[task](
                batch,
                result,
                target_logits,
                temperature_scheduler,
                model,
                extra=distill_datasets_extra[task])
            distill_loss_weight = self.global_step / self.num_train_steps
            loss = distill_loss_weight * norm_loss + (
                1 - distill_loss_weight) * distill_loss

            self.log("distill_loss", distill_loss.item())
            self.log("norm_loss", norm_loss.item())
            self.log("loss", loss.item())
            return {"loss": loss}
        else:
            self.log("loss", norm_loss.item())
            return {"loss": norm_loss}

    def val_dataloader(self):
        return [
            torch.utils.data.DataLoader(
                task_dataset[datasets.Split.VALIDATION],
                batch_size=getattr(self.hparams, f'{task}_batch_size')
                or self.hparams.batch_size,
                collate_fn=collate,
                num_workers=self.hparams.num_workers,
                pin_memory=True)
            for task, task_dataset in multi_dataset.items()
        ]

    def test_dataloader(self):
        return [
            torch.utils.data.DataLoader(
                task_dataset[datasets.Split.TEST],
                batch_size=getattr(self.hparams, f'{task}_batch_size')
                or self.hparams.batch_size,
                collate_fn=collate,
                num_workers=self.hparams.num_workers,
                pin_memory=True)
            for task, task_dataset in multi_dataset.items()
        ]

    # AdamW + LR scheduler
    def configure_optimizers(self: Model):
        num_epoch_steps = sum(
            len(dataset) for dataset in distill_datasets.values())
        num_train_steps = num_epoch_steps * self.hparams.max_epochs
        setattr(self, 'num_train_steps', num_train_steps)
        optimizer, scheduler = optimization.from_argparse_args(
            self.hparams,
            model=self,
            num_train_steps=num_train_steps,
            n_transformer_layers=self.transformer.config.num_hidden_layers)
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]

    model.configure_optimizers = types.MethodType(configure_optimizers, model)

    model.train_dataloader = types.MethodType(train_dataloader, model)
    model.training_step = types.MethodType(training_step, model)

    validation_step, validation_epoch_end = task_info.validation_method(
        multi_metric, task=task_info.task_name, preffix='val')

    model.val_dataloader = types.MethodType(val_dataloader, model)
    model.validation_step = types.MethodType(validation_step, model)
    model.validation_epoch_end = types.MethodType(validation_epoch_end, model)

    test_step, test_epoch_end = task_info.validation_method(
        multi_metric, task=task_info.task_name, preffix='test')

    model.test_dataloader = types.MethodType(test_dataloader, model)
    model.test_step = types.MethodType(test_step, model)
    model.test_epoch_end = types.MethodType(test_epoch_end, model)
Beispiel #12
0
class LTP(object):
    model: Model
    seg_vocab: List[str]
    pos_vocab: List[str]
    ner_vocab: List[str]
    dep_vocab: List[str]
    sdp_vocab: List[str]
    srl_vocab: List[str]

    tensor: TensorType = TensorType.PYTORCH

    def __init__(self, path: str = 'small', device=None, **kwargs):
        if device is not None:
            if isinstance(device, torch.device):
                self.device = device
            elif isinstance(device, str):
                self.device = torch.device(device)
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        if path in model_map or is_remote_url(path) or os.path.isfile(path):
            proxies = kwargs.pop("proxies", None)
            cache_dir = kwargs.pop("cache_dir", LTP_CACHE)
            force_download = kwargs.pop("force_download", False)
            resume_download = kwargs.pop("resume_download", False)
            local_files_only = kwargs.pop("local_files_only", False)
            path = cached_path(model_map.get(path, path),
                               cache_dir=cache_dir,
                               force_download=force_download,
                               proxies=proxies,
                               resume_download=resume_download,
                               local_files_only=local_files_only,
                               extract_compressed_file=True)
        elif not os.path.isdir(path):
            raise FileNotFoundError()
        try:
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)
        except Exception as e:
            fake_import_pytorch_lightning()
            ckpt = torch.load(os.path.join(path, "ltp.model"),
                              map_location=self.device)

        patch_4_1_3(ckpt)

        self.cache_dir = path
        transformer_config = ckpt['transformer_config']
        transformer_config['torchscript'] = True
        config = AutoConfig.for_model(**transformer_config)
        self.model = Model(ckpt['model_config'], config=config).to(self.device)
        self.model.load_state_dict(ckpt['model'], strict=False)
        self.model.eval()

        self.seg_vocab = ckpt.get('seg', [WORD_MIDDLE, WORD_START])
        self.seg_vocab_dict = {
            tag: idx
            for idx, tag in enumerate(self.seg_vocab)
        }
        self.pos_vocab = ckpt.get('pos', [])
        self.ner_vocab = ckpt.get('ner', [])
        self.dep_vocab = ckpt.get('dep', [])
        self.sdp_vocab = ckpt.get('sdp', [])
        self.srl_vocab = [
            re.sub(r'ARG(\d)', r'A\1', tag.lstrip('ARGM-'))
            for tag in ckpt.get('srl', [])
        ]
        self.tokenizer = AutoTokenizer.from_pretrained(
            path, config=self.model.transformer.config, use_fast=True)
        self.trie = Trie()
        self._model_version = ckpt.get('version', None)

    def __str__(self):
        return f"LTP {self.version} on {self.device} (model version: {self.model_version}) "

    def __repr__(self):
        return f"LTP {self.version} on {self.device} (model version: {self.model_version}) "

    @property
    def avaliable_models(self):
        return model_map.keys()

    @property
    def version(self):
        from ltp import __version__ as version
        return version

    @property
    def model_version(self):
        return self._model_version or 'unknown'

    @property
    def max_length(self):
        return self.model.transformer.config.max_position_embeddings

    def init_dict(self, path, max_window=None):
        self.trie.init(path, max_window)

    def add_words(self, words, max_window=None):
        self.trie.add_words(words)
        self.trie.max_window = max_window

    @staticmethod
    def sent_split(inputs: List[str], flag: str = "all", limit: int = 510):
        inputs = [
            split_sentence(text, flag=flag, limit=limit) for text in inputs
        ]
        inputs = list(itertools.chain(*inputs))
        return inputs

    def seg_with_dict(self, inputs: List[str], tokenized: BatchEncoding,
                      batch_prefix):
        # 进行正向字典匹配
        matching = []
        for source_text, encoding, preffix in zip(inputs, tokenized.encodings,
                                                  batch_prefix):
            text = [
                source_text[start:end] for start, end in encoding.offsets[1:-1]
                if end != 0
            ]
            matching_pos = self.trie.maximum_forward_matching(text, preffix)
            matching.append(matching_pos)
        return matching

    @no_gard
    def _seg(self, tokenizerd, is_preseged=False):
        input_ids = tokenizerd['input_ids'].to(self.device)
        attention_mask = tokenizerd['attention_mask'].to(self.device)
        token_type_ids = tokenizerd['token_type_ids'].to(self.device)
        length = torch.sum(attention_mask, dim=-1) - 2

        pretrained_output, *_ = self.model.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=False)

        # remove [CLS] [SEP]
        word_cls = pretrained_output[:, :1]
        char_input = torch.narrow(pretrained_output, 1, 1,
                                  pretrained_output.size(1) - 2)
        if is_preseged:
            segment_output = None
        else:
            segment_output = torch.argmax(
                self.model.seg_classifier(char_input).logits,
                dim=-1).cpu().numpy()
        return word_cls, char_input, segment_output, length

    @no_gard
    def seg(self,
            inputs: Union[List[str], List[List[str]]],
            truncation: bool = True,
            is_preseged=False):
        """
        分词

        Args:
            inputs: 句子列表
            truncation: 是否对过长的句子进行截断,如果为 False 可能会抛出异常
            is_preseged:  是否已经进行过分词

        Returns:
            words: 分词后的序列
            hidden: 用于其他任务的中间表示
        """

        if transformers_version.major >= 3 and transformers_version.major > 1:
            kwargs = {'is_split_into_words': is_preseged}
        else:
            kwargs = {'is_pretokenized': is_preseged}

        tokenized = self.tokenizer.batch_encode_plus(
            inputs,
            padding=True,
            truncation=truncation,
            return_tensors=self.tensor,
            max_length=self.max_length,
            **kwargs)
        cls, hidden, seg, lengths = self._seg(tokenized,
                                              is_preseged=is_preseged)

        batch_prefix = [[
            word_idx != encoding.words[idx - 1]
            for idx, word_idx in enumerate(encoding.words)
            if word_idx is not None
        ] for encoding in tokenized.encodings]

        # merge segments with maximum forward matching
        if self.trie.is_init and not is_preseged:
            matches = self.seg_with_dict(inputs, tokenized, batch_prefix)
            for sent_match, sent_seg in zip(matches, seg):
                for start, end in sent_match:
                    sent_seg[start] = self.seg_vocab_dict[WORD_START]
                    sent_seg[start + 1:end] = self.seg_vocab_dict[WORD_MIDDLE]
                    if end < len(sent_seg):
                        sent_seg[end] = self.seg_vocab_dict[WORD_START]

        if is_preseged:
            sentences = inputs
            word_length = [len(sentence) for sentence in sentences]

            word_idx = []
            for encodings in tokenized.encodings:
                sentence_word_idx = []
                for idx, (start, end) in enumerate(encodings.offsets[1:]):
                    if start == 0 and end != 0:
                        sentence_word_idx.append(idx)
                word_idx.append(
                    torch.as_tensor(sentence_word_idx, device=self.device))
        else:
            segment_output = convert_idx_to_name(seg, lengths, self.seg_vocab)
            sentences = []
            word_idx = []
            word_length = []

            for source_text, length, encoding, seg_tag, preffix in \
                    zip(inputs, lengths, tokenized.encodings, segment_output, batch_prefix):
                offsets = encoding.offsets[1:length + 1]
                text = []
                last_offset = None
                for start, end in offsets:
                    text.append('' if last_offset == (
                        start, end) else source_text[start:end])
                    last_offset = (start, end)

                for idx in range(1, length):
                    current_beg = offsets[idx][0]
                    forward_end = offsets[idx - 1][-1]
                    if forward_end < current_beg:
                        text[idx] = source_text[
                            forward_end:current_beg] + text[idx]
                    if not preffix[idx]:
                        seg_tag[idx] = WORD_MIDDLE

                entities = get_entities(seg_tag)
                word_length.append(len(entities))
                sentences.append([
                    ''.join(text[entity[1]:entity[2] + 1]).strip()
                    for entity in entities
                ])
                word_idx.append(
                    torch.as_tensor([entity[1] for entity in entities],
                                    device=self.device))

        word_idx = torch.nn.utils.rnn.pad_sequence(word_idx, batch_first=True)
        word_idx = word_idx.unsqueeze(-1).expand(-1, -1,
                                                 hidden.shape[-1])  # 展开

        word_input = torch.gather(hidden, dim=1,
                                  index=word_idx)  # 每个word第一个char的向量

        if len(self.dep_vocab) + len(self.sdp_vocab) > 0:
            word_cls_input = torch.cat([cls, word_input], dim=1)
            word_cls_mask = length_to_mask(
                torch.as_tensor(word_length, device=self.device) + 1)
            word_cls_mask[:, 0] = False
        else:
            word_cls_input, word_cls_mask = None, None

        return sentences, {
            'word_cls': cls,
            'word_input': word_input,
            'word_length': word_length,
            'word_cls_input': word_cls_input,
            'word_cls_mask': word_cls_mask
        }

    @no_gard
    def pos(self, hidden: dict):
        """
        词性标注
        Args:
            hidden: 分词时所得到的中间表示

        Returns:
            pos: 词性标注结果
        """
        if len(self.pos_vocab) == 0:
            return []
        postagger_output = self.model.pos_classifier(
            hidden['word_input']).logits
        postagger_output = torch.argmax(postagger_output, dim=-1).cpu().numpy()
        postagger_output = convert_idx_to_name(postagger_output,
                                               hidden['word_length'],
                                               self.pos_vocab)
        return postagger_output

    @no_gard
    def ner(self, hidden: dict, as_entities=True):
        """
        命名实体识别
        Args:
            hidden: 分词时所得到的中间表示
            as_entities: 是否以 Entity(Type, Start, End) 的形式返回

        Returns:
            pos: 命名实体识别结果
        """
        if len(self.ner_vocab) == 0:
            return []
        ner_output = self.model.ner_classifier.forward(
            hidden['word_input'],
            word_attention_mask=hidden['word_cls_mask'][:, 1:])
        ner_output = ner_output.decoded or torch.argmax(ner_output.logits,
                                                        dim=-1).cpu().numpy()
        ner_output = convert_idx_to_name(ner_output, hidden['word_length'],
                                         self.ner_vocab)
        return [get_entities(ner)
                for ner in ner_output] if as_entities else ner_output

    @no_gard
    def srl(self, hidden: dict, keep_empty=True):
        """
        语义角色标注
        Args:
            hidden: 分词时所得到的中间表示

        Returns:
            pos: 语义角色标注结果
        """
        if len(self.srl_vocab) == 0:
            return []
        srl_output = self.model.srl_classifier.forward(
            input=hidden['word_input'],
            word_attention_mask=hidden['word_cls_mask'][:, 1:]).decoded
        srl_entities = get_entities_with_list(srl_output, self.srl_vocab)

        srl_labels_res = []
        for length in hidden['word_length']:
            srl_labels_res.append([])
            curr_srl_labels, srl_entities = srl_entities[:
                                                         length], srl_entities[
                                                             length:]
            srl_labels_res[-1].extend(curr_srl_labels)

        if not keep_empty:
            srl_labels_res = [[(idx, labels)
                               for idx, labels in enumerate(srl_labels)
                               if len(labels)]
                              for srl_labels in srl_labels_res]
        return srl_labels_res

    @no_gard
    def dep(self, hidden: dict, fast=True, as_tuple=True):
        """
        依存句法树
        Args:
            hidden: 分词时所得到的中间表示
            fast: 启用 fast 模式时,减少对结果的约束,速度更快,相应的精度会降低
            as_tuple: 返回的结果是否为 (idx, head, rel) 的格式,否则返回 heads, rels

        Returns:
            依存句法树结果
        """
        if len(self.dep_vocab) == 0:
            return []
        word_attention_mask = hidden['word_cls_mask']
        result = self.model.dep_classifier.forward(
            input=hidden['word_cls_input'],
            word_attention_mask=word_attention_mask[:, 1:])
        dep_arc, dep_label = result.arc_logits, result.rel_logits
        dep_arc[:, 0, 1:] = float('-inf')
        dep_arc.diagonal(0, 1, 2).fill_(float('-inf'))
        dep_arc = dep_arc.argmax(
            dim=-1) if fast else eisner(dep_arc, word_attention_mask)

        dep_label = torch.argmax(dep_label, dim=-1)
        dep_label = dep_label.gather(-1, dep_arc.unsqueeze(-1)).squeeze(-1)

        dep_arc[~word_attention_mask] = -1
        dep_label[~word_attention_mask] = -1

        head_pred = [[item for item in arcs if item != -1]
                     for arcs in dep_arc[:, 1:].cpu().numpy().tolist()]
        rel_pred = [[self.dep_vocab[item] for item in rels if item != -1]
                    for rels in dep_label[:, 1:].cpu().numpy().tolist()]
        if not as_tuple:
            return head_pred, rel_pred
        return [[(idx + 1, head, rel)
                 for idx, (head, rel) in enumerate(zip(heads, rels))]
                for heads, rels in zip(head_pred, rel_pred)]

    @no_gard
    def sdp(self, hidden: dict, mode: str = 'graph'):
        """
        语义依存图(树)
        Args:
            hidden: 分词时所得到的中间表示
            mode: ['tree', 'graph', 'mix']

        Returns:
            语义依存图(树)结果
        """
        if len(self.sdp_vocab) == 0:
            return []

        word_attention_mask = hidden['word_cls_mask']
        result = self.model.sdp_classifier(
            input=hidden['word_cls_input'],
            word_attention_mask=word_attention_mask[:, 1:])
        sdp_arc, sdp_label = result.arc_logits, result.rel_logits
        sdp_arc[:, 0, 1:] = float('-inf')
        sdp_arc.diagonal(0, 1, 2).fill_(float('-inf'))  # 避免自指
        sdp_label = torch.argmax(sdp_label, dim=-1)

        if mode == 'tree':
            # 语义依存树
            sdp_arc_idx = eisner(
                sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc_res = torch.zeros_like(sdp_arc, dtype=torch.bool).scatter_(
                -1, sdp_arc_idx, True)
        elif mode == 'mix':
            # 混合解码
            sdp_arc_idx = eisner(
                sdp_arc, word_attention_mask).unsqueeze_(-1).expand_as(sdp_arc)
            sdp_arc_res = (sdp_arc.sigmoid_() > 0.5).scatter_(
                -1, sdp_arc_idx, True)
        else:
            # 语义依存图
            sdp_arc_res = torch.sigmoid_(sdp_arc) > 0.5

        sdp_arc_res[~word_attention_mask] = False
        sdp_label = get_graph_entities(sdp_arc_res, sdp_label, self.sdp_vocab)

        return sdp_label