def evaluate(gold_file, pred_file, do_enhanced_collapse_empty_nodes=False, do_copy_cols=True): """Evaluate using official CoNLL-X evaluation script (Yuval Krymolowski) Args: gold_file(str): The gold conllx file pred_file(str): The pred conllx file do_enhanced_collapse_empty_nodes: (Default value = False) do_copy_cols: (Default value = True) Returns: """ if do_enhanced_collapse_empty_nodes: gold_file = enhanced_collapse_empty_nodes(gold_file) pred_file = enhanced_collapse_empty_nodes(pred_file) if do_copy_cols: fixed_pred_file = pred_file.replace('.conllu', '.fixed.conllu') copy_cols(gold_file, pred_file, fixed_pred_file) else: fixed_pred_file = pred_file args = SerializableDict() args.enhancements = '0' args.gold_file = gold_file args.system_file = fixed_pred_file return iwpt20_xud_eval.evaluate_wrapper(args)
def load_vocabs(self, save_dir, filename='vocabs.json'): save_dir = get_resource(save_dir) vocabs = SerializableDict() vocabs.load_json(os.path.join(save_dir, filename)) for key, value in vocabs.items(): vocab = VocabTF() vocab.copy_from(value) setattr(self.transform, key, vocab)
def load_vocabs(self, save_dir, filename='vocabs.json', vocab_cls=Vocab): """Load vocabularies from a directory. Args: save_dir: The directory to load vocabularies. filename: The name for vocabularies. """ save_dir = get_resource(save_dir) vocabs = SerializableDict() vocabs.load_json(os.path.join(save_dir, filename)) self._load_vocabs(self, vocabs, vocab_cls)
def __init__(self, **kwargs) -> None: """The base class for all components using PyTorch as backend. It provides common workflows of building vocabs, datasets, dataloaders and models. These workflows are more of a conventional guideline than en-forced protocols, which means subclass has the freedom to override or completely skip some steps. Args: **kwargs: Addtional arguments to be stored in the ``config`` property. """ super().__init__() self.model: Optional[torch.nn.Module] = None self.config = SerializableDict(**kwargs) self.vocabs = VocabDict()
def save_vocabs(self, save_dir, filename='vocabs.json'): """Save vocabularies to a directory. Args: save_dir: The directory to save vocabularies. filename: The name for vocabularies. """ vocabs = SerializableDict() for key, value in self.items(): if isinstance(value, Vocab): vocabs[key] = value.to_dict() vocabs.save_json(os.path.join(save_dir, filename))
def __init__(self, transform: Transform) -> None: super().__init__() self.meta = { 'class_path': classpath_of(self), 'hanlp_version': hanlp.version.__version__, } self.model: Optional[tf.keras.Model] = None self.config = SerializableDict() self.transform = transform # share config with transform for convenience, so we don't need to pass args around if self.transform.config: for k, v in self.transform.config.items(): self.config[k] = v self.transform.config = self.config
def __init__( self, locals_: Dict, exclude=('kwargs', 'self', '__class__', 'locals_')) -> None: """This base class helps sub-classes to capture their arguments passed to ``__init__``, and also their types so that they can be deserialized from a config in dict form. Args: locals_: Obtained by :meth:`locals`. exclude: Arguments to be excluded. Examples: >>> class MyClass(ConfigTracker): >>> def __init__(self, i_need_this='yes') -> None: >>> super().__init__(locals()) >>> obj = MyClass() >>> print(obj.config) {'i_need_this': 'yes', 'classpath': 'test_config_tracker.MyClass'} """ if 'kwargs' in locals_: locals_.update(locals_['kwargs']) self.config = SerializableDict( (k, v.config if hasattr(v, 'config') else v) for k, v in locals_.items() if k not in exclude) self.config['classpath'] = classpath_of(self)
def _savable_config(self): def convert(k, v): if not isinstance(v, SerializableDict) and hasattr(v, 'config'): v = v.config elif isinstance(v, (set, tuple)): v = list(v) if isinstance(v, dict): v = dict(convert(_k, _v) for _k, _v in v.items()) return k, v config = SerializableDict( convert(k, v) for k, v in sorted(self.config.items())) config.update({ # 'create_time': now_datetime(), 'classpath': classpath_of(self), 'hanlp_version': hanlp.__version__, }) return config
def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, **kwargs) -> None: super().__init__() self.map_y = map_y self.map_x = map_x if kwargs: if not config: config = SerializableDict() for k, v in kwargs.items(): config[k] = v self.config = config self.output_types = None self.output_shapes = None self.padding_values = None
def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, **kwargs) -> None: super().__init__() self.map_y = map_y self.map_x = map_x if kwargs: if not config: config = SerializableDict() for k, v in kwargs.items(): config[k] = v self.config = config self.output_types = None self.output_shapes = None self.padding_values = None # Fix tf memory leak: https://github.com/tensorflow/tensorflow/issues/37653#issuecomment-1000517720 self.py_func_set_to_cleanup = set()
class TorchComponent(Component, ABC): def __init__(self, **kwargs) -> None: """The base class for all components using PyTorch as backend. It provides common workflows of building vocabs, datasets, dataloaders and models. These workflows are more of a conventional guideline than en-forced protocols, which means subclass has the freedom to override or completely skip some steps. Args: **kwargs: Addtional arguments to be stored in the ``config`` property. """ super().__init__() self.model: Optional[torch.nn.Module] = None self.config = SerializableDict(**kwargs) self.vocabs = VocabDict() def _capture_config(self, locals_: Dict, exclude=( 'trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose', 'dev_batch_size', '__class__', 'devices', 'eval_trn')): """Save arguments to config Args: locals_: Dict: exclude: (Default value = ('trn_data') 'dev_data': 'save_dir': 'kwargs': 'self': 'logger': 'verbose': 'dev_batch_size': '__class__': 'devices'): Returns: """ if 'kwargs' in locals_: locals_.update(locals_['kwargs']) locals_ = dict((k, v) for k, v in locals_.items() if k not in exclude and not k.startswith('_')) self.config.update(locals_) return self.config def save_weights(self, save_dir, filename='model.pt', trainable_only=True, **kwargs): """Save model weights to a directory. Args: save_dir: The directory to save weights into. filename: A file name for weights. trainable_only: ``True`` to only save trainable weights. Useful when the model contains lots of static embeddings. **kwargs: Not used for now. """ model = self.model_ state_dict = model.state_dict() if trainable_only: trainable_names = set(n for n, p in model.named_parameters() if p.requires_grad) state_dict = dict((n, p) for n, p in state_dict.items() if n in trainable_names) torch.save(state_dict, os.path.join(save_dir, filename)) def load_weights(self, save_dir, filename='model.pt', **kwargs): """Load weights from a directory. Args: save_dir: The directory to load weights from. filename: A file name for weights. **kwargs: Not used. """ save_dir = get_resource(save_dir) filename = os.path.join(save_dir, filename) # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]') self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False) # flash('') def save_config(self, save_dir, filename='config.json'): """Save config into a directory. Args: save_dir: The directory to save config. filename: A file name for config. """ self._savable_config.save_json(os.path.join(save_dir, filename)) def load_config(self, save_dir, filename='config.json', **kwargs): """Load config from a directory. Args: save_dir: The directory to load config. filename: A file name for config. **kwargs: K-V pairs to override config. """ save_dir = get_resource(save_dir) self.config.load_json(os.path.join(save_dir, filename)) self.config.update(kwargs) # overwrite config loaded from disk for k, v in self.config.items(): if isinstance(v, dict) and 'classpath' in v: self.config[k] = Configurable.from_config(v) self.on_config_ready(**self.config) def save_vocabs(self, save_dir, filename='vocabs.json'): """Save vocabularies to a directory. Args: save_dir: The directory to save vocabularies. filename: The name for vocabularies. """ if hasattr(self, 'vocabs'): self.vocabs.save_vocabs(save_dir, filename) def load_vocabs(self, save_dir, filename='vocabs.json'): """Load vocabularies from a directory. Args: save_dir: The directory to load vocabularies. filename: The name for vocabularies. """ if hasattr(self, 'vocabs'): self.vocabs = VocabDict() self.vocabs.load_vocabs(save_dir, filename) def save(self, save_dir: str, **kwargs): """Save this component to a directory. Args: save_dir: The directory to save this component. **kwargs: Not used. """ self.save_config(save_dir) self.save_vocabs(save_dir) self.save_weights(save_dir) def load(self, save_dir: str, devices=None, verbose=HANLP_VERBOSE, **kwargs): """Load from a local/remote component. Args: save_dir: An identifier which can be a local path or a remote URL or a pre-defined string. devices: The devices this component will be moved onto. verbose: ``True`` to log loading progress. **kwargs: To override some configs. """ save_dir = get_resource(save_dir) # flash('Loading config and vocabs [blink][yellow]...[/yellow][/blink]') if devices is None and self.model: devices = self.devices self.load_config(save_dir, **kwargs) self.load_vocabs(save_dir) if verbose: flash('Building model [blink][yellow]...[/yellow][/blink]') self.model = self.build_model( **merge_dict(self.config, training=False, **kwargs, overwrite=True, inplace=True)) if verbose: flash('') self.load_weights(save_dir, **kwargs) self.to(devices) self.model.eval() def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, devices=None, logger=None, seed=None, finetune: Union[bool, str] = False, eval_trn=True, _device_placeholder=False, **kwargs): """Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote files. Args: trn_data: Training set. dev_data: Development set. save_dir: The directory to save trained component. batch_size: The number of samples in a batch. epochs: Number of epochs. devices: Devices this component will live on. logger: Any :class:`logging.Logger` instance. seed: Random seed to reproduce this training. finetune: ``True`` to load from ``save_dir`` instead of creating a randomly initialized component. ``str`` to specify a different ``save_dir`` to load from. eval_trn: Evaluate training set after each update. This can slow down the training but provides a quick diagnostic for debugging. _device_placeholder: ``True`` to create a placeholder tensor which triggers PyTorch to occupy devices so other components won't take these devices as first choices. **kwargs: Hyperparameters used by sub-classes. Returns: Any results sub-classes would like to return. Usually the best metrics on training set. """ # Common initialization steps config = self._capture_config(locals()) if not logger: logger = self.build_logger('train', save_dir) if not seed: self.config.seed = 233 if isdebugging() else int(time.time()) set_seed(self.config.seed) logger.info(self._savable_config.to_json(sort=True)) if isinstance(devices, list) or devices is None or isinstance(devices, float): flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]') devices = -1 if isdebugging() else cuda_devices(devices) flash('') # flash(f'Available GPUs: {devices}') if isinstance(devices, list): first_device = (devices[0] if devices else -1) elif isinstance(devices, dict): first_device = next(iter(devices.values())) elif isinstance(devices, int): first_device = devices else: first_device = -1 if _device_placeholder and first_device >= 0: _dummy_placeholder = self._create_dummy_placeholder_on(first_device) if finetune: if isinstance(finetune, str): self.load(finetune, devices=devices) else: self.load(save_dir, devices=devices) logger.info( f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}' f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.') self.on_config_ready(**self.config) trn = self.build_dataloader(**merge_dict(config, data=trn_data, batch_size=batch_size, shuffle=True, training=True, device=first_device, logger=logger, vocabs=self.vocabs, overwrite=True)) dev = self.build_dataloader(**merge_dict(config, data=dev_data, batch_size=batch_size, shuffle=False, training=None, device=first_device, logger=logger, vocabs=self.vocabs, overwrite=True)) if dev_data else None if not finetune: flash('[yellow]Building model [blink]...[/blink][/yellow]') self.model = self.build_model(**merge_dict(config, training=True)) flash('') logger.info(f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}' f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.') assert self.model, 'build_model is not properly implemented.' _description = repr(self.model) if len(_description.split('\n')) < 10: logger.info(_description) self.save_config(save_dir) self.save_vocabs(save_dir) self.to(devices, logger) if _device_placeholder and first_device >= 0: del _dummy_placeholder criterion = self.build_criterion(**merge_dict(config, trn=trn)) optimizer = self.build_optimizer(**merge_dict(config, trn=trn, criterion=criterion)) metric = self.build_metric(**self.config) if hasattr(trn.dataset, '__len__') and dev and hasattr(dev.dataset, '__len__'): logger.info(f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.') trn_size = len(trn) // self.config.get('gradient_accumulation', 1) ratio_width = len(f'{trn_size}/{trn_size}') else: ratio_width = None return self.execute_training_loop(**merge_dict(config, trn=trn, dev=dev, epochs=epochs, criterion=criterion, optimizer=optimizer, metric=metric, logger=logger, save_dir=save_dir, devices=devices, ratio_width=ratio_width, trn_data=trn_data, dev_data=dev_data, eval_trn=eval_trn, overwrite=True)) def build_logger(self, name, save_dir): """Build a :class:`logging.Logger`. Args: name: The name of this logger. save_dir: The directory this logger should save logs into. Returns: logging.Logger: A logger. """ logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO, fmt="%(message)s") return logger @abstractmethod def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None, **kwargs) -> DataLoader: """Build dataloader for training, dev and test sets. It's suggested to build vocabs in this method if they are not built yet. Args: data: Data representing samples, which can be a path or a list of samples. batch_size: Number of samples per batch. shuffle: Whether to shuffle this dataloader. device: Device tensors should be loaded onto. logger: Logger for reporting some message if dataloader takes a long time or if vocabs has to be built. **kwargs: Arguments from ``**self.config``. """ pass def build_vocabs(self, trn: torch.utils.data.Dataset, logger: logging.Logger): """Override this method to build vocabs. Args: trn: Training set. logger: Logger for reporting progress. """ pass @property def _savable_config(self): def convert(k, v): if not isinstance(v, SerializableDict) and hasattr(v, 'config'): v = v.config elif isinstance(v, (set, tuple)): v = list(v) if isinstance(v, dict): v = dict(convert(_k, _v) for _k, _v in v.items()) return k, v config = SerializableDict( convert(k, v) for k, v in sorted(self.config.items())) config.update({ # 'create_time': now_datetime(), 'classpath': classpath_of(self), 'hanlp_version': hanlp.__version__, }) return config @abstractmethod def build_optimizer(self, **kwargs): """Implement this method to build an optimizer. Args: **kwargs: The subclass decides the method signature. """ pass @abstractmethod def build_criterion(self, decoder, **kwargs): """Implement this method to build criterion (loss function). Args: decoder: The model or decoder. **kwargs: The subclass decides the method signature. """ pass @abstractmethod def build_metric(self, **kwargs): """Implement this to build metric(s). Args: **kwargs: The subclass decides the method signature. """ pass @abstractmethod def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir, logger: logging.Logger, devices, ratio_width=None, **kwargs): """Implement this to run training loop. Args: trn: Training set. dev: Development set. epochs: Number of epochs. criterion: Loss function. optimizer: Optimizer(s). metric: Metric(s) save_dir: The directory to save this component. logger: Logger for reporting progress. devices: Devices this component and dataloader will live on. ratio_width: The width of dataset size measured in number of characters. Used for logger to align messages. **kwargs: Other hyper-parameters passed from sub-class. """ pass @abstractmethod def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs): """Fit onto a dataloader. Args: trn: Training set. criterion: Loss function. optimizer: Optimizer. metric: Metric(s). logger: Logger for reporting progress. **kwargs: Other hyper-parameters passed from sub-class. """ pass @abstractmethod def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric=None, output=False, **kwargs): """Evaluate on a dataloader. Args: data: Dataloader which can build from any data source. criterion: Loss function. metric: Metric(s). output: Whether to save outputs into some file. **kwargs: Not used. """ pass @abstractmethod def build_model(self, training=True, **kwargs) -> torch.nn.Module: """Build model. Args: training: ``True`` if called during training. **kwargs: ``**self.config``. """ raise NotImplementedError def evaluate(self, tst_data, save_dir=None, logger: logging.Logger = None, batch_size=None, output=False, **kwargs): """Evaluate test set. Args: tst_data: Test set, which is usually a file path. save_dir: The directory to save evaluation scores or predictions. logger: Logger for reporting progress. batch_size: Batch size for test dataloader. output: Whether to save outputs into some file. **kwargs: Not used. Returns: (metric, outputs) where outputs are the return values of ``evaluate_dataloader``. """ if not self.model: raise RuntimeError('Call fit or load before evaluate.') if isinstance(tst_data, str): tst_data = get_resource(tst_data) filename = os.path.basename(tst_data) else: filename = None if output is True: output = self.generate_prediction_filename(tst_data if isinstance(tst_data, str) else 'test.txt', save_dir) if logger is None: _logger_name = basename_no_ext(filename) if filename else None logger = self.build_logger(_logger_name, save_dir) if not batch_size: batch_size = self.config.get('batch_size', 32) data = self.build_dataloader(**merge_dict(self.config, data=tst_data, batch_size=batch_size, shuffle=False, device=self.devices[0], logger=logger, overwrite=True)) dataset = data while dataset and hasattr(dataset, 'dataset'): dataset = dataset.dataset num_samples = len(dataset) if dataset else None if output and isinstance(dataset, TransformableDataset): def add_idx(samples): for idx, sample in enumerate(samples): if sample: sample[IDX] = idx add_idx(dataset.data) if dataset.cache: add_idx(dataset.cache) criterion = self.build_criterion(**self.config) metric = self.build_metric(**self.config) start = time.time() outputs = self.evaluate_dataloader(data, criterion=criterion, filename=filename, output=output, input=tst_data, save_dir=save_dir, test=True, num_samples=num_samples, **merge_dict(self.config, batch_size=batch_size, metric=metric, logger=logger, **kwargs)) elapsed = time.time() - start if logger: if num_samples: logger.info(f'speed: {num_samples / elapsed:.0f} samples/second') else: logger.info(f'speed: {len(data) / elapsed:.0f} batches/second') return metric, outputs def generate_prediction_filename(self, tst_data, save_dir): assert isinstance(tst_data, str), 'tst_data has be a str in order to infer the output name' output = os.path.splitext(os.path.basename(tst_data)) output = os.path.join(save_dir, output[0] + '.pred' + output[1]) return output def to(self, devices=Union[int, float, List[int], Dict[str, Union[int, torch.device]]], logger: logging.Logger = None, verbose=HANLP_VERBOSE): """Move this component to devices. Args: devices: Target devices. logger: Logger for printing progress report, as copying a model from CPU to GPU can takes several seconds. verbose: ``True`` to print progress when logger is None. """ if devices == -1 or devices == [-1]: devices = [] elif isinstance(devices, (int, float)) or devices is None: devices = cuda_devices(devices) if devices: if logger: logger.info(f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]') if isinstance(devices, list): if verbose: flash(f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]') self.model = self.model.to(devices[0]) if len(devices) > 1 and not isdebugging() and not isinstance(self.model, nn.DataParallel): self.model = self.parallelize(devices) elif isinstance(devices, dict): for name, module in self.model.named_modules(): for regex, device in devices.items(): try: on_device: torch.device = next(module.parameters()).device except StopIteration: continue if on_device == device: continue if isinstance(device, int): if on_device.index == device: continue if re.match(regex, name): if not name: name = '*' flash(f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}' f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n') module.to(device) else: raise ValueError(f'Unrecognized devices {devices}') if verbose: flash('') else: if logger: logger.info('Using CPU') def parallelize(self, devices: List[Union[int, torch.device]]): return nn.DataParallel(self.model, device_ids=devices) @property def devices(self): """The devices this component lives on. """ if self.model is None: return None # next(parser.model.parameters()).device if hasattr(self.model, 'device_ids'): return self.model.device_ids device: torch.device = next(self.model.parameters()).device return [device] @property def device(self): """The first device this component lives on. """ devices = self.devices if not devices: return None return devices[0] def on_config_ready(self, **kwargs): """Called when config is ready, either during ``fit`` ot ``load``. Subclass can perform extra initialization tasks in this callback. Args: **kwargs: Not used. """ pass @property def model_(self) -> nn.Module: """ The actual model when it's wrapped by a `DataParallel` Returns: The "real" model """ if isinstance(self.model, nn.DataParallel): return self.model.module return self.model # noinspection PyMethodOverriding @abstractmethod def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs): """Predict on data fed by user. Users shall avoid directly call this method since it is not guarded with ``torch.no_grad`` and will introduces unnecessary gradient computation. Use ``__call__`` instead. Args: data: Sentences or tokens. batch_size: Decoding batch size. **kwargs: Used in sub-classes. """ pass @staticmethod def _create_dummy_placeholder_on(device): if device < 0: device = 'cpu:0' return torch.zeros(16, 16, device=device) @torch.no_grad() def __call__(self, data, batch_size=None, **kwargs): """Predict on data fed by user. This method calls :meth:`~hanlp.common.torch_component.predict` but decorates it with ``torch.no_grad``. Args: data: Sentences or tokens. batch_size: Decoding batch size. **kwargs: Used in sub-classes. """ return super().__call__(data, **merge_dict(self.config, overwrite=True, batch_size=batch_size or self.config.get('batch_size', None), **kwargs))
class KerasComponent(Component, ABC): def __init__(self, transform: Transform) -> None: super().__init__() self.meta = { 'class_path': classpath_of(self), 'hanlp_version': hanlp.version.__version__, } self.model: Optional[tf.keras.Model] = None self.config = SerializableDict() self.transform = transform # share config with transform for convenience, so we don't need to pass args around if self.transform.config: for k, v in self.transform.config.items(): self.config[k] = v self.transform.config = self.config def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None, callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs): input_path = get_resource(input_path) file_prefix, ext = os.path.splitext(input_path) name = os.path.basename(file_prefix) if not name: name = 'evaluate' if save_dir and not logger: logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO if verbose else logging.WARN, mode='w') tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size) samples = self.num_samples_in(tst_data) num_batches = math.ceil(samples / batch_size) if warm_up: for x, y in tst_data: self.model.predict_on_batch(x) break if output: assert save_dir, 'Must pass save_dir in order to output' if isinstance(output, bool): output = os.path.join(save_dir, name) + '.predict' + ext elif isinstance(output, str): output = output else: raise RuntimeError('output ({}) must be of type bool or str'.format(repr(output))) timer = Timer() eval_outputs = self.evaluate_dataset(tst_data, callbacks, output, num_batches, **kwargs) loss, score, output = eval_outputs[0], eval_outputs[1], eval_outputs[2] delta_time = timer.stop() speed = samples / delta_time.delta_seconds if logger: f1: IOBES_F1_TF = None for metric in self.model.metrics: if isinstance(metric, IOBES_F1_TF): f1 = metric break extra_report = '' if f1: overall, by_type, extra_report = f1.state.result(full=True, verbose=False) extra_report = ' \n' + extra_report logger.info('Evaluation results for {} - ' 'loss: {:.4f} - {} - speed: {:.2f} sample/sec{}' .format(name + ext, loss, format_scores(score) if isinstance(score, dict) else format_metrics(self.model.metrics), speed, extra_report)) if output: logger.info('Saving output to {}'.format(output)) with open(output, 'w', encoding='utf-8') as out: self.evaluate_output(tst_data, out, num_batches, self.model.metrics) return loss, score, speed def num_samples_in(self, dataset): return size_of_dataset(dataset) def evaluate_dataset(self, tst_data, callbacks, output, num_batches, **kwargs): loss, score = self.model.evaluate(tst_data, callbacks=callbacks, steps=num_batches) return loss, score, output def evaluate_output(self, tst_data, out, num_batches, metrics: List[tf.keras.metrics.Metric]): # out.write('x\ty_true\ty_pred\n') for metric in metrics: metric.reset_states() for idx, batch in enumerate(tst_data): outputs = self.model.predict_on_batch(batch[0]) for metric in metrics: metric(batch[1], outputs, outputs._keras_mask if hasattr(outputs, '_keras_mask') else None) self.evaluate_output_to_file(batch, outputs, out) print('\r{}/{} {}'.format(idx + 1, num_batches, format_metrics(metrics)), end='') print() def evaluate_output_to_file(self, batch, outputs, out): for x, y_gold, y_pred in zip(self.transform.X_to_inputs(batch[0]), self.transform.Y_to_outputs(batch[1], gold=True), self.transform.Y_to_outputs(outputs, gold=False)): out.write(self.transform.input_truth_output_to_str(x, y_gold, y_pred)) def _capture_config(self, config: Dict, exclude=( 'trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose', 'dev_batch_size', '__class__')): """ Save arguments to config Parameters ---------- config `locals()` exclude """ if 'kwargs' in config: config.update(config['kwargs']) config = dict( (key, tf.keras.utils.serialize_keras_object(value)) if hasattr(value, 'get_config') else (key, value) for key, value in config.items()) for key in exclude: config.pop(key, None) self.config.update(config) def save_meta(self, save_dir, filename='meta.json', **kwargs): self.meta['create_time']: now_datetime() self.meta.update(kwargs) save_json(self.meta, os.path.join(save_dir, filename)) def load_meta(self, save_dir, filename='meta.json'): save_dir = get_resource(save_dir) metapath = os.path.join(save_dir, filename) if os.path.isfile(metapath): self.meta.update(load_json(metapath)) def save_config(self, save_dir, filename='config.json'): self.config.save_json(os.path.join(save_dir, filename)) def load_config(self, save_dir, filename='config.json'): save_dir = get_resource(save_dir) self.config.load_json(os.path.join(save_dir, filename)) def save_weights(self, save_dir, filename='model.h5'): self.model.save_weights(os.path.join(save_dir, filename)) def load_weights(self, save_dir, filename='model.h5', **kwargs): assert self.model.built or self.model.weights, 'You must call self.model.built() in build_model() ' \ 'in order to load it' save_dir = get_resource(save_dir) self.model.load_weights(os.path.join(save_dir, filename)) def save_vocabs(self, save_dir, filename='vocabs.json'): vocabs = SerializableDict() for key, value in vars(self.transform).items(): if isinstance(value, VocabTF): vocabs[key] = value.to_dict() vocabs.save_json(os.path.join(save_dir, filename)) def load_vocabs(self, save_dir, filename='vocabs.json'): save_dir = get_resource(save_dir) vocabs = SerializableDict() vocabs.load_json(os.path.join(save_dir, filename)) for key, value in vocabs.items(): vocab = VocabTF() vocab.copy_from(value) setattr(self.transform, key, vocab) def load_transform(self, save_dir) -> Transform: """ Try to load transform only. This method might fail due to the fact it avoids building the model. If it do fail, then you have to use `load` which might be too heavy but that's the best we can do. :param save_dir: The path to load. """ save_dir = get_resource(save_dir) self.load_config(save_dir) self.load_vocabs(save_dir) self.transform.build_config() self.transform.lock_vocabs() return self.transform def save(self, save_dir: str, **kwargs): self.save_config(save_dir) self.save_vocabs(save_dir) self.save_weights(save_dir) def load(self, save_dir: str, logger=hanlp.utils.log_util.logger, **kwargs): self.meta['load_path'] = save_dir save_dir = get_resource(save_dir) self.load_config(save_dir) self.load_vocabs(save_dir) self.build(**merge_dict(self.config, training=False, logger=logger, **kwargs, overwrite=True, inplace=True)) self.load_weights(save_dir, **kwargs) self.load_meta(save_dir) @property def input_shape(self) -> List: return self.transform.output_shapes[0] def build(self, logger, **kwargs): self.transform.build_config() self.model = self.build_model(**merge_dict(self.config, training=kwargs.get('training', None), loss=kwargs.get('loss', None))) self.transform.lock_vocabs() optimizer = self.build_optimizer(**self.config) loss = self.build_loss( **self.config if 'loss' in self.config else dict(list(self.config.items()) + [('loss', None)])) # allow for different metrics = self.build_metrics(**merge_dict(self.config, metrics=kwargs.get('metrics', 'accuracy'), logger=logger, overwrite=True)) if not isinstance(metrics, list): if isinstance(metrics, tf.keras.metrics.Metric): metrics = [metrics] if not self.model.built: sample_inputs = self.sample_data if sample_inputs is not None: self.model(sample_inputs) else: if len(self.transform.output_shapes[0]) == 1 and self.transform.output_shapes[0][0] is None: x_shape = self.transform.output_shapes[0] else: x_shape = list(self.transform.output_shapes[0]) for i, shape in enumerate(x_shape): x_shape[i] = [None] + shape # batch + X.shape self.model.build(input_shape=x_shape) self.compile_model(optimizer, loss, metrics) return self.model, optimizer, loss, metrics def compile_model(self, optimizer, loss, metrics): self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics, run_eagerly=self.config.run_eagerly) def build_optimizer(self, optimizer, **kwargs): if isinstance(optimizer, (str, dict)): custom_objects = {'AdamWeightDecay': AdamWeightDecay} optimizer: tf.keras.optimizers.Optimizer = tf.keras.utils.deserialize_keras_object(optimizer, module_objects=vars(tf.keras.optimizers), custom_objects=custom_objects) self.config.optimizer = tf.keras.utils.serialize_keras_object(optimizer) return optimizer def build_loss(self, loss, **kwargs): if not loss: loss = tf.keras.losses.SparseCategoricalCrossentropy( reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, from_logits=True) elif isinstance(loss, (str, dict)): loss = tf.keras.utils.deserialize_keras_object(loss, module_objects=vars(tf.keras.losses)) if isinstance(loss, tf.keras.losses.Loss): self.config.loss = tf.keras.utils.serialize_keras_object(loss) return loss def build_transform(self, **kwargs): return self.transform def build_vocab(self, trn_data, logger): train_examples = self.transform.fit(trn_data, **self.config) self.transform.summarize_vocabs(logger) return train_examples def build_metrics(self, metrics, logger: logging.Logger, **kwargs): metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy') return [metric] @abstractmethod def build_model(self, **kwargs) -> tf.keras.Model: pass def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True, finetune: str = None, **kwargs): self._capture_config(locals()) self.transform = self.build_transform(**self.config) if not save_dir: save_dir = tempdir_human() if not logger: logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN) logger.info('Hyperparameter:\n' + self.config.to_json()) num_examples = self.build_vocab(trn_data, logger) # assert num_examples, 'You forgot to return the number of training examples in your build_vocab' logger.info('Building...') train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True)) logger.info('Model built:\n' + summary_of_model(self.model)) if finetune: finetune = get_resource(finetune) if os.path.isdir(finetune): finetune = os.path.join(finetune, 'model.h5') model.load_weights(finetune, by_name=True, skip_mismatch=True) logger.info(f'Loaded pretrained weights from {finetune} for finetuning') self.save_config(save_dir) self.save_vocabs(save_dir) self.save_meta(save_dir) trn_data = self.build_train_dataset(trn_data, batch_size, num_examples) dev_data = self.build_valid_dataset(dev_data, batch_size) callbacks = self.build_callbacks(save_dir, **merge_dict(self.config, overwrite=True, logger=logger)) # need to know #batches, otherwise progbar crashes dev_steps = math.ceil(self.num_samples_in(dev_data) / batch_size) checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint) timer = Timer() try: history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs, num_examples=num_examples, train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps, callbacks=callbacks, logger=logger, model=model, optimizer=optimizer, loss=loss, metrics=metrics, overwrite=True)) except KeyboardInterrupt: print() if not checkpoint or checkpoint.best in (np.Inf, -np.Inf): self.save_weights(save_dir) logger.info('Aborted with model saved') else: logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}') # noinspection PyTypeChecker history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History) delta_time = timer.stop() best_epoch_ago = 0 if history and hasattr(history, 'epoch'): trained_epoch = len(history.epoch) logger.info('Trained {} epochs in {}, each epoch takes {}'. format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time)) save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder) monitor_history: List = history.history.get(checkpoint.monitor, None) if monitor_history: best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best) if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]: logger.info(f'Restored the best model saved with best ' f'{checkpoint.monitor} = {checkpoint.best:.4f} ' f'saved {best_epoch_ago} epochs ago') self.load_weights(save_dir) # restore best model return history def train_loop(self, trn_data, dev_data, epochs, num_examples, train_steps_per_epoch, dev_steps, model, optimizer, loss, metrics, callbacks, logger, **kwargs): history = self.model.fit(trn_data, epochs=epochs, steps_per_epoch=train_steps_per_epoch, validation_data=dev_data, callbacks=callbacks, validation_steps=dev_steps, ) # type:tf.keras.callbacks.History return history def build_valid_dataset(self, dev_data, batch_size): dev_data = self.transform.file_to_dataset(dev_data, batch_size=batch_size, shuffle=False) return dev_data def build_train_dataset(self, trn_data, batch_size, num_examples): trn_data = self.transform.file_to_dataset(trn_data, batch_size=batch_size, shuffle=True, repeat=-1 if self.config.train_steps else None) return trn_data def build_callbacks(self, save_dir, logger, **kwargs): metrics = kwargs.get('metrics', 'accuracy') if isinstance(metrics, (list, tuple)): metrics = metrics[-1] monitor = f'val_{metrics}' checkpoint = tf.keras.callbacks.ModelCheckpoint( os.path.join(save_dir, 'model.h5'), # verbose=1, monitor=monitor, save_best_only=True, mode='max', save_weights_only=True) logger.debug(f'Monitor {checkpoint.monitor} for checkpoint') tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=io_util.makedirs(io_util.path_join(save_dir, 'logs'))) csv_logger = FineCSVLogger(os.path.join(save_dir, 'train.log'), separator=' | ', append=True) callbacks = [checkpoint, tensorboard_callback, csv_logger] lr_decay_per_epoch = self.config.get('lr_decay_per_epoch', None) if lr_decay_per_epoch: learning_rate = self.model.optimizer.get_config().get('learning_rate', None) if not learning_rate: logger.warning('Learning rate decay not supported for optimizer={}'.format(repr(self.model.optimizer))) else: logger.debug(f'Created LearningRateScheduler with lr_decay_per_epoch={lr_decay_per_epoch}') callbacks.append(tf.keras.callbacks.LearningRateScheduler( lambda epoch: learning_rate / (1 + lr_decay_per_epoch * epoch))) anneal_factor = self.config.get('anneal_factor', None) if anneal_factor: callbacks.append(tf.keras.callbacks.ReduceLROnPlateau(factor=anneal_factor, patience=self.config.get('anneal_patience', 10))) early_stopping_patience = self.config.get('early_stopping_patience', None) if early_stopping_patience: callbacks.append(tf.keras.callbacks.EarlyStopping(monitor=monitor, mode='max', verbose=1, patience=early_stopping_patience)) return callbacks def on_train_begin(self): """ Callback before the training starts """ pass def predict(self, data: Any, batch_size=None, **kwargs): assert self.model, 'Please call fit or load before predict' if not data: return [] data, flat = self.transform.input_to_inputs(data) if not batch_size: batch_size = self.config.batch_size dataset = self.transform.inputs_to_dataset(data, batch_size=batch_size, gold=kwargs.get('gold', False)) results = [] num_samples = 0 data_is_list = isinstance(data, list) for idx, batch in enumerate(dataset): samples_in_batch = tf.shape(batch[-1] if isinstance(batch[-1], tf.Tensor) else batch[-1][0])[0] if data_is_list: inputs = data[num_samples:num_samples + samples_in_batch] else: inputs = None # if data is a generator, it's usually one-time, not able to transform into a list for output in self.predict_batch(batch, inputs=inputs, **kwargs): results.append(output) num_samples += samples_in_batch if flat: return results[0] return results def predict_batch(self, batch, inputs=None, **kwargs): X = batch[0] Y = self.model.predict_on_batch(X) for output in self.transform.Y_to_outputs(Y, X=X, inputs=inputs, batch=batch, **kwargs): yield output @property def sample_data(self): return None @staticmethod def from_meta(meta: dict, **kwargs): """ Parameters ---------- meta kwargs Returns ------- KerasComponent """ cls = str_to_type(meta['class_path']) obj: KerasComponent = cls() assert 'load_path' in meta, f'{meta} doesn\'t contain load_path field' obj.load(meta['load_path']) return obj def export_model_for_serving(self, export_dir=None, version=1, overwrite=False, show_hint=False): assert self.model, 'You have to fit or load a model before exporting it' if not export_dir: assert 'load_path' in self.meta, 'When not specifying save_dir, load_path has to present' export_dir = get_resource(self.meta['load_path']) model_path = os.path.join(export_dir, str(version)) if os.path.isdir(model_path) and not overwrite: logger.info(f'{model_path} exists, skip since overwrite = {overwrite}') return export_dir logger.info(f'Exporting to {export_dir} ...') tf.saved_model.save(self.model, model_path) logger.info(f'Successfully exported model to {export_dir}') if show_hint: logger.info(f'You can serve it through \n' f'tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} ' f'--model_base_path={export_dir} --rest_api_port=8888') return export_dir def serve(self, export_dir=None, grpc_port=8500, rest_api_port=0, overwrite=False, dry_run=False): export_dir = self.export_model_for_serving(export_dir, show_hint=False, overwrite=overwrite) if not dry_run: del self.model # free memory logger.info('The inputs of exported model is shown below.') os.system(f'saved_model_cli show --all --dir {export_dir}/1') cmd = f'nohup tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} ' \ f'--model_base_path={export_dir} --port={grpc_port} --rest_api_port={rest_api_port} ' \ f'>serve.log 2>&1 &' logger.info(f'Running ...\n{cmd}') if not dry_run: os.system(cmd)
def save_vocabs(self, save_dir, filename='vocabs.json'): vocabs = SerializableDict() for key, value in vars(self.transform).items(): if isinstance(value, VocabTF): vocabs[key] = value.to_dict() vocabs.save_json(os.path.join(save_dir, filename))
def load_vocab(self, save_dir, filename='vocab.json'): save_dir = get_resource(save_dir) vocab = SerializableDict() vocab.load_json(os.path.join(save_dir, filename)) self.vocab.copy_from(vocab)
def save_vocab(self, save_dir, filename='vocab.json'): vocab = SerializableDict() vocab.update(self.vocab.to_dict()) vocab.save_json(os.path.join(save_dir, filename))
def __call__(self, sample: dict): input_tokens = sample[self.input_key] input_is_str = isinstance(input_tokens, str) tokenizer = self.tokenizer ret_token_span = self.ret_token_span if input_is_str: # This happens in a tokenizer component where the raw sentence is fed. # noinspection PyShadowingNames def tokenize_str(input_str, add_special_tokens=True): if tokenizer.is_fast: encoding = tokenizer.encode_plus(input_str, return_offsets_mapping=True, add_special_tokens=add_special_tokens).encodings[0] subtoken_offsets = encoding.offsets input_tokens = encoding.tokens input_ids = encoding.ids # Fill up missing non-blank characters swallowed by HF tokenizer offset = 0 fixed_offsets = [] fixed_tokens = [] fixed_ids = [] for token, id, (b, e) in zip(input_tokens, input_ids, subtoken_offsets): if b > offset: missing_token = input_str[offset: b] if not missing_token.isspace(): # In the future, we may want space back fixed_tokens.append(missing_token) fixed_ids.append(tokenizer.unk_token_id) fixed_offsets.append((offset, b)) fixed_tokens.append(token) fixed_ids.append(id) fixed_offsets.append((b, e)) offset = e subtoken_offsets = fixed_offsets input_tokens = fixed_tokens input_ids = fixed_ids if add_special_tokens: subtoken_offsets = subtoken_offsets[1 if self.has_cls else 0:-1] # Edge case that the input_str is swallowed in whole if not subtoken_offsets and not input_str.isspace(): __index = 1 if add_special_tokens and self.has_cls else 0 input_tokens.insert(__index, input_str) input_ids.insert(__index, tokenizer.unk_token_id) subtoken_offsets.append((0, len(input_str))) if not self.has_cls: input_tokens = [self.cls_token] + input_tokens input_ids = [self.cls_token_id] + input_ids else: input_tokens = tokenizer.tokenize(input_str) subtoken_offsets = [] _o = 0 for each in input_tokens: subtoken_offsets.append((_o, _o + len(each))) _o += len(each) if add_special_tokens: input_tokens = [self.cls_token] + input_tokens + [self.sep_token] input_ids = tokenizer.convert_tokens_to_ids(input_tokens) if self.check_space_before: non_blank_offsets = [i for i in range(len(input_tokens)) if input_tokens[i] != '▁'] if add_special_tokens and not self.has_cls: non_blank_offsets.insert(0, 0) input_tokens = [input_tokens[i] for i in non_blank_offsets] input_ids = [input_ids[i] for i in non_blank_offsets] if add_special_tokens: non_blank_offsets = non_blank_offsets[1:-1] subtoken_offsets = [subtoken_offsets[i - 1] for i in non_blank_offsets] else: subtoken_offsets = [subtoken_offsets[i] for i in non_blank_offsets] # MT5 generates tokens like ▁of, which is bad for the tokenizer. So we want to remove the prefix. for i, token in enumerate(input_tokens[1:-1] if add_special_tokens else input_tokens): if input_str[subtoken_offsets[i][0]] == ' ': subtoken_offsets[i] = (subtoken_offsets[i][0] + 1, subtoken_offsets[i][1]) if add_special_tokens: if len(input_tokens) == 2: # bos and eos, meaning that the text contains only some spaces input_tokens.insert(1, input_str) input_ids.insert(1, tokenizer.unk_token_id) subtoken_offsets.append((0, len(input_str))) else: if not input_ids: # This chunk might be some control chars getting removed by tokenizer input_tokens = [input_str] input_ids = [tokenizer.unk_token_id] subtoken_offsets = [(0, len(input_str))] return input_tokens, input_ids, subtoken_offsets if self.dict: chunks = self.dict.split(sample.get(f'{self.input_key}_', input_tokens)) # Match original text directly _input_tokens, _input_ids, _subtoken_offsets = [self.cls_token], [self.cls_token_id], [] _offset = 0 custom_words = sample['custom_words'] = [] char_offset = 0 for chunk in chunks: if isinstance(chunk, str): # Use transformed text as it's what models are trained on chunk = input_tokens[char_offset:char_offset + len(chunk)] tokens, ids, offsets = tokenize_str(chunk, add_special_tokens=False) char_offset += len(chunk) else: begin, end, label = chunk # chunk offset is in char level # custom_words.append(chunk) if isinstance(label, list): tokens, ids, offsets, delta = [], [], [], 0 for token in label: _tokens, _ids, _offsets = tokenize_str(token, add_special_tokens=False) tokens.extend(_tokens) # track the subword offset of this chunk, -1 for [CLS] custom_words.append( (len(_input_ids) + len(ids) - 1, len(_input_ids) + len(ids) - 1 + len(_ids), token)) ids.extend(_ids) offsets.extend((x[0] + delta, x[1] + delta) for x in _offsets) delta = offsets[-1][-1] else: tokens, ids, offsets = tokenize_str(input_tokens[begin:end], add_special_tokens=False) # offsets = [(offsets[0][0], offsets[-1][-1])] custom_words.append((len(_input_ids) - 1, len(_input_ids) + len(ids) - 1, label)) char_offset = end _input_tokens.extend(tokens) _input_ids.extend(ids) _subtoken_offsets.extend((x[0] + _offset, x[1] + _offset) for x in offsets) _offset = _subtoken_offsets[-1][-1] subtoken_offsets = _subtoken_offsets input_tokens = _input_tokens + [self.sep_token] input_ids = _input_ids + [self.sep_token_id] else: input_tokens, input_ids, subtoken_offsets = tokenize_str(input_tokens, add_special_tokens=True) if self.ret_subtokens: sample[f'{self.input_key}_subtoken_offsets'] = subtoken_offsets cls_is_bos = self.cls_is_bos if cls_is_bos is None: cls_is_bos = input_tokens[0] == BOS sep_is_eos = self.sep_is_eos if sep_is_eos is None: sep_is_eos = input_tokens[-1] == EOS if self.strip_cls_sep: if cls_is_bos: input_tokens = input_tokens[1:] if sep_is_eos: input_tokens = input_tokens[:-1] if not self.ret_mask_and_type: # only need input_ids and token_span, use a light version if input_is_str: prefix_mask = self._init_prefix_mask(input_ids) else: if input_tokens: return_offsets_mapping = tokenizer.is_fast and self.ret_subtokens encodings = tokenizer.batch_encode_plus( input_tokens, return_offsets_mapping=return_offsets_mapping, # Many tokenizers do not offer fast version add_special_tokens=False ) subtoken_ids_per_token = encodings.data['input_ids'] if return_offsets_mapping: offsets_mapping = [encoding.offsets for encoding in encodings.encodings] else: offsets_mapping = [] for token, subtoken_ids in zip(input_tokens, subtoken_ids_per_token): if len(subtoken_ids) > len(token): # … --> ... del subtoken_ids[len(token):] if not subtoken_ids: subtoken_ids = [tokenizer.unk_token_id] # Since non-fast tok generates no mapping, we have to guess char_per_subtoken = max(len(token) // len(subtoken_ids), 1) bes = [(b, b + char_per_subtoken) for b in range(0, len(token), char_per_subtoken)] if len(bes) != len(subtoken_ids): bes[len(subtoken_ids) - 1] = (bes[len(subtoken_ids) - 1][0], len(token)) del bes[len(subtoken_ids):] offsets_mapping.append(bes) else: encodings = SerializableDict() subtoken_ids_per_token = [] encodings.data = {'input_ids': subtoken_ids_per_token} if self.check_space_before: # noinspection PyUnboundLocalVariable for token, subtokens, mapping, encoding in zip(input_tokens, subtoken_ids_per_token, offsets_mapping, encodings.encodings): # Remove ▁ generated by spm for 2 reasons: # 1. During decoding, mostly no ▁ will be created unless blanks are placed between tokens (which # is true for English but in English it will likely be concatenated to the token following it) # 2. For T5, '▁' is used as CLS if len(subtokens) > 1 and encoding.tokens[0] == '▁': subtokens.pop(0) if mapping: mapping.pop(0) # Some tokens get stripped out subtoken_ids_per_token = [ids if ids else [tokenizer.unk_token_id] for ids in subtoken_ids_per_token] input_ids = sum(subtoken_ids_per_token, [self.cls_token_id]) if self.sep_is_eos is None: # None means to check whether sep is at the tail or between tokens if sep_is_eos: input_ids += [self.sep_token_id] elif self.sep_token_id not in input_ids: input_ids += [self.sep_token_id] else: input_ids += [self.sep_token_id] # else self.sep_is_eos == False means sep is between tokens and don't bother to check if self.ret_subtokens: prefix_mask = self._init_prefix_mask(input_ids) # if self.check_space_before: # if offsets_mapping[0] and not input_tokens[0].startswith(' '): # prefix_mask[1] = False else: prefix_mask = [False] * len(input_ids) offset = 1 for _subtokens in subtoken_ids_per_token: prefix_mask[offset] = True offset += len(_subtokens) if self.ret_subtokens: subtoken_offsets = [] for token, offsets in zip(input_tokens, offsets_mapping): if offsets: subtoken_offsets.append(offsets) else: subtoken_offsets.append([(0, len(token))]) if self.ret_subtokens_group: sample[f'{self.input_key}_subtoken_offsets_group'] = subtoken_offsets sample[f'{self.input_key}_subtoken_offsets'] = sum(subtoken_offsets, []) else: input_ids, attention_mask, token_type_ids, prefix_mask = \ convert_examples_to_features(input_tokens, None, tokenizer, cls_token_at_end=self.cls_token_at_end, # xlnet has a cls token at the end cls_token=tokenizer.cls_token, cls_token_segment_id=self.cls_token_segment_id, sep_token=self.sep_token, sep_token_extra=self.sep_token_extra, # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 pad_on_left=self.pad_on_left, # pad on the left for xlnet pad_token_id=self.pad_token_id, pad_token_segment_id=self.pad_token_segment_id, pad_token_label_id=0, do_padding=self.do_padding) if len(input_ids) > self.max_seq_length: if self.truncate_long_sequences: # raise SequenceTooLong( # f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. ' # f'For sequence tasks, truncate_long_sequences = True is not supported.' # f'You are recommended to split your long text into several sentences within ' # f'{self.max_seq_length - 2} tokens beforehand. ' # f'Or simply set truncate_long_sequences = False to enable sliding window.') input_ids = input_ids[:self.max_seq_length] prefix_mask = prefix_mask[:self.max_seq_length] warnings.warn( f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. ' f'The exceeded part will be truncated and ignored. ' f'You are recommended to split your long text into several sentences within ' f'{self.max_seq_length - 2} tokens beforehand.' f'Or simply set truncate_long_sequences = False to enable sliding window.' ) else: input_ids = self.sliding_window(input_ids, input_ids[-1] == self.sep_token_id) if prefix_mask: if cls_is_bos: prefix_mask[0] = True if sep_is_eos: prefix_mask[-1] = True outputs = [input_ids] if self.ret_mask_and_type: # noinspection PyUnboundLocalVariable outputs += [attention_mask, token_type_ids] if self.ret_prefix_mask: outputs += [prefix_mask] if ret_token_span and prefix_mask: if cls_is_bos: token_span = [[0]] else: token_span = [] offset = 1 span = [] for mask in prefix_mask[1:len(prefix_mask) if sep_is_eos is None else -1]: # skip [CLS] and [SEP] if mask and span: token_span.append(span) span = [] span.append(offset) offset += 1 if span: token_span.append(span) if sep_is_eos: assert offset == len(prefix_mask) - 1 token_span.append([offset]) outputs.append(token_span) for k, v in zip(self.output_key, outputs): sample[k] = v return sample