Beispiel #1
0
def test_AssertConfigIsValid_invalid_temperature_micros(abc_sampler_config):
    """Test that an error is thrown if temperature_micros is < 0."""
    # Field not set.
    abc_sampler_config.ClearField("temperature_micros")
    with test.Raises(errors.UserError) as e_info:
        samplers.Sampler(abc_sampler_config)
    assert "Sampler.temperature_micros must be > 0" == str(e_info.value)
    # Value is negative.
    abc_sampler_config.temperature_micros = -1
    with test.Raises(errors.UserError) as e_info:
        samplers.Sampler(abc_sampler_config)
    assert "Sampler.temperature_micros must be > 0" == str(e_info.value)
Beispiel #2
0
def test_AssertConfigIsValid_no_start_text(clgen_cache_dir,
                                           abc_sampler_config):
    """Test that an error is thrown if start_text field is not set."""
    del clgen_cache_dir
    # Field not set.
    abc_sampler_config.ClearField("start_text")
    with test.Raises(errors.UserError) as e_info:
        samplers.Sampler(abc_sampler_config)
    assert "Sampler.start_text must be a string" == str(e_info.value)
    # Value is an empty string.
    abc_sampler_config.start_text = ""
    with test.Raises(errors.UserError) as e_info:
        samplers.Sampler(abc_sampler_config)
    assert "Sampler.start_text must be a string" == str(e_info.value)
Beispiel #3
0
    def __init__(self, config: clgen_pb2.Instance):
        """Instantiate an instance.

    Args:
      config: An Instance proto.

    Raises:
      UserError: If the instance proto contains invalid values, is missing
        a model or sampler fields.
    """
        try:
            pbutil.AssertFieldIsSet(config, 'pretrained_model')
            pbutil.AssertFieldIsSet(config, 'sampler')
        except pbutil.ProtoValueError as e:
            raise errors.UserError(e)

        self.working_dir = None
        if config.HasField('working_dir'):
            self.working_dir: pathlib.Path = pathlib.Path(
                os.path.expandvars(
                    config.working_dir)).expanduser().absolute()
        # Enter a session so that the cache paths are set relative to any requested
        # working directory.
        with self.Session():
            self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel(
                pathlib.Path(config.pretrained_model))
            self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)
Beispiel #4
0
    def __init__(self, config: clgen_pb2.Instance, dashboard_opts={}):
        """Instantiate an instance.

    Args:
      config: An Instance proto.

    Raises:
      UserError: If the instance proto contains invalid values, is missing
        a model or sampler fields.
    """
        try:
            pbutil.AssertFieldIsSet(config, "model_specification")
            pbutil.AssertFieldIsSet(config, "sampler")
        except pbutil.ProtoValueError as e:
            raise errors.UserError(e)

        self.config = config
        self.working_dir = None
        if config.HasField("working_dir"):
            self.working_dir: pathlib.Path = pathlib.Path(
                os.path.expandvars(
                    config.working_dir)).expanduser().absolute()
        # Enter a session so that the cache paths are set relative to any requested
        # working directory.
        with self.Session():
            if config.HasField("model"):
                self.model: models.Model = models.Model(config.model)
            else:
                self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel(
                    pathlib.Path(config.pretrained_model))
            self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)

        self.dashboard = dashboard.Launch(**dashboard_opts)
Beispiel #5
0
def test_AssertConfigIsValid_invalid_batch_size(abc_sampler_config):
    """Test that an error is thrown if batch_size is < 1."""
    # Field not set.
    abc_sampler_config.ClearField("batch_size")
    with test.Raises(errors.UserError) as e_info:
        samplers.Sampler(abc_sampler_config)
    assert "Sampler.batch_size must be > 0" == str(e_info.value)
    # Value is zero.
    abc_sampler_config.batch_size = 0
    with test.Raises(errors.UserError) as e_info:
        samplers.Sampler(abc_sampler_config)
    assert "Sampler.batch_size must be > 0" == str(e_info.value)
    # Value is negative.
    abc_sampler_config.batch_size = -1
    with test.Raises(errors.UserError) as e_info:
        samplers.Sampler(abc_sampler_config)
    assert "Sampler.batch_size must be > 0" == str(e_info.value)
Beispiel #6
0
    def Train(self, *args, **kwargs) -> None:
        with self.Session():
            test_sampler_config = sampler_pb2.Sampler()
            test_sampler_config.CopyFrom(self.sampler.config)
            # Make all test samples the same 512-token length.
            del test_sampler_config.termination_criteria[:]
            test_sampler_config.termination_criteria.extend([
                sampler_pb2.SampleTerminationCriterion(
                    maxlen=sampler_pb2.MaxTokenLength(
                        maximum_tokens_in_sample=512)),
            ])
            test_sampler = samplers.Sampler(test_sampler_config)

            # We inject the `test_sampler` argument so that we can create samples
            # during training.
            self.model.Train(*args, test_sampler=test_sampler, **kwargs)
Beispiel #7
0
def main(argv: typing.List[str]):
    """Main entry point."""
    if len(argv) > 1:
        raise app.UsageError("Unknown arguments: '{}'.".format(" ".join(
            argv[1:])))

    config = generative_model.CreateInstanceProtoFromFlags()

    os.environ["CLGEN_CACHE"] = config.working_dir

    logger = BacktrackingDatabaseLogger(backtracking_db.Database(FLAGS.db))

    model = backtracking_model.BacktrackingModel(config.model, logger=logger)
    sampler = samplers.Sampler(config.sampler)
    sample_observers = generative_model.SampleObserversFromFlags()

    model.Sample(sampler, sample_observers, seed=FLAGS.sample_seed)
Beispiel #8
0
def test_Sampler_Specialize_multiple_tokens_per(
    abc_sampler_config: sampler_pb2.Sampler, ):
    """Test that InvalidSymtokTokens raised if depth tokens encode to mult."""
    t = abc_sampler_config.termination_criteria.add()
    t.symtok.depth_increase_token = "abc"
    t.symtok.depth_decrease_token = "cba"
    s = samplers.Sampler(abc_sampler_config)

    def MockAtomizeString(string):
        """AtomizeString() with a multi-token output."""
        del string
        return np.array([1, 2, 3])

    mock = AtomizerMock()
    mock.AtomizeString = MockAtomizeString
    with test.Raises(errors.InvalidSymtokTokens) as e_info:
        s.Specialize(mock)
    assert ("Sampler symmetrical depth tokens do not encode to a single "
            "token using the corpus vocabulary")
Beispiel #9
0
def SampleStream():
    config = generative_model.CreateInstanceProtoFromFlags()

    os.environ["CLGEN_CACHE"] = config.working_dir

    logger = MyLogger()

    model = backtracking_model.BacktrackingModel(config.model, logger=logger)
    model.corpus.Create()
    sampler = samplers.Sampler(config.sampler)
    atomizer = model.corpus.atomizer
    sampler.Specialize(atomizer)
    batch_size = model.backend.InitSampling(sampler, 0)
    backtracker = backtracking_model.OpenClBacktrackingHelper(
        atomizer, model._target_features)
    model.backend.InitSampleBatch(sampler)
    logger.OnSampleStart(backtracker)
    yield from model.SampleOneWithBacktrackingToTextStream(
        sampler, atomizer, backtracker)
Beispiel #10
0
def test_Sampler_Specialize_invalid_depth_tokens(
    abc_sampler_config: sampler_pb2.Sampler, ):
    """Test that InvalidSymtokTokens raised if depth tokens cannot be encoded."""
    t = abc_sampler_config.termination_criteria.add()
    t.symtok.depth_increase_token = "{"
    t.symtok.depth_decrease_token = "}"
    s = samplers.Sampler(abc_sampler_config)

    def MockAtomizeString(string):
        """AtomizeString() with a vocab error on depth tokens."""
        if string == "{" or string == "}":
            raise errors.VocabError()
        else:
            return np.ndarray([1])

    mock = AtomizerMock()
    mock.AtomizeString = MockAtomizeString
    with test.Raises(errors.InvalidSymtokTokens) as e_info:
        s.Specialize(mock)
    assert ("Sampler symmetrical depth tokens cannot be encoded using the "
            "corpus vocabulary") == str(e_info.value)
Beispiel #11
0
def test_Sampler_Specialize_encoded_start_text(
    abc_sampler_config: sampler_pb2.Sampler, ):
    s = samplers.Sampler(abc_sampler_config)
    assert s.encoded_start_text is None
    s.Specialize(AtomizerMock())
    np.testing.assert_array_equal(np.array([1]), s.encoded_start_text)
Beispiel #12
0
def test_Sampler_batch_size(abc_sampler_config: sampler_pb2.Sampler):
    """Test that batch_size is set from Sampler proto."""
    abc_sampler_config.batch_size = 99
    s = samplers.Sampler(abc_sampler_config)
    assert 99 == s.batch_size
Beispiel #13
0
def test_Sampler_temperature(abc_sampler_config: sampler_pb2.Sampler):
    """Test that temperature is set from Sampler proto."""
    abc_sampler_config.temperature_micros = 1000000
    s = samplers.Sampler(abc_sampler_config)
    assert pytest.approx(1.0) == s.temperature
Beispiel #14
0
def test_Sampler_start_text(abc_sampler_config: sampler_pb2.Sampler):
    """Test that start_text is set from Sampler proto."""
    s = samplers.Sampler(abc_sampler_config)
    assert s.start_text == abc_sampler_config.start_text
Beispiel #15
0
def test_Sampler_config_type_error():
    """Test that a TypeError is raised if config is not a Sampler proto."""
    with test.Raises(TypeError) as e_info:
        samplers.Sampler(1)
    assert "Config must be a Sampler proto. Received: 'int'" == str(
        e_info.value)