Example #1
0
  def testClassRegistry(self):
    registry = misc.ClassRegistry(base_class=models.Model)
    self.assertIsNone(registry.get("TransformerBig"))
    registry.register(models.TransformerBig)
    self.assertEqual(registry.get("TransformerBig"), models.TransformerBig)
    registry.register(models.TransformerBig, name="TransformerLarge")
    self.assertEqual(registry.get("TransformerLarge"), models.TransformerBig)
    self.assertSetEqual(registry.class_names, set(["TransformerBig", "TransformerLarge"]))

    registry.register(models.TransformerBaseRelative, alias="TransformerRelative")
    self.assertEqual(registry.get("TransformerBaseRelative"), models.TransformerBaseRelative)
    self.assertEqual(registry.get("TransformerRelative"), models.TransformerBaseRelative)

    with self.assertRaises(ValueError):
      registry.register(models.TransformerBig)
    with self.assertRaises(TypeError):
      registry.register(misc.ClassRegistry)
Example #2
0
    Returns:
      The score or dictionary of scores.
    """
    raise NotImplementedError()

  def lower_is_better(self):
    """Returns ``True`` if a lower score is better."""
    return False

  def higher_is_better(self):
    """Returns ``True`` if a higher score is better."""
    return not self.lower_is_better()


_SCORERS_REGISTRY = misc.ClassRegistry(base_class=Scorer)
register_scorer = _SCORERS_REGISTRY.register  # pylint: disable=invalid-name


@register_scorer(name="rouge")
class ROUGEScorer(Scorer):
  """ROUGE scorer based on https://github.com/pltrdy/rouge."""

  def __init__(self):
    super().__init__("rouge")

  @property
  def scores_name(self):
    return {"rouge-1", "rouge-2", "rouge-l"}

  def __call__(self, ref_path, hyp_path):
Example #3
0
"""Catalog of predefined models."""

import tensorflow as tf
import tensorflow_addons as tfa

from opennmt import decoders, encoders, inputters, layers
from opennmt.models import (
    language_model,
    model,
    sequence_tagger,
    sequence_to_sequence,
    transformer,
)
from opennmt.utils import misc

_CATALOG_MODELS_REGISTRY = misc.ClassRegistry(base_class=model.Model)

register_model_in_catalog = _CATALOG_MODELS_REGISTRY.register


def list_model_names_from_catalog():
    """Lists the models name registered in the catalog."""
    return _CATALOG_MODELS_REGISTRY.class_names


def get_model_from_catalog(name, as_builder=False):
    """Gets a model from the catalog.

    Args:
      name: The model name in the catalog.
      as_builder: If ``True``, return a callable building the model on call.
Example #4
0
        raise NotImplementedError()

    @abc.abstractmethod
    def _detokenize_string(self, tokens):
        """Detokenizes tokens.

    Args:
      tokens: A list of Python unicode strings.

    Returns:
      A unicode Python string.
    """
        raise NotImplementedError()


_TOKENIZERS_REGISTRY = misc.ClassRegistry(base_class=Tokenizer)

register_tokenizer = _TOKENIZERS_REGISTRY.register  # pylint: disable=invalid-name


def make_tokenizer(config=None):
    """Creates a tokenizer instance from the configuration.

  Args:
    config: Path to a configuration file or the configuration dictionary.

  Returns:
    A :class:`opennmt.tokenizers.Tokenizer` instance.

  Raises:
    ValueError: if :obj:`config` is invalid.
Example #5
0
"""Optimization utilities."""

import inspect

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow_addons.optimizers.weight_decay_optimizers import (
    DecoupledWeightDecayExtension,
)

from opennmt.utils import misc


_OPTIMIZERS_REGISTRY = misc.ClassRegistry(base_class=tf.keras.optimizers.Optimizer)

register_optimizer = _OPTIMIZERS_REGISTRY.register


def get_optimizer_class(name):
    """Returns the optimizer class.

    Args:
      name: The optimizer name.

    Returns:
      A class extending ``tf.keras.optimizers.Optimizer``.

    Raises:
      ValueError: if :obj:`name` can not be resolved to an optimizer class.
    """
Example #6
0
            if extra_assets:
                assets_extra = os.path.join(export_dir, "assets.extra")
                tf.io.gfile.makedirs(assets_extra)
                for filename, path in extra_assets.items():
                    tf.io.gfile.copy(path,
                                     os.path.join(assets_extra, filename),
                                     overwrite=True)
                tf.get_logger().info("Extra assets written to: %s",
                                     assets_extra)

    @abc.abstractmethod
    def _export_model(self, model, export_dir):
        raise NotImplementedError()


_EXPORTERS_REGISTRY = misc.ClassRegistry(base_class=Exporter)
register_exporter = _EXPORTERS_REGISTRY.register


def make_exporter(name, **kwargs):
    """Creates a new exporter.

    Args:
      name: The exporter name.
      **kwargs: Additional arguments to pass to the exporter constructor.

    Returns:
      A :class:`opennmt.utils.Exporter` instance.

    Raises:
      ValueError: if :obj:`name` is invalid.
Example #7
0
"""Define learning rate decay functions."""

import tensorflow as tf
import numpy as np

from opennmt.utils import misc


_LR_SCHEDULES_REGISTRY = misc.ClassRegistry(
    base_class=tf.keras.optimizers.schedules.LearningRateSchedule)

register_learning_rate_schedule = _LR_SCHEDULES_REGISTRY.register  # pylint: disable=invalid-name

def get_lr_schedule_class(name):
  """Returns the learning rate schedule class.

  Args:
    name: The schedule class name.

  Returns:
    A class extending ``tf.keras.optimizers.schedules.LearningRateSchedule``.

  Raises:
    ValueError: if :obj:`name` can not be resolved to an existing schedule.
  """
  schedule_class = None
  if schedule_class is None:
    schedule_class = getattr(tf.keras.optimizers.schedules, name, None)
  if schedule_class is None:
    schedule_class = _LR_SCHEDULES_REGISTRY.get(name)
  if schedule_class is None: