Esempio n. 1
0
def test_KerasBackend_Sample_exact_multiple_of_batch_size(
    clgen_cache_dir, abc_keras_model_config):
  """Test that min_num_samples are returned when a multiple of batch_size."""
  del clgen_cache_dir
  m = models.Model(abc_keras_model_config)
  sample_observer = sample_observers.InMemorySampleSaver()
  m.Sample(MockSampler(batch_size=2),
           [sample_observers.MaxSampleCountObserver(2), sample_observer])
  assert len(sample_observer.samples) == 2
  sample_observer = sample_observers.InMemorySampleSaver()
  m.Sample(MockSampler(batch_size=2),
           [sample_observers.MaxSampleCountObserver(4), sample_observer])
  assert len(sample_observer.samples) == 4
Esempio n. 2
0
def test_TensorFlowBackend_Sample_inexact_multiple_of_batch_size(
    clgen_cache_dir, abc_tensorflow_model_config):
  """Test that min_num_samples are returned when a multiple of batch_size."""
  del clgen_cache_dir
  m = models.Model(abc_tensorflow_model_config)
  sampler = MockSampler()
  sampler.batch_size = 3
  # 3 = 1 * sizeof(batch).
  saver = sample_observers.InMemorySampleSaver()
  m.Sample(sampler, [sample_observers.MaxSampleCountObserver(2), saver])
  assert len(saver.samples) == 3
  # 6 = 2 * sizeof(batch).
  saver = sample_observers.InMemorySampleSaver()
  m.Sample(sampler, [sample_observers.MaxSampleCountObserver(4), saver])
  assert len(saver.samples) == 6
def test_TensorFlowBackend_Sample_return_value_matches_cached_sample(
        clgen_cache_dir, abc_tensorflow_model_config):
    """Test that Sample() returns Sample protos."""
    del clgen_cache_dir
    abc_tensorflow_model_config.training.batch_size = 1
    m = models.Model(abc_tensorflow_model_config)
    sample_observer = sample_observers.InMemorySampleSaver()
    m.Sample(
        MockSampler(hash="hash"),
        [
            sample_observers.MaxSampleCountObserver(1),
            sample_observer,
            sample_observers.LegacySampleCacheObserver(),
        ],
    )
    samples = sample_observer.samples
    # Samples are produced in batches of sampler.batch_size elements.
    assert len(samples) == 1
    assert len(list((m.cache.path / "samples" / "hash").iterdir())) == 1
    cached_sample_path = (m.cache.path / "samples" / "hash" / list(
        (m.cache.path / "samples" / "hash").iterdir())[0])
    assert cached_sample_path.is_file()
    cached_sample = pbutil.FromFile(cached_sample_path, model_pb2.Sample())
    assert samples[0].text == cached_sample.text
    assert samples[0].sample_time_ms == cached_sample.sample_time_ms
    assert (samples[0].sample_start_epoch_ms_utc ==
            cached_sample.sample_start_epoch_ms_utc)
Esempio n. 4
0
def test_TensorFlowBackend_Sample_implicit_train(clgen_cache_dir,
                                                 abc_tensorflow_model_config):
  """Test that Sample() implicitly trains the model."""
  del clgen_cache_dir
  m = models.Model(abc_tensorflow_model_config)
  assert not m.is_trained
  m.Sample(MockSampler(), [sample_observers.MaxSampleCountObserver(1)])
  assert m.is_trained
Esempio n. 5
0
def SampleObserversFromFlags() -> typing.List[sample_observers.SampleObserver]:
  """Create sample observers for use with model.Sample() from flags values."""
  observers = []
  if FLAGS.clgen_min_sample_count >= 0:
    app.Warning('--clgen_min_sample_count <= 0 means that sampling (and this '
                'process) will never terminate!')
    observers.append(
        sample_observers.MaxSampleCountObserver(FLAGS.clgen_min_sample_count))
  if FLAGS.clgen_cache_sample_protos:
    observers.append(sample_observers.LegacySampleCacheObserver())
  if FLAGS.clgen_print_samples:
    observers.append(sample_observers.PrintSampleObserver())
  return observers
Esempio n. 6
0
def SampleObserversFromFlags():
  """Instantiate sample observers from flag values."""
  sample_observers = []
  if FLAGS.min_samples <= 0:
    app.Warning(
        'Entering an infinite sample loop, this process will never end!')
  else:
    sample_observers.append(
        sample_observers_lib.MaxSampleCountObserver(FLAGS.min_samples))
  if FLAGS.print_samples:
    sample_observers.append(sample_observers_lib.PrintSampleObserver())
  if FLAGS.cache_samples:
    sample_observers.append(sample_observers_lib.LegacySampleCacheObserver())
  if FLAGS.sample_text_dir:
    sample_observers.append(
        sample_observers_lib.SaveSampleTextObserver(
            pathlib.Path(FLAGS.sample_text_dir)))
  return sample_observers
Esempio n. 7
0
def test_KerasBackend_Sample_return_value_matches_cached_sample(
    clgen_cache_dir, abc_keras_model_config):
  """Test that Sample() returns Sample protos."""
  del clgen_cache_dir
  m = models.Model(abc_keras_model_config)
  sample_observer = sample_observers.InMemorySampleSaver()
  m.Sample(MockSampler(hash='hash'),
           [sample_observers.MaxSampleCountObserver(1), sample_observer])
  samples = sample_observer.samples
  assert len(samples) == 1
  assert len(list((m.cache.path / 'samples' / 'hash').iterdir())) == 1
  cached_sample_path = (m.cache.path / 'samples' / 'hash' / list(
      (m.cache.path / 'samples' / 'hash').iterdir())[0])
  assert cached_sample_path.is_file()
  cached_sample = pbutil.FromFile(cached_sample_path, model_pb2.Sample())
  assert samples[0].text == cached_sample.text
  assert samples[0].sample_time_ms == cached_sample.sample_time_ms
  assert samples[
      0].sample_start_epoch_ms_utc == cached_sample.sample_start_epoch_ms_utc
Esempio n. 8
0
def test_MaxSampleCountObserver():
    observer = sample_observers.MaxSampleCountObserver(3)
    assert observer.OnSample(None)
    assert observer.OnSample(None)
    assert not observer.OnSample(None)
Esempio n. 9
0
def main(argv):
    del argv
    for proto in PROTOS:
        instance = clgen.Instance.FromFile(proto)
        instance.Sample(
            sample_observers=[sample_observers.MaxSampleCountObserver(1000)])