Ejemplo n.º 1
0
def clgen_preprocessor(func: PreprocessorFunction) -> PreprocessorFunction:
    """A decorator which marks a function as a CLgen preprocessor.

  A CLgen preprocessor is accessible using GetPreprocessFunction(), and is a
  function which accepts a single parameter 'text', and returns a string.
  Type hinting is used to ensure that any function wrapped with this decorator
  has the appropriate argument and return type. If the function does not, an
  InternalError is raised at the time that the module containing the function
  is imported.

  Args:
    func: The preprocessor function to decorate.

  Returns:
    The decorated preprocessor function.

  Raises:
    InternalError: If the function being wrapped does not have the signature
      'def func(text: str) -> str:'.
  """
    type_hints = typing.get_type_hints(func)
    if not type_hints == {"text": str, "return": str}:
        raise errors.InternalError(
            f"Preprocessor {func.__name__} does not have signature "
            f'"def {func.__name__}(text: str) -> str".')
    func.__dict__["is_clgen_preprocessor"] = True
    return func
Ejemplo n.º 2
0
def GetTerminationCriteria(
    config: typing.List[sampler_pb2.SampleTerminationCriterion],
) -> typing.List[TerminationCriterionBase]:
    """Build a list of termination criteria from config protos.

  Args:
    config: A list of SampleTerminationCriterion protos.

  Returns:
    A list of TerminationCriterion instances.

  Raises:
    UserError: In case of invalid configs.
    InternalError: If any of the termination criteria are unrecognized.
  """
    terminators = []
    for criterion in config:
        if criterion.HasField("maxlen"):
            terminators.append(MaxlenTerminationCriterion(criterion.maxlen))
        elif criterion.HasField("symtok"):
            terminators.append(SymmetricalTokenDepthCriterion(
                criterion.symtok))
        else:
            raise errors.InternalError("Unknown Sampler.termination_criteria")
    return terminators
Ejemplo n.º 3
0
def BuildOptimizer(config: model_pb2.Model) -> "keras.optimizers.Optimizer":
    """Construct the training optimizer from config.

  Args:
    config: A Model config proto.

  Raises:
    InternalError: If the value of the optimizer field is not understood.
  """
    # Deferred importing of Keras so that we don't have to activate the
    # TensorFlow backend every time we import this module.
    import keras

    # We do not use *any* default values for arguments, in case for whatever
    # reason the Keras API changes a default arg.
    if config.training.HasField("adam_optimizer"):
        opts = {}
        opt = config.training.adam_optimizer
        if opt.normalized_gradient_clip_micros:
            opts["clipnorm"] = opt.normalized_gradient_clip_micros / 1e6
        return keras.optimizers.Adam(
            lr=opt.initial_learning_rate_micros / 1e6,
            beta_1=opt.beta_1_micros / 1e6,
            beta_2=opt.beta_2_micros / 1e6,
            epsilon=None,
            decay=opt.learning_rate_decay_per_epoch_micros / 1e6,
            amsgrad=False,
            **opts,
        )
    elif config.training.HasField("rmsprop_optimizer"):
        opt = config.training.rmsprop_optimizer
        return keras.optimizers.RMSprop(
            lr=opt.initial_learning_rate_micros / 1e6,
            decay=opt.initial_learning_rate_micros / 1e6,
            rho=0.9,
            epsilon=None,
        )
    else:
        raise errors.InternalError(
            "Unrecognized value: 'TrainingOptions.optimizer'")
Ejemplo n.º 4
0
    def __init__(self, config: model_pb2.Model):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, model_pb2.Model):
            t = type(config).__name__
            raise TypeError(f"Config must be a Model proto. Received: '{t}'")
        # Validate config options.
        if config.training.sequence_length < 1:
            raise errors.UserError(
                'TrainingOptions.sequence_length must be >= 1')

        self.config = model_pb2.Model()
        self.config.CopyFrom(builders.AssertIsBuildable(config))
        self.corpus = corpuses.Corpus(config.corpus)
        self.hash = self._ComputeHash(self.corpus, self.config)
        self.cache = cache.mkcache('model', self.hash)
        # Create the necessary cache directories.
        (self.cache.path / 'checkpoints').mkdir(exist_ok=True)
        (self.cache.path / 'samples').mkdir(exist_ok=True)
        (self.cache.path / 'logs').mkdir(exist_ok=True)

        # Create symlink to encoded corpus.
        symlink = self.cache.path / 'corpus'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(
                    pathlib.Path(
                        self.corpus.encoded.url[len('sqlite:///'):]).parent,
                    self.cache.path), symlink)

        # Create symlink to the atomizer.
        symlink = self.cache.path / 'atomizer'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(self.corpus.atomizer_path, self.cache.path),
                symlink)

        # Validate metadata against cache.
        if self.cache.get('META.pbtxt'):
            cached_meta = pbutil.FromFile(
                pathlib.Path(self.cache['META.pbtxt']),
                internal_pb2.ModelMeta())
            # Exclude num_epochs and corpus location from metadata comparison.
            config_to_compare = model_pb2.Model()
            config_to_compare.CopyFrom(self.config)
            config_to_compare.corpus.ClearField('contentfiles')
            config_to_compare.training.ClearField('num_epochs')
            # These fields should have already been cleared, but we'll do it again
            # so that metadata comparisons don't fail when the cached meta schema
            # is updated.
            cached_to_compare = model_pb2.Model()
            cached_to_compare.CopyFrom(cached_meta.config)
            cached_to_compare.corpus.ClearField('contentfiles')
            cached_to_compare.training.ClearField('num_epochs')
            if config_to_compare != cached_to_compare:
                raise errors.InternalError('Metadata mismatch')
            self.meta = cached_meta
        else:
            self.meta = internal_pb2.ModelMeta()
            self.meta.config.CopyFrom(self.config)
            self._WriteMetafile()

        self.backend = {
            model_pb2.NetworkArchitecture.TENSORFLOW:
            tensorflow_backend.TensorFlowBackend,
            model_pb2.NetworkArchitecture.KERAS: keras_backend.KerasBackend,
        }[config.architecture.backend](self.config, self.cache, self.corpus)
Ejemplo n.º 5
0
def AssertIsBuildable(config: model_pb2.Model) -> model_pb2.Model:
  """Assert that a model configuration is buildable.

  Args:
    config: A model proto.

  Returns:
    The input model proto, unmodified.

  Raises:
    UserError: If the model is not buildable.
    InternalError: If the value of the training.optimizer field is not
      understood.
  """
  # Any change to the Model proto schema will require a change to this function.
  try:
    pbutil.AssertFieldIsSet(config, 'corpus')
    pbutil.AssertFieldIsSet(config, 'architecture')
    pbutil.AssertFieldIsSet(config, 'training')
    pbutil.AssertFieldIsSet(config.architecture, 'backend')
    pbutil.AssertFieldIsSet(config.architecture, 'neuron_type')
    if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS:
      pbutil.AssertFieldConstraint(
          config.architecture, 'embedding_size', lambda x: 0 < x,
          'NetworkArchitecture.embedding_size must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'neurons_per_layer', lambda x: 0 < x,
        'NetworkArchitecture.neurons_per_layer must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'num_layers', lambda x: 0 < x,
        'NetworkArchitecture.num_layers must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'post_layer_dropout_micros',
        lambda x: 0 <= x <= 1000000,
        'NetworkArchitecture.post_layer_dropout_micros '
        'must be >= 0 and <= 1000000')
    pbutil.AssertFieldConstraint(
        config.training, 'num_epochs', lambda x: 0 < x,
        'TrainingOptions.num_epochs must be > 0')
    pbutil.AssertFieldIsSet(
        config.training, 'shuffle_corpus_contentfiles_between_epochs')
    pbutil.AssertFieldConstraint(
        config.training, 'batch_size', lambda x: 0 < x,
        'TrainingOptions.batch_size must be > 0')
    pbutil.AssertFieldIsSet(config.training, 'optimizer')
    if config.training.HasField('adam_optimizer'):
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer, 'initial_learning_rate_micros',
          lambda x: 0 <= x,
          'AdamOptimizer.initial_learning_rate_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
          'AdamOptimizer.learning_rate_decay_per_epoch_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'beta_1_micros', lambda x: 0 <= x <= 1000000,
          'AdamOptimizer.beta_1_micros must be >= 0 and <= 1000000')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'beta_2_micros', lambda x: 0 <= x <= 1000000,
          'AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'normalized_gradient_clip_micros', lambda x: 0 <= x,
          'AdamOptimizer.normalized_gradient_clip_micros must be >= 0')
    elif config.training.HasField('rmsprop_optimizer'):
      pbutil.AssertFieldConstraint(
          config.training.rmsprop_optimizer, 'initial_learning_rate_micros',
          lambda x: 0 <= x,
          'RmsPropOptimizer.initial_learning_rate_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.rmsprop_optimizer,
          'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
          'RmsPropOptimizer.learning_rate_decay_per_epoch_micros must be >= 0')
    else:
      raise errors.InternalError(
          "Unrecognized value: 'TrainingOptions.optimizer'")
  except pbutil.ProtoValueError as e:
    raise errors.UserError(str(e))
  return config
Ejemplo n.º 6
0
def MockPreprocessorInternalError(text: str) -> str:
  """A mock preprocessor which raises a BadCodeException."""
  del text
  raise errors.InternalError('internal error')