def __init__(self, config: clgen_pb2.Instance): """Instantiate an instance. Args: config: An Instance proto. Raises: UserError: If the instance proto contains invalid values, is missing a model or sampler fields. """ try: pbutil.AssertFieldIsSet(config, 'pretrained_model') pbutil.AssertFieldIsSet(config, 'sampler') except pbutil.ProtoValueError as e: raise errors.UserError(e) self.working_dir = None if config.HasField('working_dir'): self.working_dir: pathlib.Path = pathlib.Path( os.path.expandvars( config.working_dir)).expanduser().absolute() # Enter a session so that the cache paths are set relative to any requested # working directory. with self.Session(): self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel( pathlib.Path(config.pretrained_model)) self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)
def test_AssertFieldIsSet_field_is_set(): """Field value is returned when field is set.""" t = test_protos_pb2.TestMessage() t.string = 'foo' t.number = 5 assert 'foo' == pbutil.AssertFieldIsSet(t, 'string') assert 5 == pbutil.AssertFieldIsSet(t, 'number')
def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus: """Assert that config proto is valid. Args: config: A Corpus proto. Returns: The Corpus proto. Raises: UserError: If the config is invalid. """ try: pbutil.AssertFieldIsSet(config, 'contentfiles') pbutil.AssertFieldIsSet(config, 'atomizer') pbutil.AssertFieldIsSet(config, 'contentfile_separator') # Check that the preprocessor pipeline resolves to preprocessor functions. [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor] if config.HasField('greedy_multichar_atomizer'): if not config.greedy_multichar_atomizer.tokens: raise errors.UserError( 'GreedyMulticharAtomizer.tokens is empty') for atom in config.greedy_multichar_atomizer.tokens: if not atom: raise errors.UserError( 'Empty string found in GreedyMulticharAtomizer.tokens is empty' ) return config except pbutil.ProtoValueError as e: raise errors.UserError(e)
def test_AssertFieldIsSet_user_callback_custom_fail_message(): """Test that the requested message is returned on callback fail.""" t = test_protos_pb2.TestMessage() with pytest.raises(pbutil.ProtoValueError) as e_info: pbutil.AssertFieldIsSet(t, 'string', 'Hello, world!') assert 'Hello, world!' == str(e_info.value) with pytest.raises(pbutil.ProtoValueError) as e_info: pbutil.AssertFieldIsSet(t, 'number', fail_message='Hello, world!') assert 'Hello, world!' == str(e_info.value)
def test_AssertFieldIsSet_field_not_set(): """ValueError is raised if the requested field is not set.""" t = test_protos_pb2.TestMessage() with pytest.raises(pbutil.ProtoValueError) as e_info: pbutil.AssertFieldIsSet(t, 'string') assert "Field not set: 'TestMessage.string'" == str(e_info.value) with pytest.raises(pbutil.ProtoValueError) as e_info: pbutil.AssertFieldIsSet(t, 'number') assert "Field not set: 'TestMessage.number'" == str(e_info.value)
def AssertIsBuildable(config: model_pb2.Model) -> model_pb2.Model: """Assert that a model configuration is buildable. Args: config: A model proto. Returns: The input model proto, unmodified. Raises: UserError: If the model is not buildable. InternalError: If the value of the training.optimizer field is not understood. """ # Any change to the Model proto schema will require a change to this function. try: pbutil.AssertFieldIsSet(config, 'corpus') pbutil.AssertFieldIsSet(config, 'architecture') pbutil.AssertFieldIsSet(config, 'training') pbutil.AssertFieldIsSet(config.architecture, 'backend') pbutil.AssertFieldIsSet(config.architecture, 'neuron_type') if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS: pbutil.AssertFieldConstraint( config.architecture, 'embedding_size', lambda x: 0 < x, 'NetworkArchitecture.embedding_size must be > 0') pbutil.AssertFieldConstraint( config.architecture, 'neurons_per_layer', lambda x: 0 < x, 'NetworkArchitecture.neurons_per_layer must be > 0') pbutil.AssertFieldConstraint( config.architecture, 'num_layers', lambda x: 0 < x, 'NetworkArchitecture.num_layers must be > 0') pbutil.AssertFieldConstraint( config.architecture, 'post_layer_dropout_micros', lambda x: 0 <= x <= 1000000, 'NetworkArchitecture.post_layer_dropout_micros ' 'must be >= 0 and <= 1000000') pbutil.AssertFieldConstraint( config.training, 'num_epochs', lambda x: 0 < x, 'TrainingOptions.num_epochs must be > 0') pbutil.AssertFieldIsSet( config.training, 'shuffle_corpus_contentfiles_between_epochs') pbutil.AssertFieldConstraint( config.training, 'batch_size', lambda x: 0 < x, 'TrainingOptions.batch_size must be > 0') pbutil.AssertFieldIsSet(config.training, 'optimizer') if config.training.HasField('adam_optimizer'): pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'initial_learning_rate_micros', lambda x: 0 <= x, 'AdamOptimizer.initial_learning_rate_micros must be >= 0') pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x, 'AdamOptimizer.learning_rate_decay_per_epoch_micros must be >= 0') pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'beta_1_micros', lambda x: 0 <= x <= 1000000, 'AdamOptimizer.beta_1_micros must be >= 0 and <= 1000000') pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'beta_2_micros', lambda x: 0 <= x <= 1000000, 'AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000') pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'normalized_gradient_clip_micros', lambda x: 0 <= x, 'AdamOptimizer.normalized_gradient_clip_micros must be >= 0') elif config.training.HasField('rmsprop_optimizer'): pbutil.AssertFieldConstraint( config.training.rmsprop_optimizer, 'initial_learning_rate_micros', lambda x: 0 <= x, 'RmsPropOptimizer.initial_learning_rate_micros must be >= 0') pbutil.AssertFieldConstraint( config.training.rmsprop_optimizer, 'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x, 'RmsPropOptimizer.learning_rate_decay_per_epoch_micros must be >= 0') else: raise errors.InternalError( "Unrecognized value: 'TrainingOptions.optimizer'") except pbutil.ProtoValueError as e: raise errors.UserError(str(e)) return config
def test_AssFieldIsSet_oneof_field_no_return(): """Test that no value is returned when a oneof field is set.""" t = test_protos_pb2.TestMessage() t.option_a = 1 assert pbutil.AssertFieldIsSet(t, 'union_field') is None assert 1 == pbutil.AssertFieldIsSet(t, 'option_a')
def test_AssertFieldIsSet_invalid_field_name(): """ValueError is raised if the requested field name does not exist.""" t = test_protos_pb2.TestMessage() with pytest.raises(ValueError): pbutil.AssertFieldIsSet(t, 'not_a_real_field')