예제 #1
0
def evaluate(gold_file,
             pred_file,
             do_enhanced_collapse_empty_nodes=False,
             do_copy_cols=True):
    """Evaluate using official CoNLL-X evaluation script (Yuval Krymolowski)

    Args:
      gold_file(str): The gold conllx file
      pred_file(str): The pred conllx file
      do_enhanced_collapse_empty_nodes:  (Default value = False)
      do_copy_cols:  (Default value = True)

    Returns:

    
    """
    if do_enhanced_collapse_empty_nodes:
        gold_file = enhanced_collapse_empty_nodes(gold_file)
        pred_file = enhanced_collapse_empty_nodes(pred_file)
    if do_copy_cols:
        fixed_pred_file = pred_file.replace('.conllu', '.fixed.conllu')
        copy_cols(gold_file, pred_file, fixed_pred_file)
    else:
        fixed_pred_file = pred_file
    args = SerializableDict()
    args.enhancements = '0'
    args.gold_file = gold_file
    args.system_file = fixed_pred_file
    return iwpt20_xud_eval.evaluate_wrapper(args)
예제 #2
0
 def load_vocabs(self, save_dir, filename='vocabs.json'):
     save_dir = get_resource(save_dir)
     vocabs = SerializableDict()
     vocabs.load_json(os.path.join(save_dir, filename))
     for key, value in vocabs.items():
         vocab = VocabTF()
         vocab.copy_from(value)
         setattr(self.transform, key, vocab)
예제 #3
0
파일: transform.py 프로젝트: lei1993/HanLP
    def load_vocabs(self, save_dir, filename='vocabs.json', vocab_cls=Vocab):
        """Load vocabularies from a directory.

        Args:
            save_dir: The directory to load vocabularies.
            filename:  The name for vocabularies.
        """
        save_dir = get_resource(save_dir)
        vocabs = SerializableDict()
        vocabs.load_json(os.path.join(save_dir, filename))
        self._load_vocabs(self, vocabs, vocab_cls)
예제 #4
0
    def __init__(self, **kwargs) -> None:
        """The base class for all components using PyTorch as backend. It provides common workflows of building vocabs,
        datasets, dataloaders and models. These workflows are more of a conventional guideline than en-forced
        protocols, which means subclass has the freedom to override or completely skip some steps.

        Args:
            **kwargs: Addtional arguments to be stored in the ``config`` property.
        """
        super().__init__()
        self.model: Optional[torch.nn.Module] = None
        self.config = SerializableDict(**kwargs)
        self.vocabs = VocabDict()
예제 #5
0
파일: transform.py 프로젝트: lei1993/HanLP
    def save_vocabs(self, save_dir, filename='vocabs.json'):
        """Save vocabularies to a directory.

        Args:
            save_dir: The directory to save vocabularies.
            filename:  The name for vocabularies.
        """
        vocabs = SerializableDict()
        for key, value in self.items():
            if isinstance(value, Vocab):
                vocabs[key] = value.to_dict()
        vocabs.save_json(os.path.join(save_dir, filename))
예제 #6
0
 def __init__(self, transform: Transform) -> None:
     super().__init__()
     self.meta = {
         'class_path': classpath_of(self),
         'hanlp_version': hanlp.version.__version__,
     }
     self.model: Optional[tf.keras.Model] = None
     self.config = SerializableDict()
     self.transform = transform
     # share config with transform for convenience, so we don't need to pass args around
     if self.transform.config:
         for k, v in self.transform.config.items():
             self.config[k] = v
     self.transform.config = self.config
예제 #7
0
파일: structure.py 프로젝트: lei1993/HanLP
    def __init__(
        self,
        locals_: Dict,
        exclude=('kwargs', 'self', '__class__', 'locals_')) -> None:
        """This base class helps sub-classes to capture their arguments passed to ``__init__``, and also their types so
        that they can be deserialized from a config in dict form.

        Args:
            locals_: Obtained by :meth:`locals`.
            exclude: Arguments to be excluded.

        Examples:
            >>> class MyClass(ConfigTracker):
            >>>     def __init__(self, i_need_this='yes') -> None:
            >>>         super().__init__(locals())
            >>> obj = MyClass()
            >>> print(obj.config)
            {'i_need_this': 'yes', 'classpath': 'test_config_tracker.MyClass'}

        """
        if 'kwargs' in locals_:
            locals_.update(locals_['kwargs'])
        self.config = SerializableDict(
            (k, v.config if hasattr(v, 'config') else v)
            for k, v in locals_.items() if k not in exclude)
        self.config['classpath'] = classpath_of(self)
예제 #8
0
    def _savable_config(self):
        def convert(k, v):
            if not isinstance(v, SerializableDict) and hasattr(v, 'config'):
                v = v.config
            elif isinstance(v, (set, tuple)):
                v = list(v)
            if isinstance(v, dict):
                v = dict(convert(_k, _v) for _k, _v in v.items())
            return k, v

        config = SerializableDict(
            convert(k, v) for k, v in sorted(self.config.items()))
        config.update({
            # 'create_time': now_datetime(),
            'classpath': classpath_of(self),
            'hanlp_version': hanlp.__version__,
        })
        return config
예제 #9
0
 def __init__(self,
              config: SerializableDict = None,
              map_x=True,
              map_y=True,
              **kwargs) -> None:
     super().__init__()
     self.map_y = map_y
     self.map_x = map_x
     if kwargs:
         if not config:
             config = SerializableDict()
         for k, v in kwargs.items():
             config[k] = v
     self.config = config
     self.output_types = None
     self.output_shapes = None
     self.padding_values = None
예제 #10
0
 def __init__(self,
              config: SerializableDict = None,
              map_x=True,
              map_y=True,
              **kwargs) -> None:
     super().__init__()
     self.map_y = map_y
     self.map_x = map_x
     if kwargs:
         if not config:
             config = SerializableDict()
         for k, v in kwargs.items():
             config[k] = v
     self.config = config
     self.output_types = None
     self.output_shapes = None
     self.padding_values = None
     # Fix tf memory leak: https://github.com/tensorflow/tensorflow/issues/37653#issuecomment-1000517720
     self.py_func_set_to_cleanup = set()
예제 #11
0
class TorchComponent(Component, ABC):
    def __init__(self, **kwargs) -> None:
        """The base class for all components using PyTorch as backend. It provides common workflows of building vocabs,
        datasets, dataloaders and models. These workflows are more of a conventional guideline than en-forced
        protocols, which means subclass has the freedom to override or completely skip some steps.

        Args:
            **kwargs: Addtional arguments to be stored in the ``config`` property.
        """
        super().__init__()
        self.model: Optional[torch.nn.Module] = None
        self.config = SerializableDict(**kwargs)
        self.vocabs = VocabDict()

    def _capture_config(self, locals_: Dict,
                        exclude=(
                                'trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose',
                                'dev_batch_size', '__class__', 'devices', 'eval_trn')):
        """Save arguments to config

        Args:
          locals_: Dict: 
          exclude:  (Default value = ('trn_data')
          'dev_data': 
          'save_dir': 
          'kwargs': 
          'self': 
          'logger': 
          'verbose': 
          'dev_batch_size': 
          '__class__': 
          'devices'): 

        Returns:

        
        """
        if 'kwargs' in locals_:
            locals_.update(locals_['kwargs'])
        locals_ = dict((k, v) for k, v in locals_.items() if k not in exclude and not k.startswith('_'))
        self.config.update(locals_)
        return self.config

    def save_weights(self, save_dir, filename='model.pt', trainable_only=True, **kwargs):
        """Save model weights to a directory.

        Args:
            save_dir: The directory to save weights into.
            filename: A file name for weights.
            trainable_only: ``True`` to only save trainable weights. Useful when the model contains lots of static
                embeddings.
            **kwargs: Not used for now.
        """
        model = self.model_
        state_dict = model.state_dict()
        if trainable_only:
            trainable_names = set(n for n, p in model.named_parameters() if p.requires_grad)
            state_dict = dict((n, p) for n, p in state_dict.items() if n in trainable_names)
        torch.save(state_dict, os.path.join(save_dir, filename))

    def load_weights(self, save_dir, filename='model.pt', **kwargs):
        """Load weights from a directory.

        Args:
            save_dir: The directory to load weights from.
            filename: A file name for weights.
            **kwargs: Not used.
        """
        save_dir = get_resource(save_dir)
        filename = os.path.join(save_dir, filename)
        # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]')
        self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False)
        # flash('')

    def save_config(self, save_dir, filename='config.json'):
        """Save config into a directory.

        Args:
            save_dir: The directory to save config.
            filename: A file name for config.
        """
        self._savable_config.save_json(os.path.join(save_dir, filename))

    def load_config(self, save_dir, filename='config.json', **kwargs):
        """Load config from a directory.

        Args:
            save_dir: The directory to load config.
            filename: A file name for config.
            **kwargs: K-V pairs to override config.
        """
        save_dir = get_resource(save_dir)
        self.config.load_json(os.path.join(save_dir, filename))
        self.config.update(kwargs)  # overwrite config loaded from disk
        for k, v in self.config.items():
            if isinstance(v, dict) and 'classpath' in v:
                self.config[k] = Configurable.from_config(v)
        self.on_config_ready(**self.config)

    def save_vocabs(self, save_dir, filename='vocabs.json'):
        """Save vocabularies to a directory.

        Args:
            save_dir: The directory to save vocabularies.
            filename:  The name for vocabularies.
        """
        if hasattr(self, 'vocabs'):
            self.vocabs.save_vocabs(save_dir, filename)

    def load_vocabs(self, save_dir, filename='vocabs.json'):
        """Load vocabularies from a directory.

        Args:
            save_dir: The directory to load vocabularies.
            filename:  The name for vocabularies.
        """
        if hasattr(self, 'vocabs'):
            self.vocabs = VocabDict()
            self.vocabs.load_vocabs(save_dir, filename)

    def save(self, save_dir: str, **kwargs):
        """Save this component to a directory.

        Args:
            save_dir: The directory to save this component.
            **kwargs: Not used.
        """
        self.save_config(save_dir)
        self.save_vocabs(save_dir)
        self.save_weights(save_dir)

    def load(self, save_dir: str, devices=None, verbose=HANLP_VERBOSE, **kwargs):
        """Load from a local/remote component.

        Args:
            save_dir: An identifier which can be a local path or a remote URL or a pre-defined string.
            devices: The devices this component will be moved onto.
            verbose: ``True`` to log loading progress.
            **kwargs: To override some configs.
        """
        save_dir = get_resource(save_dir)
        # flash('Loading config and vocabs [blink][yellow]...[/yellow][/blink]')
        if devices is None and self.model:
            devices = self.devices
        self.load_config(save_dir, **kwargs)
        self.load_vocabs(save_dir)
        if verbose:
            flash('Building model [blink][yellow]...[/yellow][/blink]')
        self.model = self.build_model(
            **merge_dict(self.config, training=False, **kwargs, overwrite=True,
                         inplace=True))
        if verbose:
            flash('')
        self.load_weights(save_dir, **kwargs)
        self.to(devices)
        self.model.eval()

    def fit(self,
            trn_data,
            dev_data,
            save_dir,
            batch_size,
            epochs,
            devices=None,
            logger=None,
            seed=None,
            finetune: Union[bool, str] = False,
            eval_trn=True,
            _device_placeholder=False,
            **kwargs):
        """Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote
        files.

        Args:
            trn_data: Training set.
            dev_data: Development set.
            save_dir: The directory to save trained component.
            batch_size: The number of samples in a batch.
            epochs: Number of epochs.
            devices: Devices this component will live on.
            logger: Any :class:`logging.Logger` instance.
            seed: Random seed to reproduce this training.
            finetune: ``True`` to load from ``save_dir`` instead of creating a randomly initialized component. ``str``
                to specify a different ``save_dir`` to load from.
            eval_trn: Evaluate training set after each update. This can slow down the training but provides a quick
                diagnostic for debugging.
            _device_placeholder: ``True`` to create a placeholder tensor which triggers PyTorch to occupy devices so
                other components won't take these devices as first choices.
            **kwargs: Hyperparameters used by sub-classes.

        Returns:
            Any results sub-classes would like to return. Usually the best metrics on training set.

        """
        # Common initialization steps
        config = self._capture_config(locals())
        if not logger:
            logger = self.build_logger('train', save_dir)
        if not seed:
            self.config.seed = 233 if isdebugging() else int(time.time())
        set_seed(self.config.seed)
        logger.info(self._savable_config.to_json(sort=True))
        if isinstance(devices, list) or devices is None or isinstance(devices, float):
            flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]')
            devices = -1 if isdebugging() else cuda_devices(devices)
            flash('')
        # flash(f'Available GPUs: {devices}')
        if isinstance(devices, list):
            first_device = (devices[0] if devices else -1)
        elif isinstance(devices, dict):
            first_device = next(iter(devices.values()))
        elif isinstance(devices, int):
            first_device = devices
        else:
            first_device = -1
        if _device_placeholder and first_device >= 0:
            _dummy_placeholder = self._create_dummy_placeholder_on(first_device)
        if finetune:
            if isinstance(finetune, str):
                self.load(finetune, devices=devices)
            else:
                self.load(save_dir, devices=devices)
            logger.info(
                f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
                f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.')
        self.on_config_ready(**self.config)
        trn = self.build_dataloader(**merge_dict(config, data=trn_data, batch_size=batch_size, shuffle=True,
                                                 training=True, device=first_device, logger=logger, vocabs=self.vocabs,
                                                 overwrite=True))
        dev = self.build_dataloader(**merge_dict(config, data=dev_data, batch_size=batch_size, shuffle=False,
                                                 training=None, device=first_device, logger=logger, vocabs=self.vocabs,
                                                 overwrite=True)) if dev_data else None
        if not finetune:
            flash('[yellow]Building model [blink]...[/blink][/yellow]')
            self.model = self.build_model(**merge_dict(config, training=True))
            flash('')
            logger.info(f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
                        f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.')
            assert self.model, 'build_model is not properly implemented.'
        _description = repr(self.model)
        if len(_description.split('\n')) < 10:
            logger.info(_description)
        self.save_config(save_dir)
        self.save_vocabs(save_dir)
        self.to(devices, logger)
        if _device_placeholder and first_device >= 0:
            del _dummy_placeholder
        criterion = self.build_criterion(**merge_dict(config, trn=trn))
        optimizer = self.build_optimizer(**merge_dict(config, trn=trn, criterion=criterion))
        metric = self.build_metric(**self.config)
        if hasattr(trn.dataset, '__len__') and dev and hasattr(dev.dataset, '__len__'):
            logger.info(f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.')
            trn_size = len(trn) // self.config.get('gradient_accumulation', 1)
            ratio_width = len(f'{trn_size}/{trn_size}')
        else:
            ratio_width = None
        return self.execute_training_loop(**merge_dict(config, trn=trn, dev=dev, epochs=epochs, criterion=criterion,
                                                       optimizer=optimizer, metric=metric, logger=logger,
                                                       save_dir=save_dir,
                                                       devices=devices,
                                                       ratio_width=ratio_width,
                                                       trn_data=trn_data,
                                                       dev_data=dev_data,
                                                       eval_trn=eval_trn,
                                                       overwrite=True))

    def build_logger(self, name, save_dir):
        """Build a :class:`logging.Logger`.

        Args:
            name: The name of this logger.
            save_dir: The directory this logger should save logs into.

        Returns:
            logging.Logger: A logger.
        """
        logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO, fmt="%(message)s")
        return logger

    @abstractmethod
    def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None,
                         **kwargs) -> DataLoader:
        """Build dataloader for training, dev and test sets. It's suggested to build vocabs in this method if they are
        not built yet.

        Args:
            data: Data representing samples, which can be a path or a list of samples.
            batch_size: Number of samples per batch.
            shuffle: Whether to shuffle this dataloader.
            device: Device tensors should be loaded onto.
            logger: Logger for reporting some message if dataloader takes a long time or if vocabs has to be built.
            **kwargs: Arguments from ``**self.config``.
        """
        pass

    def build_vocabs(self, trn: torch.utils.data.Dataset, logger: logging.Logger):
        """Override this method to build vocabs.

        Args:
            trn: Training set.
            logger: Logger for reporting progress.
        """
        pass

    @property
    def _savable_config(self):
        def convert(k, v):
            if not isinstance(v, SerializableDict) and hasattr(v, 'config'):
                v = v.config
            elif isinstance(v, (set, tuple)):
                v = list(v)
            if isinstance(v, dict):
                v = dict(convert(_k, _v) for _k, _v in v.items())
            return k, v

        config = SerializableDict(
            convert(k, v) for k, v in sorted(self.config.items()))
        config.update({
            # 'create_time': now_datetime(),
            'classpath': classpath_of(self),
            'hanlp_version': hanlp.__version__,
        })
        return config

    @abstractmethod
    def build_optimizer(self, **kwargs):
        """Implement this method to build an optimizer.

        Args:
            **kwargs: The subclass decides the method signature.
        """
        pass

    @abstractmethod
    def build_criterion(self, decoder, **kwargs):
        """Implement this method to build criterion (loss function).

        Args:
            decoder: The model or decoder.
            **kwargs: The subclass decides the method signature.
        """
        pass

    @abstractmethod
    def build_metric(self, **kwargs):
        """Implement this to build metric(s).

        Args:
            **kwargs: The subclass decides the method signature.
        """
        pass

    @abstractmethod
    def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
                              logger: logging.Logger, devices, ratio_width=None,
                              **kwargs):
        """Implement this to run training loop.

        Args:
            trn: Training set.
            dev: Development set.
            epochs: Number of epochs.
            criterion: Loss function.
            optimizer: Optimizer(s).
            metric: Metric(s)
            save_dir: The directory to save this component.
            logger: Logger for reporting progress.
            devices: Devices this component and dataloader will live on.
            ratio_width: The width of dataset size measured in number of characters. Used for logger to align messages.
            **kwargs: Other hyper-parameters passed from sub-class.
        """
        pass

    @abstractmethod
    def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs):
        """Fit onto a dataloader.

        Args:
            trn: Training set.
            criterion: Loss function.
            optimizer: Optimizer.
            metric: Metric(s).
            logger: Logger for reporting progress.
            **kwargs: Other hyper-parameters passed from sub-class.
        """
        pass

    @abstractmethod
    def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric=None, output=False, **kwargs):
        """Evaluate on a dataloader.

        Args:
            data: Dataloader which can build from any data source.
            criterion: Loss function.
            metric: Metric(s).
            output: Whether to save outputs into some file.
            **kwargs: Not used.
        """
        pass

    @abstractmethod
    def build_model(self, training=True, **kwargs) -> torch.nn.Module:
        """Build model.

        Args:
            training: ``True`` if called during training.
            **kwargs: ``**self.config``.
        """
        raise NotImplementedError

    def evaluate(self, tst_data, save_dir=None, logger: logging.Logger = None, batch_size=None, output=False, **kwargs):
        """Evaluate test set.

        Args:
            tst_data: Test set, which is usually a file path.
            save_dir: The directory to save evaluation scores or predictions.
            logger: Logger for reporting progress.
            batch_size: Batch size for test dataloader.
            output: Whether to save outputs into some file.
            **kwargs: Not used.

        Returns:
            (metric, outputs) where outputs are the return values of ``evaluate_dataloader``.
        """
        if not self.model:
            raise RuntimeError('Call fit or load before evaluate.')
        if isinstance(tst_data, str):
            tst_data = get_resource(tst_data)
            filename = os.path.basename(tst_data)
        else:
            filename = None
        if output is True:
            output = self.generate_prediction_filename(tst_data if isinstance(tst_data, str) else 'test.txt', save_dir)
        if logger is None:
            _logger_name = basename_no_ext(filename) if filename else None
            logger = self.build_logger(_logger_name, save_dir)
        if not batch_size:
            batch_size = self.config.get('batch_size', 32)
        data = self.build_dataloader(**merge_dict(self.config, data=tst_data, batch_size=batch_size, shuffle=False,
                                                  device=self.devices[0], logger=logger, overwrite=True))
        dataset = data
        while dataset and hasattr(dataset, 'dataset'):
            dataset = dataset.dataset
        num_samples = len(dataset) if dataset else None
        if output and isinstance(dataset, TransformableDataset):
            def add_idx(samples):
                for idx, sample in enumerate(samples):
                    if sample:
                        sample[IDX] = idx

            add_idx(dataset.data)
            if dataset.cache:
                add_idx(dataset.cache)

        criterion = self.build_criterion(**self.config)
        metric = self.build_metric(**self.config)
        start = time.time()
        outputs = self.evaluate_dataloader(data, criterion=criterion, filename=filename, output=output, input=tst_data,
                                           save_dir=save_dir,
                                           test=True,
                                           num_samples=num_samples,
                                           **merge_dict(self.config, batch_size=batch_size, metric=metric,
                                                        logger=logger, **kwargs))
        elapsed = time.time() - start
        if logger:
            if num_samples:
                logger.info(f'speed: {num_samples / elapsed:.0f} samples/second')
            else:
                logger.info(f'speed: {len(data) / elapsed:.0f} batches/second')
        return metric, outputs

    def generate_prediction_filename(self, tst_data, save_dir):
        assert isinstance(tst_data,
                          str), 'tst_data has be a str in order to infer the output name'
        output = os.path.splitext(os.path.basename(tst_data))
        output = os.path.join(save_dir, output[0] + '.pred' + output[1])
        return output

    def to(self,
           devices=Union[int, float, List[int], Dict[str, Union[int, torch.device]]],
           logger: logging.Logger = None, verbose=HANLP_VERBOSE):
        """Move this component to devices.

        Args:
            devices: Target devices.
            logger: Logger for printing progress report, as copying a model from CPU to GPU can takes several seconds.
            verbose: ``True`` to print progress when logger is None.
        """
        if devices == -1 or devices == [-1]:
            devices = []
        elif isinstance(devices, (int, float)) or devices is None:
            devices = cuda_devices(devices)
        if devices:
            if logger:
                logger.info(f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]')
            if isinstance(devices, list):
                if verbose:
                    flash(f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]')
                self.model = self.model.to(devices[0])
                if len(devices) > 1 and not isdebugging() and not isinstance(self.model, nn.DataParallel):
                    self.model = self.parallelize(devices)
            elif isinstance(devices, dict):
                for name, module in self.model.named_modules():
                    for regex, device in devices.items():
                        try:
                            on_device: torch.device = next(module.parameters()).device
                        except StopIteration:
                            continue
                        if on_device == device:
                            continue
                        if isinstance(device, int):
                            if on_device.index == device:
                                continue
                        if re.match(regex, name):
                            if not name:
                                name = '*'
                            flash(f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}'
                                  f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n')
                            module.to(device)
            else:
                raise ValueError(f'Unrecognized devices {devices}')
            if verbose:
                flash('')
        else:
            if logger:
                logger.info('Using CPU')

    def parallelize(self, devices: List[Union[int, torch.device]]):
        return nn.DataParallel(self.model, device_ids=devices)

    @property
    def devices(self):
        """The devices this component lives on.
        """
        if self.model is None:
            return None
        # next(parser.model.parameters()).device
        if hasattr(self.model, 'device_ids'):
            return self.model.device_ids
        device: torch.device = next(self.model.parameters()).device
        return [device]

    @property
    def device(self):
        """The first device this component lives on.
        """
        devices = self.devices
        if not devices:
            return None
        return devices[0]

    def on_config_ready(self, **kwargs):
        """Called when config is ready, either during ``fit`` ot ``load``. Subclass can perform extra initialization
        tasks in this callback.

        Args:
            **kwargs: Not used.
        """
        pass

    @property
    def model_(self) -> nn.Module:
        """
        The actual model when it's wrapped by a `DataParallel`

        Returns: The "real" model

        """
        if isinstance(self.model, nn.DataParallel):
            return self.model.module
        return self.model

    # noinspection PyMethodOverriding
    @abstractmethod
    def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs):
        """Predict on data fed by user. Users shall avoid directly call this method since it is not guarded with
        ``torch.no_grad`` and will introduces unnecessary gradient computation. Use ``__call__`` instead.

        Args:
            data: Sentences or tokens.
            batch_size: Decoding batch size.
            **kwargs: Used in sub-classes.
        """
        pass

    @staticmethod
    def _create_dummy_placeholder_on(device):
        if device < 0:
            device = 'cpu:0'
        return torch.zeros(16, 16, device=device)

    @torch.no_grad()
    def __call__(self, data, batch_size=None, **kwargs):
        """Predict on data fed by user. This method calls :meth:`~hanlp.common.torch_component.predict` but decorates
        it with ``torch.no_grad``.

        Args:
            data: Sentences or tokens.
            batch_size: Decoding batch size.
            **kwargs: Used in sub-classes.
        """
        return super().__call__(data, **merge_dict(self.config, overwrite=True,
                                                   batch_size=batch_size or self.config.get('batch_size', None),
                                                   **kwargs))
예제 #12
0
class KerasComponent(Component, ABC):
    def __init__(self, transform: Transform) -> None:
        super().__init__()
        self.meta = {
            'class_path': classpath_of(self),
            'hanlp_version': hanlp.version.__version__,
        }
        self.model: Optional[tf.keras.Model] = None
        self.config = SerializableDict()
        self.transform = transform
        # share config with transform for convenience, so we don't need to pass args around
        if self.transform.config:
            for k, v in self.transform.config.items():
                self.config[k] = v
        self.transform.config = self.config

    def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None,
                 callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs):
        input_path = get_resource(input_path)
        file_prefix, ext = os.path.splitext(input_path)
        name = os.path.basename(file_prefix)
        if not name:
            name = 'evaluate'
        if save_dir and not logger:
            logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO if verbose else logging.WARN,
                                 mode='w')
        tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size)
        samples = self.num_samples_in(tst_data)
        num_batches = math.ceil(samples / batch_size)
        if warm_up:
            for x, y in tst_data:
                self.model.predict_on_batch(x)
                break
        if output:
            assert save_dir, 'Must pass save_dir in order to output'
            if isinstance(output, bool):
                output = os.path.join(save_dir, name) + '.predict' + ext
            elif isinstance(output, str):
                output = output
            else:
                raise RuntimeError('output ({}) must be of type bool or str'.format(repr(output)))
        timer = Timer()
        eval_outputs = self.evaluate_dataset(tst_data, callbacks, output, num_batches, **kwargs)
        loss, score, output = eval_outputs[0], eval_outputs[1], eval_outputs[2]
        delta_time = timer.stop()
        speed = samples / delta_time.delta_seconds

        if logger:
            f1: IOBES_F1_TF = None
            for metric in self.model.metrics:
                if isinstance(metric, IOBES_F1_TF):
                    f1 = metric
                    break
            extra_report = ''
            if f1:
                overall, by_type, extra_report = f1.state.result(full=True, verbose=False)
                extra_report = ' \n' + extra_report
            logger.info('Evaluation results for {} - '
                        'loss: {:.4f} - {} - speed: {:.2f} sample/sec{}'
                        .format(name + ext, loss,
                                format_scores(score) if isinstance(score, dict) else format_metrics(self.model.metrics),
                                speed, extra_report))
        if output:
            logger.info('Saving output to {}'.format(output))
            with open(output, 'w', encoding='utf-8') as out:
                self.evaluate_output(tst_data, out, num_batches, self.model.metrics)

        return loss, score, speed

    def num_samples_in(self, dataset):
        return size_of_dataset(dataset)

    def evaluate_dataset(self, tst_data, callbacks, output, num_batches, **kwargs):
        loss, score = self.model.evaluate(tst_data, callbacks=callbacks, steps=num_batches)
        return loss, score, output

    def evaluate_output(self, tst_data, out, num_batches, metrics: List[tf.keras.metrics.Metric]):
        # out.write('x\ty_true\ty_pred\n')
        for metric in metrics:
            metric.reset_states()
        for idx, batch in enumerate(tst_data):
            outputs = self.model.predict_on_batch(batch[0])
            for metric in metrics:
                metric(batch[1], outputs, outputs._keras_mask if hasattr(outputs, '_keras_mask') else None)
            self.evaluate_output_to_file(batch, outputs, out)
            print('\r{}/{} {}'.format(idx + 1, num_batches, format_metrics(metrics)), end='')
        print()

    def evaluate_output_to_file(self, batch, outputs, out):
        for x, y_gold, y_pred in zip(self.transform.X_to_inputs(batch[0]),
                                     self.transform.Y_to_outputs(batch[1], gold=True),
                                     self.transform.Y_to_outputs(outputs, gold=False)):
            out.write(self.transform.input_truth_output_to_str(x, y_gold, y_pred))

    def _capture_config(self, config: Dict,
                        exclude=(
                                'trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose',
                                'dev_batch_size', '__class__')):
        """
        Save arguments to config

        Parameters
        ----------
        config
            `locals()`
        exclude
        """
        if 'kwargs' in config:
            config.update(config['kwargs'])
        config = dict(
            (key, tf.keras.utils.serialize_keras_object(value)) if hasattr(value, 'get_config') else (key, value) for
            key, value in config.items())
        for key in exclude:
            config.pop(key, None)
        self.config.update(config)

    def save_meta(self, save_dir, filename='meta.json', **kwargs):
        self.meta['create_time']: now_datetime()
        self.meta.update(kwargs)
        save_json(self.meta, os.path.join(save_dir, filename))

    def load_meta(self, save_dir, filename='meta.json'):
        save_dir = get_resource(save_dir)
        metapath = os.path.join(save_dir, filename)
        if os.path.isfile(metapath):
            self.meta.update(load_json(metapath))

    def save_config(self, save_dir, filename='config.json'):
        self.config.save_json(os.path.join(save_dir, filename))

    def load_config(self, save_dir, filename='config.json'):
        save_dir = get_resource(save_dir)
        self.config.load_json(os.path.join(save_dir, filename))

    def save_weights(self, save_dir, filename='model.h5'):
        self.model.save_weights(os.path.join(save_dir, filename))

    def load_weights(self, save_dir, filename='model.h5', **kwargs):
        assert self.model.built or self.model.weights, 'You must call self.model.built() in build_model() ' \
                                                       'in order to load it'
        save_dir = get_resource(save_dir)
        self.model.load_weights(os.path.join(save_dir, filename))

    def save_vocabs(self, save_dir, filename='vocabs.json'):
        vocabs = SerializableDict()
        for key, value in vars(self.transform).items():
            if isinstance(value, VocabTF):
                vocabs[key] = value.to_dict()
        vocabs.save_json(os.path.join(save_dir, filename))

    def load_vocabs(self, save_dir, filename='vocabs.json'):
        save_dir = get_resource(save_dir)
        vocabs = SerializableDict()
        vocabs.load_json(os.path.join(save_dir, filename))
        for key, value in vocabs.items():
            vocab = VocabTF()
            vocab.copy_from(value)
            setattr(self.transform, key, vocab)

    def load_transform(self, save_dir) -> Transform:
        """
        Try to load transform only. This method might fail due to the fact it avoids building the model.
        If it do fail, then you have to use `load` which might be too heavy but that's the best we can do.
        :param save_dir: The path to load.
        """
        save_dir = get_resource(save_dir)
        self.load_config(save_dir)
        self.load_vocabs(save_dir)
        self.transform.build_config()
        self.transform.lock_vocabs()
        return self.transform

    def save(self, save_dir: str, **kwargs):
        self.save_config(save_dir)
        self.save_vocabs(save_dir)
        self.save_weights(save_dir)

    def load(self, save_dir: str, logger=hanlp.utils.log_util.logger, **kwargs):
        self.meta['load_path'] = save_dir
        save_dir = get_resource(save_dir)
        self.load_config(save_dir)
        self.load_vocabs(save_dir)
        self.build(**merge_dict(self.config, training=False, logger=logger, **kwargs, overwrite=True, inplace=True))
        self.load_weights(save_dir, **kwargs)
        self.load_meta(save_dir)

    @property
    def input_shape(self) -> List:
        return self.transform.output_shapes[0]

    def build(self, logger, **kwargs):
        self.transform.build_config()
        self.model = self.build_model(**merge_dict(self.config, training=kwargs.get('training', None),
                                                   loss=kwargs.get('loss', None)))
        self.transform.lock_vocabs()
        optimizer = self.build_optimizer(**self.config)
        loss = self.build_loss(
            **self.config if 'loss' in self.config else dict(list(self.config.items()) + [('loss', None)]))
        # allow for different
        metrics = self.build_metrics(**merge_dict(self.config, metrics=kwargs.get('metrics', 'accuracy'),
                                                  logger=logger, overwrite=True))
        if not isinstance(metrics, list):
            if isinstance(metrics, tf.keras.metrics.Metric):
                metrics = [metrics]
        if not self.model.built:
            sample_inputs = self.sample_data
            if sample_inputs is not None:
                self.model(sample_inputs)
            else:
                if len(self.transform.output_shapes[0]) == 1 and self.transform.output_shapes[0][0] is None:
                    x_shape = self.transform.output_shapes[0]
                else:
                    x_shape = list(self.transform.output_shapes[0])
                    for i, shape in enumerate(x_shape):
                        x_shape[i] = [None] + shape  # batch + X.shape
                self.model.build(input_shape=x_shape)
        self.compile_model(optimizer, loss, metrics)
        return self.model, optimizer, loss, metrics

    def compile_model(self, optimizer, loss, metrics):
        self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics, run_eagerly=self.config.run_eagerly)

    def build_optimizer(self, optimizer, **kwargs):
        if isinstance(optimizer, (str, dict)):
            custom_objects = {'AdamWeightDecay': AdamWeightDecay}
            optimizer: tf.keras.optimizers.Optimizer = tf.keras.utils.deserialize_keras_object(optimizer,
                                                                                               module_objects=vars(tf.keras.optimizers),
                                                                                               custom_objects=custom_objects)
        self.config.optimizer = tf.keras.utils.serialize_keras_object(optimizer)
        return optimizer

    def build_loss(self, loss, **kwargs):
        if not loss:
            loss = tf.keras.losses.SparseCategoricalCrossentropy(
                reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
                from_logits=True)
        elif isinstance(loss, (str, dict)):
            loss = tf.keras.utils.deserialize_keras_object(loss, module_objects=vars(tf.keras.losses))
        if isinstance(loss, tf.keras.losses.Loss):
            self.config.loss = tf.keras.utils.serialize_keras_object(loss)
        return loss

    def build_transform(self, **kwargs):
        return self.transform

    def build_vocab(self, trn_data, logger):
        train_examples = self.transform.fit(trn_data, **self.config)
        self.transform.summarize_vocabs(logger)
        return train_examples

    def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
        metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
        return [metric]

    @abstractmethod
    def build_model(self, **kwargs) -> tf.keras.Model:
        pass

    def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True,
            finetune: str = None, **kwargs):
        self._capture_config(locals())
        self.transform = self.build_transform(**self.config)
        if not save_dir:
            save_dir = tempdir_human()
        if not logger:
            logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN)
        logger.info('Hyperparameter:\n' + self.config.to_json())
        num_examples = self.build_vocab(trn_data, logger)
        # assert num_examples, 'You forgot to return the number of training examples in your build_vocab'
        logger.info('Building...')
        train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None
        self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None
        model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True))
        logger.info('Model built:\n' + summary_of_model(self.model))
        if finetune:
            finetune = get_resource(finetune)
            if os.path.isdir(finetune):
                finetune = os.path.join(finetune, 'model.h5')
            model.load_weights(finetune, by_name=True, skip_mismatch=True)
            logger.info(f'Loaded pretrained weights from {finetune} for finetuning')
        self.save_config(save_dir)
        self.save_vocabs(save_dir)
        self.save_meta(save_dir)
        trn_data = self.build_train_dataset(trn_data, batch_size, num_examples)
        dev_data = self.build_valid_dataset(dev_data, batch_size)
        callbacks = self.build_callbacks(save_dir, **merge_dict(self.config, overwrite=True, logger=logger))
        # need to know #batches, otherwise progbar crashes
        dev_steps = math.ceil(self.num_samples_in(dev_data) / batch_size)
        checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint)
        timer = Timer()
        try:
            history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs,
                                                   num_examples=num_examples,
                                                   train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps,
                                                   callbacks=callbacks, logger=logger, model=model, optimizer=optimizer,
                                                   loss=loss,
                                                   metrics=metrics, overwrite=True))
        except KeyboardInterrupt:
            print()
            if not checkpoint or checkpoint.best in (np.Inf, -np.Inf):
                self.save_weights(save_dir)
                logger.info('Aborted with model saved')
            else:
                logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}')
            # noinspection PyTypeChecker
            history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History)
        delta_time = timer.stop()
        best_epoch_ago = 0
        if history and hasattr(history, 'epoch'):
            trained_epoch = len(history.epoch)
            logger.info('Trained {} epochs in {}, each epoch takes {}'.
                        format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time))
            save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder)
            monitor_history: List = history.history.get(checkpoint.monitor, None)
            if monitor_history:
                best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best)
            if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]:
                logger.info(f'Restored the best model saved with best '
                            f'{checkpoint.monitor} = {checkpoint.best:.4f} '
                            f'saved {best_epoch_ago} epochs ago')
                self.load_weights(save_dir)  # restore best model
        return history

    def train_loop(self, trn_data, dev_data, epochs, num_examples, train_steps_per_epoch, dev_steps, model, optimizer,
                   loss, metrics, callbacks,
                   logger, **kwargs):
        history = self.model.fit(trn_data, epochs=epochs, steps_per_epoch=train_steps_per_epoch,
                                 validation_data=dev_data,
                                 callbacks=callbacks,
                                 validation_steps=dev_steps,
                                 )  # type:tf.keras.callbacks.History
        return history

    def build_valid_dataset(self, dev_data, batch_size):
        dev_data = self.transform.file_to_dataset(dev_data, batch_size=batch_size, shuffle=False)
        return dev_data

    def build_train_dataset(self, trn_data, batch_size, num_examples):
        trn_data = self.transform.file_to_dataset(trn_data, batch_size=batch_size,
                                                  shuffle=True,
                                                  repeat=-1 if self.config.train_steps else None)
        return trn_data

    def build_callbacks(self, save_dir, logger, **kwargs):
        metrics = kwargs.get('metrics', 'accuracy')
        if isinstance(metrics, (list, tuple)):
            metrics = metrics[-1]
        monitor = f'val_{metrics}'
        checkpoint = tf.keras.callbacks.ModelCheckpoint(
            os.path.join(save_dir, 'model.h5'),
            # verbose=1,
            monitor=monitor, save_best_only=True,
            mode='max',
            save_weights_only=True)
        logger.debug(f'Monitor {checkpoint.monitor} for checkpoint')
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=io_util.makedirs(io_util.path_join(save_dir, 'logs')))
        csv_logger = FineCSVLogger(os.path.join(save_dir, 'train.log'), separator=' | ', append=True)
        callbacks = [checkpoint, tensorboard_callback, csv_logger]
        lr_decay_per_epoch = self.config.get('lr_decay_per_epoch', None)
        if lr_decay_per_epoch:
            learning_rate = self.model.optimizer.get_config().get('learning_rate', None)
            if not learning_rate:
                logger.warning('Learning rate decay not supported for optimizer={}'.format(repr(self.model.optimizer)))
            else:
                logger.debug(f'Created LearningRateScheduler with lr_decay_per_epoch={lr_decay_per_epoch}')
                callbacks.append(tf.keras.callbacks.LearningRateScheduler(
                    lambda epoch: learning_rate / (1 + lr_decay_per_epoch * epoch)))
        anneal_factor = self.config.get('anneal_factor', None)
        if anneal_factor:
            callbacks.append(tf.keras.callbacks.ReduceLROnPlateau(factor=anneal_factor,
                                                                  patience=self.config.get('anneal_patience', 10)))
        early_stopping_patience = self.config.get('early_stopping_patience', None)
        if early_stopping_patience:
            callbacks.append(tf.keras.callbacks.EarlyStopping(monitor=monitor, mode='max',
                                                              verbose=1,
                                                              patience=early_stopping_patience))
        return callbacks

    def on_train_begin(self):
        """
        Callback before the training starts
        """
        pass

    def predict(self, data: Any, batch_size=None, **kwargs):
        assert self.model, 'Please call fit or load before predict'
        if not data:
            return []
        data, flat = self.transform.input_to_inputs(data)

        if not batch_size:
            batch_size = self.config.batch_size

        dataset = self.transform.inputs_to_dataset(data, batch_size=batch_size, gold=kwargs.get('gold', False))

        results = []
        num_samples = 0
        data_is_list = isinstance(data, list)
        for idx, batch in enumerate(dataset):
            samples_in_batch = tf.shape(batch[-1] if isinstance(batch[-1], tf.Tensor) else batch[-1][0])[0]
            if data_is_list:
                inputs = data[num_samples:num_samples + samples_in_batch]
            else:
                inputs = None  # if data is a generator, it's usually one-time, not able to transform into a list
            for output in self.predict_batch(batch, inputs=inputs, **kwargs):
                results.append(output)
            num_samples += samples_in_batch

        if flat:
            return results[0]
        return results

    def predict_batch(self, batch, inputs=None, **kwargs):
        X = batch[0]
        Y = self.model.predict_on_batch(X)
        for output in self.transform.Y_to_outputs(Y, X=X, inputs=inputs, batch=batch, **kwargs):
            yield output

    @property
    def sample_data(self):
        return None

    @staticmethod
    def from_meta(meta: dict, **kwargs):
        """

        Parameters
        ----------
        meta
        kwargs

        Returns
        -------
        KerasComponent

        """
        cls = str_to_type(meta['class_path'])
        obj: KerasComponent = cls()
        assert 'load_path' in meta, f'{meta} doesn\'t contain load_path field'
        obj.load(meta['load_path'])
        return obj

    def export_model_for_serving(self, export_dir=None, version=1, overwrite=False, show_hint=False):
        assert self.model, 'You have to fit or load a model before exporting it'
        if not export_dir:
            assert 'load_path' in self.meta, 'When not specifying save_dir, load_path has to present'
            export_dir = get_resource(self.meta['load_path'])
        model_path = os.path.join(export_dir, str(version))
        if os.path.isdir(model_path) and not overwrite:
            logger.info(f'{model_path} exists, skip since overwrite = {overwrite}')
            return export_dir
        logger.info(f'Exporting to {export_dir} ...')
        tf.saved_model.save(self.model, model_path)
        logger.info(f'Successfully exported model to {export_dir}')
        if show_hint:
            logger.info(f'You can serve it through \n'
                        f'tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} '
                        f'--model_base_path={export_dir} --rest_api_port=8888')
        return export_dir

    def serve(self, export_dir=None, grpc_port=8500, rest_api_port=0, overwrite=False, dry_run=False):
        export_dir = self.export_model_for_serving(export_dir, show_hint=False, overwrite=overwrite)
        if not dry_run:
            del self.model  # free memory
        logger.info('The inputs of exported model is shown below.')
        os.system(f'saved_model_cli show --all --dir {export_dir}/1')
        cmd = f'nohup tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} ' \
              f'--model_base_path={export_dir} --port={grpc_port} --rest_api_port={rest_api_port} ' \
              f'>serve.log 2>&1 &'
        logger.info(f'Running ...\n{cmd}')
        if not dry_run:
            os.system(cmd)
예제 #13
0
 def save_vocabs(self, save_dir, filename='vocabs.json'):
     vocabs = SerializableDict()
     for key, value in vars(self.transform).items():
         if isinstance(value, VocabTF):
             vocabs[key] = value.to_dict()
     vocabs.save_json(os.path.join(save_dir, filename))
예제 #14
0
파일: transform.py 프로젝트: lei1993/HanLP
 def load_vocab(self, save_dir, filename='vocab.json'):
     save_dir = get_resource(save_dir)
     vocab = SerializableDict()
     vocab.load_json(os.path.join(save_dir, filename))
     self.vocab.copy_from(vocab)
예제 #15
0
파일: transform.py 프로젝트: lei1993/HanLP
 def save_vocab(self, save_dir, filename='vocab.json'):
     vocab = SerializableDict()
     vocab.update(self.vocab.to_dict())
     vocab.save_json(os.path.join(save_dir, filename))
예제 #16
0
    def __call__(self, sample: dict):
        input_tokens = sample[self.input_key]
        input_is_str = isinstance(input_tokens, str)
        tokenizer = self.tokenizer
        ret_token_span = self.ret_token_span
        if input_is_str:  # This happens in a tokenizer component where the raw sentence is fed.

            # noinspection PyShadowingNames
            def tokenize_str(input_str, add_special_tokens=True):
                if tokenizer.is_fast:
                    encoding = tokenizer.encode_plus(input_str,
                                                     return_offsets_mapping=True,
                                                     add_special_tokens=add_special_tokens).encodings[0]
                    subtoken_offsets = encoding.offsets
                    input_tokens = encoding.tokens
                    input_ids = encoding.ids

                    # Fill up missing non-blank characters swallowed by HF tokenizer
                    offset = 0
                    fixed_offsets = []
                    fixed_tokens = []
                    fixed_ids = []
                    for token, id, (b, e) in zip(input_tokens, input_ids, subtoken_offsets):
                        if b > offset:
                            missing_token = input_str[offset: b]
                            if not missing_token.isspace():  # In the future, we may want space back
                                fixed_tokens.append(missing_token)
                                fixed_ids.append(tokenizer.unk_token_id)
                                fixed_offsets.append((offset, b))
                        fixed_tokens.append(token)
                        fixed_ids.append(id)
                        fixed_offsets.append((b, e))
                        offset = e
                    subtoken_offsets = fixed_offsets
                    input_tokens = fixed_tokens
                    input_ids = fixed_ids

                    if add_special_tokens:
                        subtoken_offsets = subtoken_offsets[1 if self.has_cls else 0:-1]

                    # Edge case that the input_str is swallowed in whole
                    if not subtoken_offsets and not input_str.isspace():
                        __index = 1 if add_special_tokens and self.has_cls else 0
                        input_tokens.insert(__index, input_str)
                        input_ids.insert(__index, tokenizer.unk_token_id)
                        subtoken_offsets.append((0, len(input_str)))

                    if not self.has_cls:
                        input_tokens = [self.cls_token] + input_tokens
                        input_ids = [self.cls_token_id] + input_ids
                else:
                    input_tokens = tokenizer.tokenize(input_str)
                    subtoken_offsets = []
                    _o = 0
                    for each in input_tokens:
                        subtoken_offsets.append((_o, _o + len(each)))
                        _o += len(each)
                    if add_special_tokens:
                        input_tokens = [self.cls_token] + input_tokens + [self.sep_token]
                    input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
                if self.check_space_before:
                    non_blank_offsets = [i for i in range(len(input_tokens)) if input_tokens[i] != '▁']
                    if add_special_tokens and not self.has_cls:
                        non_blank_offsets.insert(0, 0)
                    input_tokens = [input_tokens[i] for i in non_blank_offsets]
                    input_ids = [input_ids[i] for i in non_blank_offsets]
                    if add_special_tokens:
                        non_blank_offsets = non_blank_offsets[1:-1]
                        subtoken_offsets = [subtoken_offsets[i - 1] for i in non_blank_offsets]
                    else:
                        subtoken_offsets = [subtoken_offsets[i] for i in non_blank_offsets]
                    # MT5 generates tokens like ▁of, which is bad for the tokenizer. So we want to remove the prefix.
                    for i, token in enumerate(input_tokens[1:-1] if add_special_tokens else input_tokens):
                        if input_str[subtoken_offsets[i][0]] == ' ':
                            subtoken_offsets[i] = (subtoken_offsets[i][0] + 1, subtoken_offsets[i][1])
                if add_special_tokens:
                    if len(input_tokens) == 2:  # bos and eos, meaning that the text contains only some spaces
                        input_tokens.insert(1, input_str)
                        input_ids.insert(1, tokenizer.unk_token_id)
                        subtoken_offsets.append((0, len(input_str)))
                else:
                    if not input_ids:  # This chunk might be some control chars getting removed by tokenizer
                        input_tokens = [input_str]
                        input_ids = [tokenizer.unk_token_id]
                        subtoken_offsets = [(0, len(input_str))]
                return input_tokens, input_ids, subtoken_offsets

            if self.dict:
                chunks = self.dict.split(sample.get(f'{self.input_key}_', input_tokens))  # Match original text directly
                _input_tokens, _input_ids, _subtoken_offsets = [self.cls_token], [self.cls_token_id], []
                _offset = 0
                custom_words = sample['custom_words'] = []
                char_offset = 0
                for chunk in chunks:
                    if isinstance(chunk, str):  # Use transformed text as it's what models are trained on
                        chunk = input_tokens[char_offset:char_offset + len(chunk)]
                        tokens, ids, offsets = tokenize_str(chunk, add_special_tokens=False)
                        char_offset += len(chunk)
                    else:
                        begin, end, label = chunk
                        # chunk offset is in char level
                        # custom_words.append(chunk)
                        if isinstance(label, list):
                            tokens, ids, offsets, delta = [], [], [], 0
                            for token in label:
                                _tokens, _ids, _offsets = tokenize_str(token, add_special_tokens=False)
                                tokens.extend(_tokens)
                                # track the subword offset of this chunk, -1 for [CLS]
                                custom_words.append(
                                    (len(_input_ids) + len(ids) - 1, len(_input_ids) + len(ids) - 1 + len(_ids), token))
                                ids.extend(_ids)
                                offsets.extend((x[0] + delta, x[1] + delta) for x in _offsets)
                                delta = offsets[-1][-1]
                        else:
                            tokens, ids, offsets = tokenize_str(input_tokens[begin:end], add_special_tokens=False)
                            # offsets = [(offsets[0][0], offsets[-1][-1])]
                            custom_words.append((len(_input_ids) - 1, len(_input_ids) + len(ids) - 1, label))
                        char_offset = end
                    _input_tokens.extend(tokens)
                    _input_ids.extend(ids)
                    _subtoken_offsets.extend((x[0] + _offset, x[1] + _offset) for x in offsets)
                    _offset = _subtoken_offsets[-1][-1]
                subtoken_offsets = _subtoken_offsets
                input_tokens = _input_tokens + [self.sep_token]
                input_ids = _input_ids + [self.sep_token_id]
            else:
                input_tokens, input_ids, subtoken_offsets = tokenize_str(input_tokens, add_special_tokens=True)

            if self.ret_subtokens:
                sample[f'{self.input_key}_subtoken_offsets'] = subtoken_offsets

        cls_is_bos = self.cls_is_bos
        if cls_is_bos is None:
            cls_is_bos = input_tokens[0] == BOS
        sep_is_eos = self.sep_is_eos
        if sep_is_eos is None:
            sep_is_eos = input_tokens[-1] == EOS
        if self.strip_cls_sep:
            if cls_is_bos:
                input_tokens = input_tokens[1:]
            if sep_is_eos:
                input_tokens = input_tokens[:-1]
        if not self.ret_mask_and_type:  # only need input_ids and token_span, use a light version
            if input_is_str:
                prefix_mask = self._init_prefix_mask(input_ids)
            else:
                if input_tokens:
                    return_offsets_mapping = tokenizer.is_fast and self.ret_subtokens
                    encodings = tokenizer.batch_encode_plus(
                        input_tokens,
                        return_offsets_mapping=return_offsets_mapping,  # Many tokenizers do not offer fast version
                        add_special_tokens=False
                    )
                    subtoken_ids_per_token = encodings.data['input_ids']
                    if return_offsets_mapping:
                        offsets_mapping = [encoding.offsets for encoding in encodings.encodings]
                    else:
                        offsets_mapping = []
                        for token, subtoken_ids in zip(input_tokens, subtoken_ids_per_token):
                            if len(subtoken_ids) > len(token):  # … --> ...
                                del subtoken_ids[len(token):]
                            if not subtoken_ids:
                                subtoken_ids = [tokenizer.unk_token_id]
                            # Since non-fast tok generates no mapping, we have to guess
                            char_per_subtoken = max(len(token) // len(subtoken_ids), 1)
                            bes = [(b, b + char_per_subtoken) for b in range(0, len(token), char_per_subtoken)]
                            if len(bes) != len(subtoken_ids):
                                bes[len(subtoken_ids) - 1] = (bes[len(subtoken_ids) - 1][0], len(token))
                                del bes[len(subtoken_ids):]
                            offsets_mapping.append(bes)
                else:
                    encodings = SerializableDict()
                    subtoken_ids_per_token = []
                    encodings.data = {'input_ids': subtoken_ids_per_token}
                if self.check_space_before:
                    # noinspection PyUnboundLocalVariable
                    for token, subtokens, mapping, encoding in zip(input_tokens, subtoken_ids_per_token,
                                                                   offsets_mapping, encodings.encodings):
                        # Remove ▁ generated by spm for 2 reasons:
                        # 1. During decoding, mostly no ▁ will be created unless blanks are placed between tokens (which
                        # is true for English but in English it will likely be concatenated to the token following it)
                        # 2. For T5, '▁' is used as CLS
                        if len(subtokens) > 1 and encoding.tokens[0] == '▁':
                            subtokens.pop(0)
                            if mapping:
                                mapping.pop(0)
                # Some tokens get stripped out
                subtoken_ids_per_token = [ids if ids else [tokenizer.unk_token_id] for ids in subtoken_ids_per_token]
                input_ids = sum(subtoken_ids_per_token, [self.cls_token_id])
                if self.sep_is_eos is None:
                    # None means to check whether sep is at the tail or between tokens
                    if sep_is_eos:
                        input_ids += [self.sep_token_id]
                    elif self.sep_token_id not in input_ids:
                        input_ids += [self.sep_token_id]
                else:
                    input_ids += [self.sep_token_id]
                # else self.sep_is_eos == False means sep is between tokens and don't bother to check

                if self.ret_subtokens:
                    prefix_mask = self._init_prefix_mask(input_ids)
                    # if self.check_space_before:
                    #     if offsets_mapping[0] and not input_tokens[0].startswith(' '):
                    #         prefix_mask[1] = False
                else:
                    prefix_mask = [False] * len(input_ids)
                    offset = 1
                    for _subtokens in subtoken_ids_per_token:
                        prefix_mask[offset] = True
                        offset += len(_subtokens)
                if self.ret_subtokens:
                    subtoken_offsets = []
                    for token, offsets in zip(input_tokens, offsets_mapping):
                        if offsets:
                            subtoken_offsets.append(offsets)
                        else:
                            subtoken_offsets.append([(0, len(token))])
                    if self.ret_subtokens_group:
                        sample[f'{self.input_key}_subtoken_offsets_group'] = subtoken_offsets
                    sample[f'{self.input_key}_subtoken_offsets'] = sum(subtoken_offsets, [])
        else:
            input_ids, attention_mask, token_type_ids, prefix_mask = \
                convert_examples_to_features(input_tokens,
                                             None,
                                             tokenizer,
                                             cls_token_at_end=self.cls_token_at_end,
                                             # xlnet has a cls token at the end
                                             cls_token=tokenizer.cls_token,
                                             cls_token_segment_id=self.cls_token_segment_id,
                                             sep_token=self.sep_token,
                                             sep_token_extra=self.sep_token_extra,
                                             # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
                                             pad_on_left=self.pad_on_left,
                                             # pad on the left for xlnet
                                             pad_token_id=self.pad_token_id,
                                             pad_token_segment_id=self.pad_token_segment_id,
                                             pad_token_label_id=0,
                                             do_padding=self.do_padding)
        if len(input_ids) > self.max_seq_length:
            if self.truncate_long_sequences:
                # raise SequenceTooLong(
                #     f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. '
                #     f'For sequence tasks, truncate_long_sequences = True is not supported.'
                #     f'You are recommended to split your long text into several sentences within '
                #     f'{self.max_seq_length - 2} tokens beforehand. '
                #     f'Or simply set truncate_long_sequences = False to enable sliding window.')
                input_ids = input_ids[:self.max_seq_length]
                prefix_mask = prefix_mask[:self.max_seq_length]
                warnings.warn(
                    f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. '
                    f'The exceeded part will be truncated and ignored. '
                    f'You are recommended to split your long text into several sentences within '
                    f'{self.max_seq_length - 2} tokens beforehand.'
                    f'Or simply set truncate_long_sequences = False to enable sliding window.'
                )
            else:
                input_ids = self.sliding_window(input_ids, input_ids[-1] == self.sep_token_id)
        if prefix_mask:
            if cls_is_bos:
                prefix_mask[0] = True
            if sep_is_eos:
                prefix_mask[-1] = True
        outputs = [input_ids]
        if self.ret_mask_and_type:
            # noinspection PyUnboundLocalVariable
            outputs += [attention_mask, token_type_ids]
        if self.ret_prefix_mask:
            outputs += [prefix_mask]
        if ret_token_span and prefix_mask:
            if cls_is_bos:
                token_span = [[0]]
            else:
                token_span = []
            offset = 1
            span = []
            for mask in prefix_mask[1:len(prefix_mask) if sep_is_eos is None else -1]:  # skip [CLS] and [SEP]
                if mask and span:
                    token_span.append(span)
                    span = []
                span.append(offset)
                offset += 1
            if span:
                token_span.append(span)
            if sep_is_eos:
                assert offset == len(prefix_mask) - 1
                token_span.append([offset])
            outputs.append(token_span)
        for k, v in zip(self.output_key, outputs):
            sample[k] = v
        return sample