Exemplo n.º 1
0
def test_AssertFieldConstraint_no_callback_return_value():
  """Field value is returned when no callback and field is set."""
  t = test_protos_pb2.TestMessage()
  t.string = 'foo'
  t.number = 5
  assert 'foo' == pbutil.AssertFieldConstraint(t, 'string')
  assert 5 == pbutil.AssertFieldConstraint(t, 'number')
Exemplo n.º 2
0
def AssertConfigIsValid(config: sampler_pb2.Sampler) -> sampler_pb2.Sampler:
    """Assert that a sampler configuration contains no invalid values.

  Args:
    config: A sampler configuration proto.

  Returns:
    The sampler configuration proto.

  Raises:
    UserError: If there are configuration errors.
  """
    try:
        pbutil.AssertFieldConstraint(config, 'start_text', lambda s: len(s),
                                     'Sampler.start_text must be a string')
        pbutil.AssertFieldConstraint(config, 'batch_size', lambda x: 0 < x,
                                     'Sampler.batch_size must be > 0')
        pbutil.AssertFieldConstraint(config, 'sequence_length',
                                     lambda x: 0 < x,
                                     'Sampler.sequence_length must be > 0')
        pbutil.AssertFieldConstraint(config, 'temperature_micros',
                                     lambda x: 0 < x,
                                     'Sampler.temperature_micros must be > 0')
        return config
    except pbutil.ProtoValueError as e:
        raise errors.UserError(e)
Exemplo n.º 3
0
def test_AssertFieldConstraint_user_callback_passes():
  """Field value is returned when user callback passes."""
  t = test_protos_pb2.TestMessage()
  t.string = 'foo'
  t.number = 5
  assert 'foo' == pbutil.AssertFieldConstraint(t, 'string',
                                               lambda x: x == 'foo')
  assert 5 == pbutil.AssertFieldConstraint(t, 'number', lambda x: 1 < x < 10)
Exemplo n.º 4
0
def test_AssertFieldConstraint_field_not_set():
  """ValueError is raised if the requested field is not set."""
  t = test_protos_pb2.TestMessage()
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldConstraint(t, 'string')
  assert "Field not set: 'TestMessage.string'" == str(e_info.value)
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldConstraint(t, 'number')
  assert "Field not set: 'TestMessage.number'" == str(e_info.value)
Exemplo n.º 5
0
 def __init__(self, config: sampler_pb2.SymmetricalTokenDepth):
     try:
         self.left_token = pbutil.AssertFieldConstraint(
             config, 'depth_increase_token', lambda s: len(s) > 0,
             'SymmetricalTokenDepth.depth_increase_token must be a string')
         self.right_token = pbutil.AssertFieldConstraint(
             config, 'depth_decrease_token', lambda s: len(s) > 0,
             'SymmetricalTokenDepth.depth_decrease_token must be a string')
     except pbutil.ProtoValueError as e:
         raise errors.UserError(e)
     if self.left_token == self.right_token:
         raise errors.UserError(
             'SymmetricalTokenDepth tokens must be different')
Exemplo n.º 6
0
def test_AssertFieldConstraint_user_callback_fails():
  """ProtoValueError raised when when user callback fails."""
  t = test_protos_pb2.TestMessage()
  t.string = 'foo'
  t.number = 5
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldConstraint(t, 'string', lambda x: x == 'bar')
  assert "Field fails constraint check: 'TestMessage.string'" == str(
      e_info.value)
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldConstraint(t, 'number', lambda x: 10 < x < 100)
  assert "Field fails constraint check: 'TestMessage.number'" == str(
      e_info.value)
Exemplo n.º 7
0
def test_AssertFieldConstraint_user_callback_custom_fail_message():
  """Test that the requested message is returned on callback fail."""
  t = test_protos_pb2.TestMessage()
  t.string = 'foo'

  # Constraint function fails.
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldConstraint(t, 'string', lambda x: x == 'bar',
                                 'Hello, world!')
  assert 'Hello, world!' == str(e_info.value)

  # Field not set.
  with pytest.raises(pbutil.ProtoValueError) as e_info:
    pbutil.AssertFieldConstraint(t, 'number', fail_message='Hello, world!')
  assert 'Hello, world!' == str(e_info.value)
Exemplo n.º 8
0
 def __init__(self, config: generator_pb2.RandCharGenerator):
   super(RandCharGenerator, self).__init__(config)
   self.toolchain = self.config.model.corpus.language
   self.generator = deepsmith_pb2.Generator(
       name='randchar',
       opts={
         'toolchain': str(pbutil.AssertFieldConstraint(
             self.config, 'toolchain', lambda x: len(x))),
         'min_len': str(pbutil.AssertFieldConstraint(
             self.config, 'string_min_len', lambda x: x > 0)),
         'max_len': str(pbutil.AssertFieldConstraint(
             self.config, 'string_max_len',
             lambda x: x > 0 and x >= self.config.string_min_len)),
       }
   )
Exemplo n.º 9
0
 def __init__(self, config: sampler_pb2.MaxTokenLength):
     try:
         self.max_len = pbutil.AssertFieldConstraint(
             config, 'maximum_tokens_in_sample', lambda x: x > 1,
             'MaxTokenLength.maximum_tokens_in_sample must be > 0')
     except pbutil.ProtoValueError as e:
         raise errors.UserError(e)
Exemplo n.º 10
0
def test_AssertFieldConstraint_user_callback_raises_exception():
  """If callback raises exception, it is passed to calling code."""
  t = test_protos_pb2.TestMessage()
  t.string = 'foo'

  def CallbackWhichRaisesException(x):
    """Test callback which raises an exception"""
    raise FileExistsError('foo')

  with pytest.raises(FileExistsError) as e_info:
    pbutil.AssertFieldConstraint(t, 'string', CallbackWhichRaisesException)
  assert str(e_info.value) == 'foo'
Exemplo n.º 11
0
def AssertIsBuildable(config: model_pb2.Model) -> model_pb2.Model:
    """Assert that a model configuration is buildable.

  Args:
    config: A model proto.

  Returns:
    The input model proto, unmodified.

  Raises:
    UserError: If the model is not buildable.
    InternalError: If the value of the training.optimizer field is not
      understood.
  """
    # Any change to the Model proto schema will require a change to this function.
    try:
        pbutil.AssertFieldIsSet(config, 'corpus')
        pbutil.AssertFieldIsSet(config, 'architecture')
        pbutil.AssertFieldIsSet(config, 'training')
        pbutil.AssertFieldIsSet(config.architecture, 'backend')
        pbutil.AssertFieldIsSet(config.architecture, 'neuron_type')
        if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS:
            pbutil.AssertFieldConstraint(
                config.architecture, 'embedding_size', lambda x: 0 < x,
                'NetworkArchitecture.embedding_size must be > 0')
        pbutil.AssertFieldConstraint(
            config.architecture, 'neurons_per_layer', lambda x: 0 < x,
            'NetworkArchitecture.neurons_per_layer must be > 0')
        pbutil.AssertFieldConstraint(
            config.architecture, 'num_layers', lambda x: 0 < x,
            'NetworkArchitecture.num_layers must be > 0')
        pbutil.AssertFieldConstraint(
            config.architecture, 'post_layer_dropout_micros',
            lambda x: 0 <= x <= 1000000,
            'NetworkArchitecture.post_layer_dropout_micros '
            'must be >= 0 and <= 1000000')
        pbutil.AssertFieldConstraint(config.training, 'num_epochs',
                                     lambda x: 0 < x,
                                     'TrainingOptions.num_epochs must be > 0')
        pbutil.AssertFieldIsSet(config.training,
                                'shuffle_corpus_contentfiles_between_epochs')
        pbutil.AssertFieldConstraint(config.training, 'batch_size',
                                     lambda x: 0 < x,
                                     'TrainingOptions.batch_size must be > 0')
        pbutil.AssertFieldIsSet(config.training, 'optimizer')
        if config.training.HasField('adam_optimizer'):
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer, 'initial_learning_rate_micros',
                lambda x: 0 <= x,
                'AdamOptimizer.initial_learning_rate_micros must be >= 0')
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer,
                'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
                'AdamOptimizer.learning_rate_decay_per_epoch_micros must be >= 0'
            )
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer, 'beta_1_micros',
                lambda x: 0 <= x <= 1000000,
                'AdamOptimizer.beta_1_micros must be >= 0 and <= 1000000')
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer, 'beta_2_micros',
                lambda x: 0 <= x <= 1000000,
                'AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000')
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer,
                'normalized_gradient_clip_micros', lambda x: 0 <= x,
                'AdamOptimizer.normalized_gradient_clip_micros must be >= 0')
        elif config.training.HasField('rmsprop_optimizer'):
            pbutil.AssertFieldConstraint(
                config.training.rmsprop_optimizer,
                'initial_learning_rate_micros', lambda x: 0 <= x,
                'RmsPropOptimizer.initial_learning_rate_micros must be >= 0')
            pbutil.AssertFieldConstraint(
                config.training.rmsprop_optimizer,
                'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
                'RmsPropOptimizer.learning_rate_decay_per_epoch_micros must be >= 0'
            )
        else:
            raise errors.InternalError(
                "Unrecognized value: 'TrainingOptions.optimizer'")
    except pbutil.ProtoValueError as e:
        raise errors.UserError(str(e))
    return config
Exemplo n.º 12
0
def test_AssertFieldConstraint_invalid_field_name():
  """ValueError is raised if the requested field name does not exist."""
  t = test_protos_pb2.TestMessage()
  with pytest.raises(ValueError):
    pbutil.AssertFieldConstraint(t, 'not_a_real_field')