示例#1
0
def test_KerasBackend_is_trained(clgen_cache_dir, abc_keras_model_config):
    """Test that is_trained changes to True when model is trained."""
    del clgen_cache_dir
    m = models.Model(abc_keras_model_config)
    assert not m.is_trained
    m.Train()
    assert m.is_trained
示例#2
0
def test_benchmark_TensorFlowModel_Train_already_trained(
        clgen_cache_dir, abc_tensorflow_model_config, benchmark):
    """Benchmark the Train() method on an already-trained model."""
    del clgen_cache_dir
    m = models.Model(abc_tensorflow_model_config)
    m.Train()  # "Offline" training from cold.
    benchmark(m.Train)
示例#3
0
def test_KerasBackend_Sample_exact_multiple_of_batch_size(
        clgen_cache_dir, abc_keras_model_config):
    """Test that min_num_samples are returned when a multiple of batch_size."""
    del clgen_cache_dir
    m = models.Model(abc_keras_model_config)
    assert len(m.Sample(MockSampler(batch_size=2), 2)) == 2
    assert len(m.Sample(MockSampler(batch_size=2), 4)) == 4
示例#4
0
    def __init__(self, config: clgen_pb2.Instance, dashboard_opts={}):
        """Instantiate an instance.

    Args:
      config: An Instance proto.

    Raises:
      UserError: If the instance proto contains invalid values, is missing
        a model or sampler fields.
    """
        try:
            pbutil.AssertFieldIsSet(config, "model_specification")
            pbutil.AssertFieldIsSet(config, "sampler")
        except pbutil.ProtoValueError as e:
            raise errors.UserError(e)

        self.config = config
        self.working_dir = None
        if config.HasField("working_dir"):
            self.working_dir: pathlib.Path = pathlib.Path(
                os.path.expandvars(
                    config.working_dir)).expanduser().absolute()
        # Enter a session so that the cache paths are set relative to any requested
        # working directory.
        with self.Session():
            if config.HasField("model"):
                self.model: models.Model = models.Model(config.model)
            else:
                self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel(
                    pathlib.Path(config.pretrained_model))
            self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)

        self.dashboard = dashboard.Launch(**dashboard_opts)
示例#5
0
def test_TensorFlowBackend_Train_GetShortSummary_before_create(
        clgen_cache_dir, abc_tensorflow_model_config):
    """Test that model training produced telemetry files."""
    del clgen_cache_dir
    m = models.Model(abc_tensorflow_model_config)
    with test.Raises(ValueError):
        m.GetShortSummary()
示例#6
0
def test_TensorFlowBackend_Train_missing_intermediate_checkpoints(
        clgen_cache_dir, abc_tensorflow_model_config):
    """Test that a missing intermediate checkpoint does not affect training."""
    del clgen_cache_dir
    abc_tensorflow_model_config.training.num_epochs = 2
    m = models.Model(abc_tensorflow_model_config)
    m.Train()
    assert 2 == len(m.backend.epoch_checkpoints)

    checkpoints_dir = m.cache.path / "checkpoints"
    for path in checkpoints_dir.iterdir():
        # Remove all files which are not either the checkpoints list, or the most
        # recent checkpoint.
        if not path.name == "checkpoint" and not path.name.startswith(
                "checkpoint-2"):
            path.unlink()
    f1a = checksumdir.dirhash(checkpoints_dir)

    assert 1 == len(m.backend.epoch_checkpoints)
    assert 2 in m.backend.epoch_checkpoints

    # Run Train() again to check that nothing is changed.
    m.Train()
    assert 1 == len(m.backend.epoch_checkpoints)
    assert 2 in m.backend.epoch_checkpoints
    f1b = checksumdir.dirhash(checkpoints_dir)
    assert f1a == f1b
示例#7
0
def test_TensorFlowBackend_Sample_return_value_matches_cached_sample(
        clgen_cache_dir, abc_tensorflow_model_config):
    """Test that Sample() returns Sample protos."""
    del clgen_cache_dir
    abc_tensorflow_model_config.training.batch_size = 1
    m = models.Model(abc_tensorflow_model_config)
    sample_observer = sample_observers.InMemorySampleSaver()
    m.Sample(
        MockSampler(hash="hash"),
        [
            sample_observers.MaxSampleCountObserver(1),
            sample_observer,
            sample_observers.LegacySampleCacheObserver(),
        ],
    )
    samples = sample_observer.samples
    # Samples are produced in batches of sampler.batch_size elements.
    assert len(samples) == 1
    assert len(list((m.cache.path / "samples" / "hash").iterdir())) == 1
    cached_sample_path = (m.cache.path / "samples" / "hash" / list(
        (m.cache.path / "samples" / "hash").iterdir())[0])
    assert cached_sample_path.is_file()
    cached_sample = pbutil.FromFile(cached_sample_path, model_pb2.Sample())
    assert samples[0].text == cached_sample.text
    assert samples[0].sample_time_ms == cached_sample.sample_time_ms
    assert (samples[0].sample_start_epoch_ms_utc ==
            cached_sample.sample_start_epoch_ms_utc)
示例#8
0
def test_Model_config_sequence_length_not_set(clgen_cache_dir,
                                              abc_model_config):
    """Test that an error is raised if sequence_length is < 1."""
    del clgen_cache_dir
    abc_model_config.training.sequence_length = -1
    with test.Raises(errors.UserError):
        models.Model(abc_model_config)
示例#9
0
def test_Model_metafile(clgen_cache_dir, abc_model_config):
    """A newly instantiated model's cache has a metafile."""
    del clgen_cache_dir
    m = models.Model(abc_model_config)
    assert (m.cache.path / "META.pbtxt").is_file()
    assert pbutil.ProtoIsReadable(m.cache.path / "META.pbtxt",
                                  internal_pb2.ModelMeta())
示例#10
0
def test_benchmark_KerasBackend_Train_already_trained(
    clgen_cache_dir, abc_keras_model_config, benchmark):
  """Benchmark the Train() method on an already-trained model."""
  del clgen_cache_dir
  m = models.Model(abc_keras_model_config)
  m.Train()  # "Offline" training from cold.
  benchmark(m.Train)
示例#11
0
def test_KerasBackend_GetInferenceModel_predict_output_shape(
        clgen_cache_dir, abc_keras_model_config):
    """Test that predict() on inference model is one-hot encoded."""
    del clgen_cache_dir
    m = models.Model(abc_keras_model_config)
    im, batch_size = m.backend.GetInferenceModel()
    probabilities = im.predict(np.array([[0]]) * batch_size)
    assert (batch_size, 1, m.corpus.vocab_size) == probabilities.shape
示例#12
0
def test_TensorFlowBackend_Train_GetShortSummary(clgen_cache_dir,
                                                 abc_tensorflow_model_config):
    """Test that model training produced telemetry files."""
    del clgen_cache_dir
    m = models.Model(abc_tensorflow_model_config)
    m.Create()
    assert (m.GetShortSummary() ==
            "4×1 LSTM network, 59 token corpus with 25-element vocabulary")
示例#13
0
def test_TensorFlowBackend_Sample_exact_multiple_of_batch_size(
        clgen_cache_dir, abc_tensorflow_model_config):
    """Test that min_num_samples are returned when a multiple of batch_size."""
    del clgen_cache_dir
    abc_tensorflow_model_config.training.batch_size = 2
    m = models.Model(abc_tensorflow_model_config)
    assert len(m.Sample(MockSampler(), 2)) == 2
    assert len(m.Sample(MockSampler(), 4)) == 4
示例#14
0
def test_KerasBackend_Sample_implicit_train(clgen_cache_dir,
                                            abc_keras_model_config):
  """Test that Sample() implicitly trains the model."""
  del clgen_cache_dir
  m = models.Model(abc_keras_model_config)
  assert not m.is_trained
  m.Sample(MockSampler(), 1)
  assert m.is_trained
示例#15
0
def test_TensorFlowBackend_Train_is_trained(clgen_cache_dir,
                                            abc_tensorflow_model_config):
  """Test that is_trained is initially false until trained."""
  del clgen_cache_dir
  m = models.Model(abc_tensorflow_model_config)
  assert not m.is_trained
  m.Train()
  assert m.is_trained
示例#16
0
def test_TensorFlowBackend_Sample_implicit_train(clgen_cache_dir,
                                                 abc_tensorflow_model_config):
  """Test that Sample() implicitly trains the model."""
  del clgen_cache_dir
  m = models.Model(abc_tensorflow_model_config)
  assert not m.is_trained
  m.Sample(MockSampler(), [sample_observers.MaxSampleCountObserver(1)])
  assert m.is_trained
示例#17
0
def test_Model_atomizer_symlink(clgen_cache_dir, abc_model_config):
    """Test path of symlink to atomizer."""
    del clgen_cache_dir
    m = models.Model(abc_model_config)
    assert (m.cache.path / 'atomizer').is_symlink()
    path = str((m.cache.path / 'atomizer').resolve())
    # We can't do a literal comparison because of bazel sandboxing.
    assert path.endswith(str(m.corpus.atomizer_path))
示例#18
0
def test_Model_corpus_symlink(clgen_cache_dir, abc_model_config):
    """Test path of symlink to corpus files."""
    del clgen_cache_dir
    m = models.Model(abc_model_config)
    assert (m.cache.path / 'corpus').is_symlink()
    path = str((m.cache.path / 'corpus').resolve())
    # We can't do a literal comparison because of bazel sandboxing.
    assert path.endswith(str(m.corpus.encoded.database_path.parent))
示例#19
0
def test_TensorFlowBackend_Train_with_sample_callback(
    clgen_cache_dir, abc_tensorflow_model_config):
  """Test that sampling during training does not blow up."""
  del clgen_cache_dir
  abc_tensorflow_model_config.training.num_epochs = 2
  sampler = MockSampler()
  m = models.Model(abc_tensorflow_model_config)
  m.Train(test_sampler=sampler)
  assert m.is_trained
示例#20
0
def test_Model_directories(clgen_cache_dir, abc_model_config):
    """A newly instantiated model's cache has checkpoint and sample dirs."""
    del clgen_cache_dir
    m = models.Model(abc_model_config)
    assert (m.cache.path / 'checkpoints').is_dir()
    assert (m.cache.path / 'samples').is_dir()
    # There should be nothing in these directories yet.
    assert not list((m.cache.path / 'checkpoints').iterdir())
    assert not list((m.cache.path / 'samples').iterdir())
示例#21
0
def test_Model_corpus_symlink(clgen_cache_dir, abc_model_config):
    """Test path of symlink to corpus files."""
    del clgen_cache_dir
    m = models.Model(abc_model_config)
    assert (m.cache.path / "corpus").is_symlink()
    path = str((m.cache.path / "corpus").resolve())
    # We can't do a literal comparison because of bazel sandboxing.
    assert path.endswith(
        str(pathlib.Path(m.corpus.encoded.url[len("sqlite:///"):]).parent))
示例#22
0
def test_KerasBackend_Train_telemetry(clgen_cache_dir, abc_keras_model_config):
    """Test that model training produced telemetry files."""
    del clgen_cache_dir
    abc_keras_model_config.training.num_epochs = 2
    m = models.Model(abc_keras_model_config)
    assert len(m.TrainingTelemetry()) == 0
    m.Train()
    assert len(m.TrainingTelemetry()) == 2
    for telemetry in m.TrainingTelemetry():
        assert isinstance(telemetry, telemetry_pb2.ModelEpochTelemetry)
示例#23
0
def test_KerasBackend_Train_epoch_checkpoints(clgen_cache_dir,
                                              abc_keras_model_config):
    """Test that a trained model generates weight checkpoints."""
    del clgen_cache_dir
    abc_keras_model_config.training.num_epochs = 2
    m = models.Model(abc_keras_model_config)
    m.Train()
    assert len(m.backend.epoch_checkpoints) == 2
    for path in m.backend.epoch_checkpoints:
        assert path.is_file()
示例#24
0
def test_TensorFlowBackend_Train_epoch_checkpoints(clgen_cache_dir,
                                                   abc_tensorflow_model_config):
  """Test that epoch_checkpoints returns a <int, str> dict."""
  del clgen_cache_dir
  abc_tensorflow_model_config.training.num_epochs = 2
  m = models.Model(abc_tensorflow_model_config)
  assert not m.backend.epoch_checkpoints
  m.Train()
  epoch_checkpoints = m.backend.epoch_checkpoints
  assert 2 == len(epoch_checkpoints)
  assert 1 in epoch_checkpoints
  assert 2 in epoch_checkpoints
示例#25
0
def test_KerasBackend_Sample_exact_multiple_of_batch_size(
    clgen_cache_dir, abc_keras_model_config):
  """Test that min_num_samples are returned when a multiple of batch_size."""
  del clgen_cache_dir
  m = models.Model(abc_keras_model_config)
  sample_observer = sample_observers.InMemorySampleSaver()
  m.Sample(MockSampler(batch_size=2),
           [sample_observers.MaxSampleCountObserver(2), sample_observer])
  assert len(sample_observer.samples) == 2
  sample_observer = sample_observers.InMemorySampleSaver()
  m.Sample(MockSampler(batch_size=2),
           [sample_observers.MaxSampleCountObserver(4), sample_observer])
  assert len(sample_observer.samples) == 4
示例#26
0
def test_KerasBackend_Train_twice(clgen_cache_dir, abc_keras_model_config):
    """Test that TensorFlow checkpoint does not change after training twice."""
    del clgen_cache_dir
    abc_keras_model_config.training.num_epochs = 1
    m = models.Model(abc_keras_model_config)
    m.Train()
    f1a = checksumdir.dirhash(m.cache.path / "checkpoints")
    f1b = crypto.md5_file(m.cache.path / "META.pbtxt")
    m.Train()
    f2a = checksumdir.dirhash(m.cache.path / "checkpoints")
    f2b = crypto.md5_file(m.cache.path / "META.pbtxt")
    assert f1a == f2a
    assert f1b == f2b
示例#27
0
def LsModels(cache_root: pathlib.Path) -> None:
    for model_dir in (cache_root / "model").iterdir():
        meta_file = model_dir / "META.pbtxt"
        if pbutil.ProtoIsReadable(meta_file, internal_pb2.ModelMeta()):
            model = models.Model(
                pbutil.FromFile(meta_file, internal_pb2.ModelMeta()).config)
            telemetry = list(model.TrainingTelemetry())
            num_epochs = model.config.training.num_epochs
            n = len(telemetry)
            print(f"{model_dir} {n} / {num_epochs} epochs")
        elif meta_file.is_file():
            app.Warning("Meta file %s cannot be read.", meta_file)
        else:
            app.Warning("Meta file %s not found.", meta_file)
示例#28
0
def test_TensorFlowBackend_Sample_inexact_multiple_of_batch_size(
    clgen_cache_dir, abc_tensorflow_model_config):
  """Test that min_num_samples are returned when a multiple of batch_size."""
  del clgen_cache_dir
  m = models.Model(abc_tensorflow_model_config)
  sampler = MockSampler()
  sampler.batch_size = 3
  # 3 = 1 * sizeof(batch).
  saver = sample_observers.InMemorySampleSaver()
  m.Sample(sampler, [sample_observers.MaxSampleCountObserver(2), saver])
  assert len(saver.samples) == 3
  # 6 = 2 * sizeof(batch).
  saver = sample_observers.InMemorySampleSaver()
  m.Sample(sampler, [sample_observers.MaxSampleCountObserver(4), saver])
  assert len(saver.samples) == 6
示例#29
0
def test_KerasBackend_Sample_return_value_matches_cached_sample(
        clgen_cache_dir, abc_keras_model_config):
    """Test that Sample() returns Sample protos."""
    del clgen_cache_dir
    m = models.Model(abc_keras_model_config)
    samples = m.Sample(MockSampler(hash='hash'), 1)
    assert len(samples) == 1
    assert len(list((m.cache.path / 'samples' / 'hash').iterdir())) == 1
    cached_sample_path = (m.cache.path / 'samples' / 'hash' / list(
        (m.cache.path / 'samples' / 'hash').iterdir())[0])
    assert cached_sample_path.is_file()
    cached_sample = pbutil.FromFile(cached_sample_path, model_pb2.Sample())
    assert samples[0].text == cached_sample.text
    assert samples[0].sample_time_ms == cached_sample.sample_time_ms
    assert samples[
        0].sample_start_epoch_ms_utc == cached_sample.sample_start_epoch_ms_utc
示例#30
0
def test_KerasBackend_Sample_return_value_matches_cached_sample(
        clgen_cache_dir, abc_keras_model_config):
    """Test that Sample() returns Sample protos."""
    del clgen_cache_dir
    m = models.Model(abc_keras_model_config)
    sample_observer = sample_observers.InMemorySampleSaver()
    m.Sample(
        MockSampler(hash="hash"),
        [sample_observers.MaxSampleCountObserver(1), sample_observer],
    )
    samples = sample_observer.samples
    assert len(samples) == 1
    assert len(list((m.cache.path / "samples" / "hash").iterdir())) == 1
    cached_sample_path = (m.cache.path / "samples" / "hash" / list(
        (m.cache.path / "samples" / "hash").iterdir())[0])
    assert cached_sample_path.is_file()
    cached_sample = pbutil.FromFile(cached_sample_path, model_pb2.Sample())
    assert samples[0].text == cached_sample.text
    assert samples[0].sample_time_ms == cached_sample.sample_time_ms
    assert (samples[0].sample_start_epoch_ms_utc ==
            cached_sample.sample_start_epoch_ms_utc)