def test_AssertConfigIsValid_invalid_temperature_micros(abc_sampler_config): """Test that an error is thrown if temperature_micros is < 0.""" # Field not set. abc_sampler_config.ClearField("temperature_micros") with test.Raises(errors.UserError) as e_info: samplers.Sampler(abc_sampler_config) assert "Sampler.temperature_micros must be > 0" == str(e_info.value) # Value is negative. abc_sampler_config.temperature_micros = -1 with test.Raises(errors.UserError) as e_info: samplers.Sampler(abc_sampler_config) assert "Sampler.temperature_micros must be > 0" == str(e_info.value)
def test_AssertConfigIsValid_no_start_text(clgen_cache_dir, abc_sampler_config): """Test that an error is thrown if start_text field is not set.""" del clgen_cache_dir # Field not set. abc_sampler_config.ClearField("start_text") with test.Raises(errors.UserError) as e_info: samplers.Sampler(abc_sampler_config) assert "Sampler.start_text must be a string" == str(e_info.value) # Value is an empty string. abc_sampler_config.start_text = "" with test.Raises(errors.UserError) as e_info: samplers.Sampler(abc_sampler_config) assert "Sampler.start_text must be a string" == str(e_info.value)
def __init__(self, config: clgen_pb2.Instance): """Instantiate an instance. Args: config: An Instance proto. Raises: UserError: If the instance proto contains invalid values, is missing a model or sampler fields. """ try: pbutil.AssertFieldIsSet(config, 'pretrained_model') pbutil.AssertFieldIsSet(config, 'sampler') except pbutil.ProtoValueError as e: raise errors.UserError(e) self.working_dir = None if config.HasField('working_dir'): self.working_dir: pathlib.Path = pathlib.Path( os.path.expandvars( config.working_dir)).expanduser().absolute() # Enter a session so that the cache paths are set relative to any requested # working directory. with self.Session(): self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel( pathlib.Path(config.pretrained_model)) self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)
def __init__(self, config: clgen_pb2.Instance, dashboard_opts={}): """Instantiate an instance. Args: config: An Instance proto. Raises: UserError: If the instance proto contains invalid values, is missing a model or sampler fields. """ try: pbutil.AssertFieldIsSet(config, "model_specification") pbutil.AssertFieldIsSet(config, "sampler") except pbutil.ProtoValueError as e: raise errors.UserError(e) self.config = config self.working_dir = None if config.HasField("working_dir"): self.working_dir: pathlib.Path = pathlib.Path( os.path.expandvars( config.working_dir)).expanduser().absolute() # Enter a session so that the cache paths are set relative to any requested # working directory. with self.Session(): if config.HasField("model"): self.model: models.Model = models.Model(config.model) else: self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel( pathlib.Path(config.pretrained_model)) self.sampler: samplers.Sampler = samplers.Sampler(config.sampler) self.dashboard = dashboard.Launch(**dashboard_opts)
def test_AssertConfigIsValid_invalid_batch_size(abc_sampler_config): """Test that an error is thrown if batch_size is < 1.""" # Field not set. abc_sampler_config.ClearField("batch_size") with test.Raises(errors.UserError) as e_info: samplers.Sampler(abc_sampler_config) assert "Sampler.batch_size must be > 0" == str(e_info.value) # Value is zero. abc_sampler_config.batch_size = 0 with test.Raises(errors.UserError) as e_info: samplers.Sampler(abc_sampler_config) assert "Sampler.batch_size must be > 0" == str(e_info.value) # Value is negative. abc_sampler_config.batch_size = -1 with test.Raises(errors.UserError) as e_info: samplers.Sampler(abc_sampler_config) assert "Sampler.batch_size must be > 0" == str(e_info.value)
def Train(self, *args, **kwargs) -> None: with self.Session(): test_sampler_config = sampler_pb2.Sampler() test_sampler_config.CopyFrom(self.sampler.config) # Make all test samples the same 512-token length. del test_sampler_config.termination_criteria[:] test_sampler_config.termination_criteria.extend([ sampler_pb2.SampleTerminationCriterion( maxlen=sampler_pb2.MaxTokenLength( maximum_tokens_in_sample=512)), ]) test_sampler = samplers.Sampler(test_sampler_config) # We inject the `test_sampler` argument so that we can create samples # during training. self.model.Train(*args, test_sampler=test_sampler, **kwargs)
def main(argv: typing.List[str]): """Main entry point.""" if len(argv) > 1: raise app.UsageError("Unknown arguments: '{}'.".format(" ".join( argv[1:]))) config = generative_model.CreateInstanceProtoFromFlags() os.environ["CLGEN_CACHE"] = config.working_dir logger = BacktrackingDatabaseLogger(backtracking_db.Database(FLAGS.db)) model = backtracking_model.BacktrackingModel(config.model, logger=logger) sampler = samplers.Sampler(config.sampler) sample_observers = generative_model.SampleObserversFromFlags() model.Sample(sampler, sample_observers, seed=FLAGS.sample_seed)
def test_Sampler_Specialize_multiple_tokens_per( abc_sampler_config: sampler_pb2.Sampler, ): """Test that InvalidSymtokTokens raised if depth tokens encode to mult.""" t = abc_sampler_config.termination_criteria.add() t.symtok.depth_increase_token = "abc" t.symtok.depth_decrease_token = "cba" s = samplers.Sampler(abc_sampler_config) def MockAtomizeString(string): """AtomizeString() with a multi-token output.""" del string return np.array([1, 2, 3]) mock = AtomizerMock() mock.AtomizeString = MockAtomizeString with test.Raises(errors.InvalidSymtokTokens) as e_info: s.Specialize(mock) assert ("Sampler symmetrical depth tokens do not encode to a single " "token using the corpus vocabulary")
def SampleStream(): config = generative_model.CreateInstanceProtoFromFlags() os.environ["CLGEN_CACHE"] = config.working_dir logger = MyLogger() model = backtracking_model.BacktrackingModel(config.model, logger=logger) model.corpus.Create() sampler = samplers.Sampler(config.sampler) atomizer = model.corpus.atomizer sampler.Specialize(atomizer) batch_size = model.backend.InitSampling(sampler, 0) backtracker = backtracking_model.OpenClBacktrackingHelper( atomizer, model._target_features) model.backend.InitSampleBatch(sampler) logger.OnSampleStart(backtracker) yield from model.SampleOneWithBacktrackingToTextStream( sampler, atomizer, backtracker)
def test_Sampler_Specialize_invalid_depth_tokens( abc_sampler_config: sampler_pb2.Sampler, ): """Test that InvalidSymtokTokens raised if depth tokens cannot be encoded.""" t = abc_sampler_config.termination_criteria.add() t.symtok.depth_increase_token = "{" t.symtok.depth_decrease_token = "}" s = samplers.Sampler(abc_sampler_config) def MockAtomizeString(string): """AtomizeString() with a vocab error on depth tokens.""" if string == "{" or string == "}": raise errors.VocabError() else: return np.ndarray([1]) mock = AtomizerMock() mock.AtomizeString = MockAtomizeString with test.Raises(errors.InvalidSymtokTokens) as e_info: s.Specialize(mock) assert ("Sampler symmetrical depth tokens cannot be encoded using the " "corpus vocabulary") == str(e_info.value)
def test_Sampler_Specialize_encoded_start_text( abc_sampler_config: sampler_pb2.Sampler, ): s = samplers.Sampler(abc_sampler_config) assert s.encoded_start_text is None s.Specialize(AtomizerMock()) np.testing.assert_array_equal(np.array([1]), s.encoded_start_text)
def test_Sampler_batch_size(abc_sampler_config: sampler_pb2.Sampler): """Test that batch_size is set from Sampler proto.""" abc_sampler_config.batch_size = 99 s = samplers.Sampler(abc_sampler_config) assert 99 == s.batch_size
def test_Sampler_temperature(abc_sampler_config: sampler_pb2.Sampler): """Test that temperature is set from Sampler proto.""" abc_sampler_config.temperature_micros = 1000000 s = samplers.Sampler(abc_sampler_config) assert pytest.approx(1.0) == s.temperature
def test_Sampler_start_text(abc_sampler_config: sampler_pb2.Sampler): """Test that start_text is set from Sampler proto.""" s = samplers.Sampler(abc_sampler_config) assert s.start_text == abc_sampler_config.start_text
def test_Sampler_config_type_error(): """Test that a TypeError is raised if config is not a Sampler proto.""" with test.Raises(TypeError) as e_info: samplers.Sampler(1) assert "Config must be a Sampler proto. Received: 'int'" == str( e_info.value)