示例#1
0
def test_KerasBackend_Sample_exact_multiple_of_batch_size(
    clgen_cache_dir, abc_keras_model_config):
  """Test that min_num_samples are returned when a multiple of batch_size."""
  del clgen_cache_dir
  m = models.Model(abc_keras_model_config)
  sample_observer = sample_observers.InMemorySampleSaver()
  m.Sample(MockSampler(batch_size=2),
           [sample_observers.MaxSampleCountObserver(2), sample_observer])
  assert len(sample_observer.samples) == 2
  sample_observer = sample_observers.InMemorySampleSaver()
  m.Sample(MockSampler(batch_size=2),
           [sample_observers.MaxSampleCountObserver(4), sample_observer])
  assert len(sample_observer.samples) == 4
示例#2
0
def test_TensorFlowBackend_Sample_inexact_multiple_of_batch_size(
    clgen_cache_dir, abc_tensorflow_model_config):
  """Test that min_num_samples are returned when a multiple of batch_size."""
  del clgen_cache_dir
  m = models.Model(abc_tensorflow_model_config)
  sampler = MockSampler()
  sampler.batch_size = 3
  # 3 = 1 * sizeof(batch).
  saver = sample_observers.InMemorySampleSaver()
  m.Sample(sampler, [sample_observers.MaxSampleCountObserver(2), saver])
  assert len(saver.samples) == 3
  # 6 = 2 * sizeof(batch).
  saver = sample_observers.InMemorySampleSaver()
  m.Sample(sampler, [sample_observers.MaxSampleCountObserver(4), saver])
  assert len(saver.samples) == 6
def test_TensorFlowBackend_Sample_return_value_matches_cached_sample(
        clgen_cache_dir, abc_tensorflow_model_config):
    """Test that Sample() returns Sample protos."""
    del clgen_cache_dir
    abc_tensorflow_model_config.training.batch_size = 1
    m = models.Model(abc_tensorflow_model_config)
    sample_observer = sample_observers.InMemorySampleSaver()
    m.Sample(
        MockSampler(hash="hash"),
        [
            sample_observers.MaxSampleCountObserver(1),
            sample_observer,
            sample_observers.LegacySampleCacheObserver(),
        ],
    )
    samples = sample_observer.samples
    # Samples are produced in batches of sampler.batch_size elements.
    assert len(samples) == 1
    assert len(list((m.cache.path / "samples" / "hash").iterdir())) == 1
    cached_sample_path = (m.cache.path / "samples" / "hash" / list(
        (m.cache.path / "samples" / "hash").iterdir())[0])
    assert cached_sample_path.is_file()
    cached_sample = pbutil.FromFile(cached_sample_path, model_pb2.Sample())
    assert samples[0].text == cached_sample.text
    assert samples[0].sample_time_ms == cached_sample.sample_time_ms
    assert (samples[0].sample_start_epoch_ms_utc ==
            cached_sample.sample_start_epoch_ms_utc)
示例#4
0
def test_InMemorySampleSaver():
    observer = sample_observers.InMemorySampleSaver()
    sample = model_pb2.Sample(text="Hello, world!")

    assert observer.OnSample(sample)
    assert len(observer.samples) == 1
    assert observer.samples[-1].text == "Hello, world!"

    assert observer.OnSample(sample)
    assert len(observer.samples) == 2
    assert observer.samples[-1].text == "Hello, world!"
示例#5
0
def test_KerasBackend_Sample_return_value_matches_cached_sample(
    clgen_cache_dir, abc_keras_model_config):
  """Test that Sample() returns Sample protos."""
  del clgen_cache_dir
  m = models.Model(abc_keras_model_config)
  sample_observer = sample_observers.InMemorySampleSaver()
  m.Sample(MockSampler(hash='hash'),
           [sample_observers.MaxSampleCountObserver(1), sample_observer])
  samples = sample_observer.samples
  assert len(samples) == 1
  assert len(list((m.cache.path / 'samples' / 'hash').iterdir())) == 1
  cached_sample_path = (m.cache.path / 'samples' / 'hash' / list(
      (m.cache.path / 'samples' / 'hash').iterdir())[0])
  assert cached_sample_path.is_file()
  cached_sample = pbutil.FromFile(cached_sample_path, model_pb2.Sample())
  assert samples[0].text == cached_sample.text
  assert samples[0].sample_time_ms == cached_sample.sample_time_ms
  assert samples[
      0].sample_start_epoch_ms_utc == cached_sample.sample_start_epoch_ms_utc