示例#1
0
def make_raw_text_if_necessary(home: str):
    home = get_resource(home)
    jsonpath = os.path.join(home, 'text.jsonlines')
    if os.path.isfile(jsonpath):
        return
    sents = batch_load_raw_text(home)
    save_json(sents, jsonpath)
示例#2
0
 def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True,
         finetune: str = None, **kwargs):
     self._capture_config(locals())
     self.transform = self.build_transform(**self.config)
     if not save_dir:
         save_dir = tempdir_human()
     if not logger:
         logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN)
     logger.info('Hyperparameter:\n' + self.config.to_json())
     num_examples = self.build_vocab(trn_data, logger)
     # assert num_examples, 'You forgot to return the number of training examples in your build_vocab'
     logger.info('Building...')
     train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None
     self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None
     model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True))
     logger.info('Model built:\n' + summary_of_model(self.model))
     if finetune:
         finetune = get_resource(finetune)
         if os.path.isdir(finetune):
             finetune = os.path.join(finetune, 'model.h5')
         model.load_weights(finetune, by_name=True, skip_mismatch=True)
         logger.info(f'Loaded pretrained weights from {finetune} for finetuning')
     self.save_config(save_dir)
     self.save_vocabs(save_dir)
     self.save_meta(save_dir)
     trn_data = self.build_train_dataset(trn_data, batch_size, num_examples)
     dev_data = self.build_valid_dataset(dev_data, batch_size)
     callbacks = self.build_callbacks(save_dir, **merge_dict(self.config, overwrite=True, logger=logger))
     # need to know #batches, otherwise progbar crashes
     dev_steps = math.ceil(self.num_samples_in(dev_data) / batch_size)
     checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint)
     timer = Timer()
     try:
         history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs,
                                                num_examples=num_examples,
                                                train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps,
                                                callbacks=callbacks, logger=logger, model=model, optimizer=optimizer,
                                                loss=loss,
                                                metrics=metrics, overwrite=True))
     except KeyboardInterrupt:
         print()
         if not checkpoint or checkpoint.best in (np.Inf, -np.Inf):
             self.save_weights(save_dir)
             logger.info('Aborted with model saved')
         else:
             logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}')
         # noinspection PyTypeChecker
         history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History)
     delta_time = timer.stop()
     best_epoch_ago = 0
     if history and hasattr(history, 'epoch'):
         trained_epoch = len(history.epoch)
         logger.info('Trained {} epochs in {}, each epoch takes {}'.
                     format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time))
         save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder)
         monitor_history: List = history.history.get(checkpoint.monitor, None)
         if monitor_history:
             best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best)
         if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]:
             logger.info(f'Restored the best model saved with best '
                         f'{checkpoint.monitor} = {checkpoint.best:.4f} '
                         f'saved {best_epoch_ago} epochs ago')
             self.load_weights(save_dir)  # restore best model
     return history
示例#3
0
 def save_meta(self, save_dir, filename='meta.json', **kwargs):
     self.meta['create_time']: now_datetime()
     self.meta.update(kwargs)
     save_json(self.meta, os.path.join(save_dir, filename))
示例#4
0
def load_from_meta_file(save_dir: str, meta_filename='meta.json', transform_only=False, verbose=HANLP_VERBOSE,
                        **kwargs) -> Component:
    """
    Load a component from a ``meta.json`` (legacy TensorFlow component) or a ``config.json`` file.

    Args:
        save_dir: The identifier.
        meta_filename (str): The meta file of that saved component, which stores the classpath and version.
        transform_only: Load and return only the transform.
        **kwargs: Extra parameters passed to ``component.load()``.

    Returns:

        A component.
    """
    identifier = save_dir
    load_path = save_dir
    save_dir = get_resource(save_dir)
    if save_dir.endswith('.json'):
        meta_filename = os.path.basename(save_dir)
        save_dir = os.path.dirname(save_dir)
    metapath = os.path.join(save_dir, meta_filename)
    if not os.path.isfile(metapath):
        tf_model = False
        metapath = os.path.join(save_dir, 'config.json')
    else:
        tf_model = True
    cls = None
    if not os.path.isfile(metapath):
        tips = ''
        if save_dir.isupper():
            from difflib import SequenceMatcher
            similar_keys = sorted(pretrained.ALL.keys(),
                                  key=lambda k: SequenceMatcher(None, k, identifier).ratio(),
                                  reverse=True)[:5]
            tips = f'Check its spelling based on the available keys:\n' + \
                   f'{sorted(pretrained.ALL.keys())}\n' + \
                   f'Tips: it might be one of {similar_keys}'
        # These components are not intended to be loaded in this way, but I'm tired of explaining it again and again
        if identifier in pretrained.word2vec.ALL.values():
            save_dir = os.path.dirname(save_dir)
            metapath = os.path.join(save_dir, 'config.json')
            save_json({'classpath': 'hanlp.layers.embeddings.word2vec.Word2VecEmbeddingComponent',
                       'embed': {'classpath': 'hanlp.layers.embeddings.word2vec.Word2VecEmbedding',
                                 'embed': identifier, 'field': 'token', 'normalize': 'l2'},
                       'hanlp_version': version.__version__}, metapath)
        elif identifier in pretrained.fasttext.ALL.values():
            save_dir = os.path.dirname(save_dir)
            metapath = os.path.join(save_dir, 'config.json')
            save_json({'classpath': 'hanlp.layers.embeddings.fast_text.FastTextEmbeddingComponent',
                       'embed': {'classpath': 'hanlp.layers.embeddings.fast_text.FastTextEmbedding',
                                 'filepath': identifier, 'src': 'token'},
                       'hanlp_version': version.__version__}, metapath)
        else:
            raise FileNotFoundError(f'The identifier {save_dir} resolves to a nonexistent meta file {metapath}. {tips}')
    meta: dict = load_json(metapath)
    cls = meta.get('classpath', cls)
    if not cls:
        cls = meta.get('class_path', None)  # For older version
    if tf_model:
        # tf models are trained with version < 2.1. To migrate them to 2.1, map their classpath to new locations
        upgrade = {
            'hanlp.components.tok_tf.TransformerTokenizerTF': 'hanlp.components.tokenizers.tok_tf.TransformerTokenizerTF',
            'hanlp.components.pos.RNNPartOfSpeechTagger': 'hanlp.components.taggers.pos_tf.RNNPartOfSpeechTaggerTF',
            'hanlp.components.pos_tf.RNNPartOfSpeechTaggerTF': 'hanlp.components.taggers.pos_tf.RNNPartOfSpeechTaggerTF',
            'hanlp.components.pos_tf.CNNPartOfSpeechTaggerTF': 'hanlp.components.taggers.pos_tf.CNNPartOfSpeechTaggerTF',
            'hanlp.components.ner_tf.TransformerNamedEntityRecognizerTF': 'hanlp.components.ner.ner_tf.TransformerNamedEntityRecognizerTF',
            'hanlp.components.parsers.biaffine_parser.BiaffineDependencyParser': 'hanlp.components.parsers.biaffine_parser_tf.BiaffineDependencyParserTF',
            'hanlp.components.parsers.biaffine_parser.BiaffineSemanticDependencyParser': 'hanlp.components.parsers.biaffine_parser_tf.BiaffineSemanticDependencyParserTF',
            'hanlp.components.tok_tf.NgramConvTokenizerTF': 'hanlp.components.tokenizers.tok_tf.NgramConvTokenizerTF',
            'hanlp.components.classifiers.transformer_classifier.TransformerClassifier': 'hanlp.components.classifiers.transformer_classifier_tf.TransformerClassifierTF',
            'hanlp.components.taggers.transformers.transformer_tagger.TransformerTagger': 'hanlp.components.taggers.transformers.transformer_tagger_tf.TransformerTaggerTF',
            'hanlp.components.tok.NgramConvTokenizer': 'hanlp.components.tokenizers.tok_tf.NgramConvTokenizerTF',
        }
        cls = upgrade.get(cls, cls)
    assert cls, f'{meta_filename} doesn\'t contain classpath field'
    try:
        obj: Component = object_from_classpath(cls)
        if hasattr(obj, 'load'):
            if transform_only:
                # noinspection PyUnresolvedReferences
                obj.load_transform(save_dir)
            else:
                if os.path.isfile(os.path.join(save_dir, 'config.json')):
                    obj.load(save_dir, verbose=verbose, **kwargs)
                else:
                    obj.load(metapath, **kwargs)
            obj.config['load_path'] = load_path
        return obj
    except ModuleNotFoundError as e:
        if isdebugging():
            raise e from None
        else:
            raise ModuleNotFoundError(
                f'Some modules ({e.name} etc.) required by this model are missing. Please install the full version:'
                '\n\n\tpip install hanlp[full] -U') from None
    except ValueError as e:
        if e.args and isinstance(e.args[0], str) and 'Internet connection' in e.args[0]:
            raise ConnectionError(
                'Hugging Face 🤗 Transformers failed to download because your Internet connection is either off or bad.\n'
                'See https://hanlp.hankcs.com/docs/install.html#server-without-internet for solutions.') \
                from None
        raise e from None
    except Exception as e:
        # Some users often install an incompatible tf and put the blame on HanLP. Teach them the basics.
        try:
            you_installed_wrong_versions, extras = check_version_conflicts(extras=('full',) if tf_model else None)
        except:
            you_installed_wrong_versions, extras = None, None
        if you_installed_wrong_versions:
            raise version.NotCompatible(you_installed_wrong_versions + '\nPlease reinstall HanLP in the right way:' +
                                        '\n\n\tpip install --upgrade hanlp' + (
                                            f'[{",".join(extras)}]' if extras else '')) from None
        eprint(f'Failed to load {identifier}.')
        from pkg_resources import parse_version
        model_version = meta.get("hanlp_version", '2.0.0-alpha.0')
        if model_version == '2.0.0':  # Quick fix: the first version used a wrong string
            model_version = '2.0.0-alpha.0'
        model_version = parse_version(model_version)
        installed_version = parse_version(version.__version__)
        try:
            latest_version = get_latest_info_from_pypi()
        except:
            latest_version = None
        if model_version > installed_version:
            eprint(f'{identifier} was created with hanlp-{model_version}, '
                   f'while you are running a lower version: {installed_version}. ')
        if installed_version != latest_version:
            eprint(
                f'Please upgrade HanLP with:\n'
                f'\n\tpip install --upgrade hanlp\n')
        eprint(
            'If the problem still persists, please submit an issue to https://github.com/hankcs/HanLP/issues\n'
            'When reporting an issue, make sure to paste the FULL ERROR LOG below.')

        eprint(f'{"ERROR LOG BEGINS":=^80}')
        import platform
        eprint(f'OS: {platform.platform()}')
        eprint(f'Python: {platform.python_version()}')
        import torch
        eprint(f'PyTorch: {torch.__version__}')
        if tf_model:
            try:
                import tensorflow
                tf_version = tensorflow.__version__
            except ModuleNotFoundError:
                tf_version = 'not installed'
            eprint(f'TensorFlow: {tf_version}')
        eprint(f'HanLP: {version.__version__}')
        import sys
        sys.stderr.flush()
        try:
            if e.args and isinstance(e.args, tuple) and isinstance(e.args[0], str):
                e.args = (e.args[0] + f'\n{"ERROR LOG ENDS":=^80}',) + e.args[1:]
        except:
            pass
        raise e from None
示例#5
0
 def save_json(self, path):
     save_json(self, path)
示例#6
0
 def save_json(self, path):
     save_json(self.to_dict(), path)
示例#7
0
 def save(self, filepath):
     save_json(self.meta, filepath)