def make_raw_text_if_necessary(home: str): home = get_resource(home) jsonpath = os.path.join(home, 'text.jsonlines') if os.path.isfile(jsonpath): return sents = batch_load_raw_text(home) save_json(sents, jsonpath)
def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True, finetune: str = None, **kwargs): self._capture_config(locals()) self.transform = self.build_transform(**self.config) if not save_dir: save_dir = tempdir_human() if not logger: logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN) logger.info('Hyperparameter:\n' + self.config.to_json()) num_examples = self.build_vocab(trn_data, logger) # assert num_examples, 'You forgot to return the number of training examples in your build_vocab' logger.info('Building...') train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True)) logger.info('Model built:\n' + summary_of_model(self.model)) if finetune: finetune = get_resource(finetune) if os.path.isdir(finetune): finetune = os.path.join(finetune, 'model.h5') model.load_weights(finetune, by_name=True, skip_mismatch=True) logger.info(f'Loaded pretrained weights from {finetune} for finetuning') self.save_config(save_dir) self.save_vocabs(save_dir) self.save_meta(save_dir) trn_data = self.build_train_dataset(trn_data, batch_size, num_examples) dev_data = self.build_valid_dataset(dev_data, batch_size) callbacks = self.build_callbacks(save_dir, **merge_dict(self.config, overwrite=True, logger=logger)) # need to know #batches, otherwise progbar crashes dev_steps = math.ceil(self.num_samples_in(dev_data) / batch_size) checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint) timer = Timer() try: history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs, num_examples=num_examples, train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps, callbacks=callbacks, logger=logger, model=model, optimizer=optimizer, loss=loss, metrics=metrics, overwrite=True)) except KeyboardInterrupt: print() if not checkpoint or checkpoint.best in (np.Inf, -np.Inf): self.save_weights(save_dir) logger.info('Aborted with model saved') else: logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}') # noinspection PyTypeChecker history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History) delta_time = timer.stop() best_epoch_ago = 0 if history and hasattr(history, 'epoch'): trained_epoch = len(history.epoch) logger.info('Trained {} epochs in {}, each epoch takes {}'. format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time)) save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder) monitor_history: List = history.history.get(checkpoint.monitor, None) if monitor_history: best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best) if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]: logger.info(f'Restored the best model saved with best ' f'{checkpoint.monitor} = {checkpoint.best:.4f} ' f'saved {best_epoch_ago} epochs ago') self.load_weights(save_dir) # restore best model return history
def save_meta(self, save_dir, filename='meta.json', **kwargs): self.meta['create_time']: now_datetime() self.meta.update(kwargs) save_json(self.meta, os.path.join(save_dir, filename))
def load_from_meta_file(save_dir: str, meta_filename='meta.json', transform_only=False, verbose=HANLP_VERBOSE, **kwargs) -> Component: """ Load a component from a ``meta.json`` (legacy TensorFlow component) or a ``config.json`` file. Args: save_dir: The identifier. meta_filename (str): The meta file of that saved component, which stores the classpath and version. transform_only: Load and return only the transform. **kwargs: Extra parameters passed to ``component.load()``. Returns: A component. """ identifier = save_dir load_path = save_dir save_dir = get_resource(save_dir) if save_dir.endswith('.json'): meta_filename = os.path.basename(save_dir) save_dir = os.path.dirname(save_dir) metapath = os.path.join(save_dir, meta_filename) if not os.path.isfile(metapath): tf_model = False metapath = os.path.join(save_dir, 'config.json') else: tf_model = True cls = None if not os.path.isfile(metapath): tips = '' if save_dir.isupper(): from difflib import SequenceMatcher similar_keys = sorted(pretrained.ALL.keys(), key=lambda k: SequenceMatcher(None, k, identifier).ratio(), reverse=True)[:5] tips = f'Check its spelling based on the available keys:\n' + \ f'{sorted(pretrained.ALL.keys())}\n' + \ f'Tips: it might be one of {similar_keys}' # These components are not intended to be loaded in this way, but I'm tired of explaining it again and again if identifier in pretrained.word2vec.ALL.values(): save_dir = os.path.dirname(save_dir) metapath = os.path.join(save_dir, 'config.json') save_json({'classpath': 'hanlp.layers.embeddings.word2vec.Word2VecEmbeddingComponent', 'embed': {'classpath': 'hanlp.layers.embeddings.word2vec.Word2VecEmbedding', 'embed': identifier, 'field': 'token', 'normalize': 'l2'}, 'hanlp_version': version.__version__}, metapath) elif identifier in pretrained.fasttext.ALL.values(): save_dir = os.path.dirname(save_dir) metapath = os.path.join(save_dir, 'config.json') save_json({'classpath': 'hanlp.layers.embeddings.fast_text.FastTextEmbeddingComponent', 'embed': {'classpath': 'hanlp.layers.embeddings.fast_text.FastTextEmbedding', 'filepath': identifier, 'src': 'token'}, 'hanlp_version': version.__version__}, metapath) else: raise FileNotFoundError(f'The identifier {save_dir} resolves to a nonexistent meta file {metapath}. {tips}') meta: dict = load_json(metapath) cls = meta.get('classpath', cls) if not cls: cls = meta.get('class_path', None) # For older version if tf_model: # tf models are trained with version < 2.1. To migrate them to 2.1, map their classpath to new locations upgrade = { 'hanlp.components.tok_tf.TransformerTokenizerTF': 'hanlp.components.tokenizers.tok_tf.TransformerTokenizerTF', 'hanlp.components.pos.RNNPartOfSpeechTagger': 'hanlp.components.taggers.pos_tf.RNNPartOfSpeechTaggerTF', 'hanlp.components.pos_tf.RNNPartOfSpeechTaggerTF': 'hanlp.components.taggers.pos_tf.RNNPartOfSpeechTaggerTF', 'hanlp.components.pos_tf.CNNPartOfSpeechTaggerTF': 'hanlp.components.taggers.pos_tf.CNNPartOfSpeechTaggerTF', 'hanlp.components.ner_tf.TransformerNamedEntityRecognizerTF': 'hanlp.components.ner.ner_tf.TransformerNamedEntityRecognizerTF', 'hanlp.components.parsers.biaffine_parser.BiaffineDependencyParser': 'hanlp.components.parsers.biaffine_parser_tf.BiaffineDependencyParserTF', 'hanlp.components.parsers.biaffine_parser.BiaffineSemanticDependencyParser': 'hanlp.components.parsers.biaffine_parser_tf.BiaffineSemanticDependencyParserTF', 'hanlp.components.tok_tf.NgramConvTokenizerTF': 'hanlp.components.tokenizers.tok_tf.NgramConvTokenizerTF', 'hanlp.components.classifiers.transformer_classifier.TransformerClassifier': 'hanlp.components.classifiers.transformer_classifier_tf.TransformerClassifierTF', 'hanlp.components.taggers.transformers.transformer_tagger.TransformerTagger': 'hanlp.components.taggers.transformers.transformer_tagger_tf.TransformerTaggerTF', 'hanlp.components.tok.NgramConvTokenizer': 'hanlp.components.tokenizers.tok_tf.NgramConvTokenizerTF', } cls = upgrade.get(cls, cls) assert cls, f'{meta_filename} doesn\'t contain classpath field' try: obj: Component = object_from_classpath(cls) if hasattr(obj, 'load'): if transform_only: # noinspection PyUnresolvedReferences obj.load_transform(save_dir) else: if os.path.isfile(os.path.join(save_dir, 'config.json')): obj.load(save_dir, verbose=verbose, **kwargs) else: obj.load(metapath, **kwargs) obj.config['load_path'] = load_path return obj except ModuleNotFoundError as e: if isdebugging(): raise e from None else: raise ModuleNotFoundError( f'Some modules ({e.name} etc.) required by this model are missing. Please install the full version:' '\n\n\tpip install hanlp[full] -U') from None except ValueError as e: if e.args and isinstance(e.args[0], str) and 'Internet connection' in e.args[0]: raise ConnectionError( 'Hugging Face 🤗 Transformers failed to download because your Internet connection is either off or bad.\n' 'See https://hanlp.hankcs.com/docs/install.html#server-without-internet for solutions.') \ from None raise e from None except Exception as e: # Some users often install an incompatible tf and put the blame on HanLP. Teach them the basics. try: you_installed_wrong_versions, extras = check_version_conflicts(extras=('full',) if tf_model else None) except: you_installed_wrong_versions, extras = None, None if you_installed_wrong_versions: raise version.NotCompatible(you_installed_wrong_versions + '\nPlease reinstall HanLP in the right way:' + '\n\n\tpip install --upgrade hanlp' + ( f'[{",".join(extras)}]' if extras else '')) from None eprint(f'Failed to load {identifier}.') from pkg_resources import parse_version model_version = meta.get("hanlp_version", '2.0.0-alpha.0') if model_version == '2.0.0': # Quick fix: the first version used a wrong string model_version = '2.0.0-alpha.0' model_version = parse_version(model_version) installed_version = parse_version(version.__version__) try: latest_version = get_latest_info_from_pypi() except: latest_version = None if model_version > installed_version: eprint(f'{identifier} was created with hanlp-{model_version}, ' f'while you are running a lower version: {installed_version}. ') if installed_version != latest_version: eprint( f'Please upgrade HanLP with:\n' f'\n\tpip install --upgrade hanlp\n') eprint( 'If the problem still persists, please submit an issue to https://github.com/hankcs/HanLP/issues\n' 'When reporting an issue, make sure to paste the FULL ERROR LOG below.') eprint(f'{"ERROR LOG BEGINS":=^80}') import platform eprint(f'OS: {platform.platform()}') eprint(f'Python: {platform.python_version()}') import torch eprint(f'PyTorch: {torch.__version__}') if tf_model: try: import tensorflow tf_version = tensorflow.__version__ except ModuleNotFoundError: tf_version = 'not installed' eprint(f'TensorFlow: {tf_version}') eprint(f'HanLP: {version.__version__}') import sys sys.stderr.flush() try: if e.args and isinstance(e.args, tuple) and isinstance(e.args[0], str): e.args = (e.args[0] + f'\n{"ERROR LOG ENDS":=^80}',) + e.args[1:] except: pass raise e from None
def save_json(self, path): save_json(self, path)
def save_json(self, path): save_json(self.to_dict(), path)
def save(self, filepath): save_json(self.meta, filepath)