def testClassRegistry(self): registry = misc.ClassRegistry(base_class=models.Model) self.assertIsNone(registry.get("TransformerBig")) registry.register(models.TransformerBig) self.assertEqual(registry.get("TransformerBig"), models.TransformerBig) registry.register(models.TransformerBig, name="TransformerLarge") self.assertEqual(registry.get("TransformerLarge"), models.TransformerBig) self.assertSetEqual(registry.class_names, set(["TransformerBig", "TransformerLarge"])) registry.register(models.TransformerBaseRelative, alias="TransformerRelative") self.assertEqual(registry.get("TransformerBaseRelative"), models.TransformerBaseRelative) self.assertEqual(registry.get("TransformerRelative"), models.TransformerBaseRelative) with self.assertRaises(ValueError): registry.register(models.TransformerBig) with self.assertRaises(TypeError): registry.register(misc.ClassRegistry)
Returns: The score or dictionary of scores. """ raise NotImplementedError() def lower_is_better(self): """Returns ``True`` if a lower score is better.""" return False def higher_is_better(self): """Returns ``True`` if a higher score is better.""" return not self.lower_is_better() _SCORERS_REGISTRY = misc.ClassRegistry(base_class=Scorer) register_scorer = _SCORERS_REGISTRY.register # pylint: disable=invalid-name @register_scorer(name="rouge") class ROUGEScorer(Scorer): """ROUGE scorer based on https://github.com/pltrdy/rouge.""" def __init__(self): super().__init__("rouge") @property def scores_name(self): return {"rouge-1", "rouge-2", "rouge-l"} def __call__(self, ref_path, hyp_path):
"""Catalog of predefined models.""" import tensorflow as tf import tensorflow_addons as tfa from opennmt import decoders, encoders, inputters, layers from opennmt.models import ( language_model, model, sequence_tagger, sequence_to_sequence, transformer, ) from opennmt.utils import misc _CATALOG_MODELS_REGISTRY = misc.ClassRegistry(base_class=model.Model) register_model_in_catalog = _CATALOG_MODELS_REGISTRY.register def list_model_names_from_catalog(): """Lists the models name registered in the catalog.""" return _CATALOG_MODELS_REGISTRY.class_names def get_model_from_catalog(name, as_builder=False): """Gets a model from the catalog. Args: name: The model name in the catalog. as_builder: If ``True``, return a callable building the model on call.
raise NotImplementedError() @abc.abstractmethod def _detokenize_string(self, tokens): """Detokenizes tokens. Args: tokens: A list of Python unicode strings. Returns: A unicode Python string. """ raise NotImplementedError() _TOKENIZERS_REGISTRY = misc.ClassRegistry(base_class=Tokenizer) register_tokenizer = _TOKENIZERS_REGISTRY.register # pylint: disable=invalid-name def make_tokenizer(config=None): """Creates a tokenizer instance from the configuration. Args: config: Path to a configuration file or the configuration dictionary. Returns: A :class:`opennmt.tokenizers.Tokenizer` instance. Raises: ValueError: if :obj:`config` is invalid.
"""Optimization utilities.""" import inspect import tensorflow as tf import tensorflow_addons as tfa from tensorflow_addons.optimizers.weight_decay_optimizers import ( DecoupledWeightDecayExtension, ) from opennmt.utils import misc _OPTIMIZERS_REGISTRY = misc.ClassRegistry(base_class=tf.keras.optimizers.Optimizer) register_optimizer = _OPTIMIZERS_REGISTRY.register def get_optimizer_class(name): """Returns the optimizer class. Args: name: The optimizer name. Returns: A class extending ``tf.keras.optimizers.Optimizer``. Raises: ValueError: if :obj:`name` can not be resolved to an optimizer class. """
if extra_assets: assets_extra = os.path.join(export_dir, "assets.extra") tf.io.gfile.makedirs(assets_extra) for filename, path in extra_assets.items(): tf.io.gfile.copy(path, os.path.join(assets_extra, filename), overwrite=True) tf.get_logger().info("Extra assets written to: %s", assets_extra) @abc.abstractmethod def _export_model(self, model, export_dir): raise NotImplementedError() _EXPORTERS_REGISTRY = misc.ClassRegistry(base_class=Exporter) register_exporter = _EXPORTERS_REGISTRY.register def make_exporter(name, **kwargs): """Creates a new exporter. Args: name: The exporter name. **kwargs: Additional arguments to pass to the exporter constructor. Returns: A :class:`opennmt.utils.Exporter` instance. Raises: ValueError: if :obj:`name` is invalid.
"""Define learning rate decay functions.""" import tensorflow as tf import numpy as np from opennmt.utils import misc _LR_SCHEDULES_REGISTRY = misc.ClassRegistry( base_class=tf.keras.optimizers.schedules.LearningRateSchedule) register_learning_rate_schedule = _LR_SCHEDULES_REGISTRY.register # pylint: disable=invalid-name def get_lr_schedule_class(name): """Returns the learning rate schedule class. Args: name: The schedule class name. Returns: A class extending ``tf.keras.optimizers.schedules.LearningRateSchedule``. Raises: ValueError: if :obj:`name` can not be resolved to an existing schedule. """ schedule_class = None if schedule_class is None: schedule_class = getattr(tf.keras.optimizers.schedules, name, None) if schedule_class is None: schedule_class = _LR_SCHEDULES_REGISTRY.get(name) if schedule_class is None: