Exemplo n.º 1
0
    def __init__(self, config: clgen_pb2.Instance):
        """Instantiate an instance.

    Args:
      config: An Instance proto.

    Raises:
      UserError: If the instance proto contains invalid values, is missing
        a model or sampler fields.
    """
        try:
            pbutil.AssertFieldIsSet(config, 'model_specification')
            pbutil.AssertFieldIsSet(config, 'sampler')
        except pbutil.ProtoValueError as e:
            raise errors.UserError(e)

        self.working_dir = None
        if config.HasField('working_dir'):
            self.working_dir: pathlib.Path = pathlib.Path(
                os.path.expandvars(
                    config.working_dir)).expanduser().absolute()
        # Enter a session so that the cache paths are set relative to any requested
        # working directory.
        with self.Session():
            if config.HasField('model'):
                self.model: models.Model = models.Model(config.model)
            else:
                self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel(
                    pathlib.Path(config.pretrained_model))
            self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)
Exemplo n.º 2
0
def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus:
    """Assert that config proto is valid.

  Args:
    config: A Corpus proto.

  Returns:
    The Corpus proto.

  Raises:
    UserError: If the config is invalid.
  """
    try:
        pbutil.AssertFieldIsSet(config, 'contentfiles')
        pbutil.AssertFieldIsSet(config, 'atomizer')
        pbutil.AssertFieldIsSet(config, 'contentfile_separator')
        # Check that the preprocessor pipeline resolves to preprocessor functions.
        [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor]

        if config.HasField('greedy_multichar_atomizer'):
            if not config.greedy_multichar_atomizer.tokens:
                raise errors.UserError(
                    'GreedyMulticharAtomizer.tokens is empty')
            for atom in config.greedy_multichar_atomizer.tokens:
                if not atom:
                    raise errors.UserError(
                        'Empty string found in GreedyMulticharAtomizer.tokens is empty'
                    )

        return config
    except pbutil.ProtoValueError as e:
        raise errors.UserError(e)
Exemplo n.º 3
0
def test_AssertFieldIsSet_field_is_set():
  """Field value is returned when field is set."""
  t = test_protos_pb2.TestMessage()
  t.string = 'foo'
  t.number = 5
  assert 'foo' == pbutil.AssertFieldIsSet(t, 'string')
  assert 5 == pbutil.AssertFieldIsSet(t, 'number')
Exemplo n.º 4
0
def test_AssertFieldIsSet_user_callback_custom_fail_message():
  """Test that the requested message is returned on callback fail."""
  t = test_protos_pb2.TestMessage()
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldIsSet(t, 'string', 'Hello, world!')
  assert 'Hello, world!' == str(e_info.value)
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldIsSet(t, 'number', fail_message='Hello, world!')
  assert 'Hello, world!' == str(e_info.value)
Exemplo n.º 5
0
def test_AssertFieldIsSet_field_not_set():
  """ValueError is raised if the requested field is not set."""
  t = test_protos_pb2.TestMessage()
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldIsSet(t, 'string')
  assert "Field not set: 'TestMessage.string'" == str(e_info.value)
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldIsSet(t, 'number')
  assert "Field not set: 'TestMessage.number'" == str(e_info.value)
Exemplo n.º 6
0
def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus:
  """Assert that config proto is valid.

  Args:
    config: A Corpus proto.

  Returns:
    The Corpus proto.

  Raises:
    UserError: If the config is invalid.
  """
  try:
    # Early-exit to support corpuses derived from databases of pre-encoded
    # content files.
    # TODO(github.com/ChrisCummins/phd/issues/46): Refactor after splitting
    # Corpus class.
    if config.HasField('pre_encoded_corpus_url'):
      return config

    pbutil.AssertFieldIsSet(config, 'contentfiles')
    pbutil.AssertFieldIsSet(config, 'atomizer')
    pbutil.AssertFieldIsSet(config, 'contentfile_separator')
    # Check that the preprocessor pipeline resolves to preprocessor functions.
    [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor]

    if config.HasField('greedy_multichar_atomizer'):
      if not config.greedy_multichar_atomizer.tokens:
        raise errors.UserError('GreedyMulticharAtomizer.tokens is empty')
      for atom in config.greedy_multichar_atomizer.tokens:
        if not atom:
          raise errors.UserError(
              'Empty string found in GreedyMulticharAtomizer.tokens is empty')

    return config
  except pbutil.ProtoValueError as e:
    raise errors.UserError(e)
Exemplo n.º 7
0
def AssertIsBuildable(config: model_pb2.Model) -> model_pb2.Model:
    """Assert that a model configuration is buildable.

  Args:
    config: A model proto.

  Returns:
    The input model proto, unmodified.

  Raises:
    UserError: If the model is not buildable.
    InternalError: If the value of the training.optimizer field is not
      understood.
  """
    # Any change to the Model proto schema will require a change to this function.
    try:
        pbutil.AssertFieldIsSet(config, 'corpus')
        pbutil.AssertFieldIsSet(config, 'architecture')
        pbutil.AssertFieldIsSet(config, 'training')
        pbutil.AssertFieldIsSet(config.architecture, 'backend')
        pbutil.AssertFieldIsSet(config.architecture, 'neuron_type')
        if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS:
            pbutil.AssertFieldConstraint(
                config.architecture, 'embedding_size', lambda x: 0 < x,
                'NetworkArchitecture.embedding_size must be > 0')
        pbutil.AssertFieldConstraint(
            config.architecture, 'neurons_per_layer', lambda x: 0 < x,
            'NetworkArchitecture.neurons_per_layer must be > 0')
        pbutil.AssertFieldConstraint(
            config.architecture, 'num_layers', lambda x: 0 < x,
            'NetworkArchitecture.num_layers must be > 0')
        pbutil.AssertFieldConstraint(
            config.architecture, 'post_layer_dropout_micros',
            lambda x: 0 <= x <= 1000000,
            'NetworkArchitecture.post_layer_dropout_micros '
            'must be >= 0 and <= 1000000')
        pbutil.AssertFieldConstraint(config.training, 'num_epochs',
                                     lambda x: 0 < x,
                                     'TrainingOptions.num_epochs must be > 0')
        pbutil.AssertFieldIsSet(config.training,
                                'shuffle_corpus_contentfiles_between_epochs')
        pbutil.AssertFieldConstraint(config.training, 'batch_size',
                                     lambda x: 0 < x,
                                     'TrainingOptions.batch_size must be > 0')
        pbutil.AssertFieldIsSet(config.training, 'optimizer')
        if config.training.HasField('adam_optimizer'):
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer, 'initial_learning_rate_micros',
                lambda x: 0 <= x,
                'AdamOptimizer.initial_learning_rate_micros must be >= 0')
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer,
                'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
                'AdamOptimizer.learning_rate_decay_per_epoch_micros must be >= 0'
            )
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer, 'beta_1_micros',
                lambda x: 0 <= x <= 1000000,
                'AdamOptimizer.beta_1_micros must be >= 0 and <= 1000000')
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer, 'beta_2_micros',
                lambda x: 0 <= x <= 1000000,
                'AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000')
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer,
                'normalized_gradient_clip_micros', lambda x: 0 <= x,
                'AdamOptimizer.normalized_gradient_clip_micros must be >= 0')
        elif config.training.HasField('rmsprop_optimizer'):
            pbutil.AssertFieldConstraint(
                config.training.rmsprop_optimizer,
                'initial_learning_rate_micros', lambda x: 0 <= x,
                'RmsPropOptimizer.initial_learning_rate_micros must be >= 0')
            pbutil.AssertFieldConstraint(
                config.training.rmsprop_optimizer,
                'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
                'RmsPropOptimizer.learning_rate_decay_per_epoch_micros must be >= 0'
            )
        else:
            raise errors.InternalError(
                "Unrecognized value: 'TrainingOptions.optimizer'")
    except pbutil.ProtoValueError as e:
        raise errors.UserError(str(e))
    return config
Exemplo n.º 8
0
def test_AssFieldIsSet_oneof_field_no_return():
  """Test that no value is returned when a oneof field is set."""
  t = test_protos_pb2.TestMessage()
  t.option_a = 1
  assert pbutil.AssertFieldIsSet(t, 'union_field') is None
  assert 1 == pbutil.AssertFieldIsSet(t, 'option_a')
Exemplo n.º 9
0
def test_AssertFieldIsSet_invalid_field_name():
  """ValueError is raised if the requested field name does not exist."""
  t = test_protos_pb2.TestMessage()
  with pytest.raises(ValueError):
    pbutil.AssertFieldIsSet(t, 'not_a_real_field')