예제 #1
0
파일: sample.py 프로젝트: BeauJoh/phd
    def __init__(self, config: clgen_pb2.Instance):
        """Instantiate an instance.

    Args:
      config: An Instance proto.

    Raises:
      UserError: If the instance proto contains invalid values, is missing
        a model or sampler fields.
    """
        try:
            pbutil.AssertFieldIsSet(config, 'pretrained_model')
            pbutil.AssertFieldIsSet(config, 'sampler')
        except pbutil.ProtoValueError as e:
            raise errors.UserError(e)

        self.working_dir = None
        if config.HasField('working_dir'):
            self.working_dir: pathlib.Path = pathlib.Path(
                os.path.expandvars(
                    config.working_dir)).expanduser().absolute()
        # Enter a session so that the cache paths are set relative to any requested
        # working directory.
        with self.Session():
            self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel(
                pathlib.Path(config.pretrained_model))
            self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)
예제 #2
0
def test_AssertFieldIsSet_field_is_set():
    """Field value is returned when field is set."""
    t = test_protos_pb2.TestMessage()
    t.string = 'foo'
    t.number = 5
    assert 'foo' == pbutil.AssertFieldIsSet(t, 'string')
    assert 5 == pbutil.AssertFieldIsSet(t, 'number')
예제 #3
0
def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus:
    """Assert that config proto is valid.

  Args:
    config: A Corpus proto.

  Returns:
    The Corpus proto.

  Raises:
    UserError: If the config is invalid.
  """
    try:
        pbutil.AssertFieldIsSet(config, 'contentfiles')
        pbutil.AssertFieldIsSet(config, 'atomizer')
        pbutil.AssertFieldIsSet(config, 'contentfile_separator')
        # Check that the preprocessor pipeline resolves to preprocessor functions.
        [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor]

        if config.HasField('greedy_multichar_atomizer'):
            if not config.greedy_multichar_atomizer.tokens:
                raise errors.UserError(
                    'GreedyMulticharAtomizer.tokens is empty')
            for atom in config.greedy_multichar_atomizer.tokens:
                if not atom:
                    raise errors.UserError(
                        'Empty string found in GreedyMulticharAtomizer.tokens is empty'
                    )

        return config
    except pbutil.ProtoValueError as e:
        raise errors.UserError(e)
예제 #4
0
def test_AssertFieldIsSet_user_callback_custom_fail_message():
    """Test that the requested message is returned on callback fail."""
    t = test_protos_pb2.TestMessage()
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldIsSet(t, 'string', 'Hello, world!')
    assert 'Hello, world!' == str(e_info.value)
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldIsSet(t, 'number', fail_message='Hello, world!')
    assert 'Hello, world!' == str(e_info.value)
예제 #5
0
def test_AssertFieldIsSet_field_not_set():
    """ValueError is raised if the requested field is not set."""
    t = test_protos_pb2.TestMessage()
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldIsSet(t, 'string')
    assert "Field not set: 'TestMessage.string'" == str(e_info.value)
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldIsSet(t, 'number')
    assert "Field not set: 'TestMessage.number'" == str(e_info.value)
예제 #6
0
def AssertIsBuildable(config: model_pb2.Model) -> model_pb2.Model:
  """Assert that a model configuration is buildable.

  Args:
    config: A model proto.

  Returns:
    The input model proto, unmodified.

  Raises:
    UserError: If the model is not buildable.
    InternalError: If the value of the training.optimizer field is not
      understood.
  """
  # Any change to the Model proto schema will require a change to this function.
  try:
    pbutil.AssertFieldIsSet(config, 'corpus')
    pbutil.AssertFieldIsSet(config, 'architecture')
    pbutil.AssertFieldIsSet(config, 'training')
    pbutil.AssertFieldIsSet(config.architecture, 'backend')
    pbutil.AssertFieldIsSet(config.architecture, 'neuron_type')
    if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS:
      pbutil.AssertFieldConstraint(
          config.architecture, 'embedding_size', lambda x: 0 < x,
          'NetworkArchitecture.embedding_size must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'neurons_per_layer', lambda x: 0 < x,
        'NetworkArchitecture.neurons_per_layer must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'num_layers', lambda x: 0 < x,
        'NetworkArchitecture.num_layers must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'post_layer_dropout_micros',
        lambda x: 0 <= x <= 1000000,
        'NetworkArchitecture.post_layer_dropout_micros '
        'must be >= 0 and <= 1000000')
    pbutil.AssertFieldConstraint(
        config.training, 'num_epochs', lambda x: 0 < x,
        'TrainingOptions.num_epochs must be > 0')
    pbutil.AssertFieldIsSet(
        config.training, 'shuffle_corpus_contentfiles_between_epochs')
    pbutil.AssertFieldConstraint(
        config.training, 'batch_size', lambda x: 0 < x,
        'TrainingOptions.batch_size must be > 0')
    pbutil.AssertFieldIsSet(config.training, 'optimizer')
    if config.training.HasField('adam_optimizer'):
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer, 'initial_learning_rate_micros',
          lambda x: 0 <= x,
          'AdamOptimizer.initial_learning_rate_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
          'AdamOptimizer.learning_rate_decay_per_epoch_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'beta_1_micros', lambda x: 0 <= x <= 1000000,
          'AdamOptimizer.beta_1_micros must be >= 0 and <= 1000000')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'beta_2_micros', lambda x: 0 <= x <= 1000000,
          'AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'normalized_gradient_clip_micros', lambda x: 0 <= x,
          'AdamOptimizer.normalized_gradient_clip_micros must be >= 0')
    elif config.training.HasField('rmsprop_optimizer'):
      pbutil.AssertFieldConstraint(
          config.training.rmsprop_optimizer, 'initial_learning_rate_micros',
          lambda x: 0 <= x,
          'RmsPropOptimizer.initial_learning_rate_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.rmsprop_optimizer,
          'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
          'RmsPropOptimizer.learning_rate_decay_per_epoch_micros must be >= 0')
    else:
      raise errors.InternalError(
          "Unrecognized value: 'TrainingOptions.optimizer'")
  except pbutil.ProtoValueError as e:
    raise errors.UserError(str(e))
  return config
예제 #7
0
def test_AssFieldIsSet_oneof_field_no_return():
    """Test that no value is returned when a oneof field is set."""
    t = test_protos_pb2.TestMessage()
    t.option_a = 1
    assert pbutil.AssertFieldIsSet(t, 'union_field') is None
    assert 1 == pbutil.AssertFieldIsSet(t, 'option_a')
예제 #8
0
def test_AssertFieldIsSet_invalid_field_name():
    """ValueError is raised if the requested field name does not exist."""
    t = test_protos_pb2.TestMessage()
    with pytest.raises(ValueError):
        pbutil.AssertFieldIsSet(t, 'not_a_real_field')