Esempio n. 1
0
def test_AssertFieldConstraint_no_callback_return_value():
    """Field value is returned when no callback and field is set."""
    t = test_protos_pb2.TestMessage()
    t.string = 'foo'
    t.number = 5
    assert 'foo' == pbutil.AssertFieldConstraint(t, 'string')
    assert 5 == pbutil.AssertFieldConstraint(t, 'number')
Esempio n. 2
0
def test_AssertFieldConstraint_user_callback_passes():
    """Field value is returned when user callback passes."""
    t = test_protos_pb2.TestMessage()
    t.string = 'foo'
    t.number = 5
    assert 'foo' == pbutil.AssertFieldConstraint(t, 'string',
                                                 lambda x: x == 'foo')
    assert 5 == pbutil.AssertFieldConstraint(t, 'number', lambda x: 1 < x < 10)
Esempio n. 3
0
def test_AssertFieldConstraint_field_not_set():
    """ValueError is raised if the requested field is not set."""
    t = test_protos_pb2.TestMessage()
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldConstraint(t, 'string')
    assert "Field not set: 'TestMessage.string'" == str(e_info.value)
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldConstraint(t, 'number')
    assert "Field not set: 'TestMessage.number'" == str(e_info.value)
Esempio n. 4
0
 def __init__(self, config: sampler_pb2.SymmetricalTokenDepth):
     try:
         self.left_token = pbutil.AssertFieldConstraint(
             config, 'depth_increase_token', lambda s: len(s) > 0,
             'SymmetricalTokenDepth.depth_increase_token must be a string')
         self.right_token = pbutil.AssertFieldConstraint(
             config, 'depth_decrease_token', lambda s: len(s) > 0,
             'SymmetricalTokenDepth.depth_decrease_token must be a string')
     except pbutil.ProtoValueError as e:
         raise errors.UserError(e)
     if self.left_token == self.right_token:
         raise errors.UserError(
             'SymmetricalTokenDepth tokens must be different')
Esempio n. 5
0
def test_AssertFieldConstraint_user_callback_fails():
    """ProtoValueError raised when when user callback fails."""
    t = test_protos_pb2.TestMessage()
    t.string = 'foo'
    t.number = 5
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldConstraint(t, 'string', lambda x: x == 'bar')
    assert "Field fails constraint check: 'TestMessage.string'" == str(
        e_info.value)
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldConstraint(t, 'number', lambda x: 10 < x < 100)
    assert "Field fails constraint check: 'TestMessage.number'" == str(
        e_info.value)
Esempio n. 6
0
 def __init__(self, config: generator_pb2.RandCharGenerator):
   super(RandCharGenerator, self).__init__(config)
   self.toolchain = self.config.model.corpus.language
   self.generator = deepsmith_pb2.Generator(
       name='randchar',
       opts={
         'toolchain': str(pbutil.AssertFieldConstraint(
             self.config, 'toolchain', lambda x: len(x))),
         'min_len': str(pbutil.AssertFieldConstraint(
             self.config, 'string_min_len', lambda x: x > 0)),
         'max_len': str(pbutil.AssertFieldConstraint(
             self.config, 'string_max_len',
             lambda x: x > 0 and x >= self.config.string_min_len)),
       }
   )
Esempio n. 7
0
def test_AssertFieldConstraint_user_callback_custom_fail_message():
    """Test that the requested message is returned on callback fail."""
    t = test_protos_pb2.TestMessage()
    t.string = 'foo'

    # Constraint function fails.
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldConstraint(t, 'string', lambda x: x == 'bar',
                                     'Hello, world!')
    assert 'Hello, world!' == str(e_info.value)

    # Field not set.
    with pytest.raises(pbutil.ProtoValueError) as e_info:
        pbutil.AssertFieldConstraint(t, 'number', fail_message='Hello, world!')
    assert 'Hello, world!' == str(e_info.value)
Esempio n. 8
0
 def __init__(self, config: sampler_pb2.MaxTokenLength):
     try:
         self.max_len = pbutil.AssertFieldConstraint(
             config, 'maximum_tokens_in_sample', lambda x: x > 1,
             'MaxTokenLength.maximum_tokens_in_sample must be > 0')
     except pbutil.ProtoValueError as e:
         raise errors.UserError(e)
Esempio n. 9
0
def test_AssertFieldConstraint_user_callback_raises_exception():
    """If callback raises exception, it is passed to calling code."""
    t = test_protos_pb2.TestMessage()
    t.string = 'foo'

    def CallbackWhichRaisesException(x):
        """Test callback which raises an exception"""
        raise FileExistsError('foo')

    with pytest.raises(FileExistsError) as e_info:
        pbutil.AssertFieldConstraint(t, 'string', CallbackWhichRaisesException)
    assert str(e_info.value) == 'foo'
Esempio n. 10
0
def AssertConfigIsValid(config: sampler_pb2.Sampler) -> sampler_pb2.Sampler:
    """Assert that a sampler configuration contains no invalid values.

  Args:
    config: A sampler configuration proto.

  Returns:
    The sampler configuration proto.

  Raises:
    UserError: If there are configuration errors.
  """
    try:
        pbutil.AssertFieldConstraint(config, 'start_text', lambda s: len(s),
                                     'Sampler.start_text must be a string')
        pbutil.AssertFieldConstraint(config, 'batch_size', lambda x: 0 < x,
                                     'Sampler.batch_size must be > 0')
        pbutil.AssertFieldConstraint(config, 'temperature_micros',
                                     lambda x: 0 < x,
                                     'Sampler.temperature_micros must be > 0')
        return config
    except pbutil.ProtoValueError as e:
        raise errors.UserError(e)
Esempio n. 11
0
def AssertIsBuildable(config: model_pb2.Model) -> model_pb2.Model:
  """Assert that a model configuration is buildable.

  Args:
    config: A model proto.

  Returns:
    The input model proto, unmodified.

  Raises:
    UserError: If the model is not buildable.
    InternalError: If the value of the training.optimizer field is not
      understood.
  """
  # Any change to the Model proto schema will require a change to this function.
  try:
    pbutil.AssertFieldIsSet(config, 'corpus')
    pbutil.AssertFieldIsSet(config, 'architecture')
    pbutil.AssertFieldIsSet(config, 'training')
    pbutil.AssertFieldIsSet(config.architecture, 'backend')
    pbutil.AssertFieldIsSet(config.architecture, 'neuron_type')
    if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS:
      pbutil.AssertFieldConstraint(
          config.architecture, 'embedding_size', lambda x: 0 < x,
          'NetworkArchitecture.embedding_size must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'neurons_per_layer', lambda x: 0 < x,
        'NetworkArchitecture.neurons_per_layer must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'num_layers', lambda x: 0 < x,
        'NetworkArchitecture.num_layers must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'post_layer_dropout_micros',
        lambda x: 0 <= x <= 1000000,
        'NetworkArchitecture.post_layer_dropout_micros '
        'must be >= 0 and <= 1000000')
    pbutil.AssertFieldConstraint(
        config.training, 'num_epochs', lambda x: 0 < x,
        'TrainingOptions.num_epochs must be > 0')
    pbutil.AssertFieldIsSet(
        config.training, 'shuffle_corpus_contentfiles_between_epochs')
    pbutil.AssertFieldConstraint(
        config.training, 'batch_size', lambda x: 0 < x,
        'TrainingOptions.batch_size must be > 0')
    pbutil.AssertFieldIsSet(config.training, 'optimizer')
    if config.training.HasField('adam_optimizer'):
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer, 'initial_learning_rate_micros',
          lambda x: 0 <= x,
          'AdamOptimizer.initial_learning_rate_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
          'AdamOptimizer.learning_rate_decay_per_epoch_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'beta_1_micros', lambda x: 0 <= x <= 1000000,
          'AdamOptimizer.beta_1_micros must be >= 0 and <= 1000000')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'beta_2_micros', lambda x: 0 <= x <= 1000000,
          'AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'normalized_gradient_clip_micros', lambda x: 0 <= x,
          'AdamOptimizer.normalized_gradient_clip_micros must be >= 0')
    elif config.training.HasField('rmsprop_optimizer'):
      pbutil.AssertFieldConstraint(
          config.training.rmsprop_optimizer, 'initial_learning_rate_micros',
          lambda x: 0 <= x,
          'RmsPropOptimizer.initial_learning_rate_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.rmsprop_optimizer,
          'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
          'RmsPropOptimizer.learning_rate_decay_per_epoch_micros must be >= 0')
    else:
      raise errors.InternalError(
          "Unrecognized value: 'TrainingOptions.optimizer'")
  except pbutil.ProtoValueError as e:
    raise errors.UserError(str(e))
  return config
Esempio n. 12
0
def test_AssertFieldConstraint_invalid_field_name():
    """ValueError is raised if the requested field name does not exist."""
    t = test_protos_pb2.TestMessage()
    with pytest.raises(ValueError):
        pbutil.AssertFieldConstraint(t, 'not_a_real_field')