def test_KerasBackend_Sample_exact_multiple_of_batch_size( clgen_cache_dir, abc_keras_model_config): """Test that min_num_samples are returned when a multiple of batch_size.""" del clgen_cache_dir m = models.Model(abc_keras_model_config) sample_observer = sample_observers.InMemorySampleSaver() m.Sample(MockSampler(batch_size=2), [sample_observers.MaxSampleCountObserver(2), sample_observer]) assert len(sample_observer.samples) == 2 sample_observer = sample_observers.InMemorySampleSaver() m.Sample(MockSampler(batch_size=2), [sample_observers.MaxSampleCountObserver(4), sample_observer]) assert len(sample_observer.samples) == 4
def test_TensorFlowBackend_Sample_inexact_multiple_of_batch_size( clgen_cache_dir, abc_tensorflow_model_config): """Test that min_num_samples are returned when a multiple of batch_size.""" del clgen_cache_dir m = models.Model(abc_tensorflow_model_config) sampler = MockSampler() sampler.batch_size = 3 # 3 = 1 * sizeof(batch). saver = sample_observers.InMemorySampleSaver() m.Sample(sampler, [sample_observers.MaxSampleCountObserver(2), saver]) assert len(saver.samples) == 3 # 6 = 2 * sizeof(batch). saver = sample_observers.InMemorySampleSaver() m.Sample(sampler, [sample_observers.MaxSampleCountObserver(4), saver]) assert len(saver.samples) == 6
def test_TensorFlowBackend_Sample_return_value_matches_cached_sample( clgen_cache_dir, abc_tensorflow_model_config): """Test that Sample() returns Sample protos.""" del clgen_cache_dir abc_tensorflow_model_config.training.batch_size = 1 m = models.Model(abc_tensorflow_model_config) sample_observer = sample_observers.InMemorySampleSaver() m.Sample( MockSampler(hash="hash"), [ sample_observers.MaxSampleCountObserver(1), sample_observer, sample_observers.LegacySampleCacheObserver(), ], ) samples = sample_observer.samples # Samples are produced in batches of sampler.batch_size elements. assert len(samples) == 1 assert len(list((m.cache.path / "samples" / "hash").iterdir())) == 1 cached_sample_path = (m.cache.path / "samples" / "hash" / list( (m.cache.path / "samples" / "hash").iterdir())[0]) assert cached_sample_path.is_file() cached_sample = pbutil.FromFile(cached_sample_path, model_pb2.Sample()) assert samples[0].text == cached_sample.text assert samples[0].sample_time_ms == cached_sample.sample_time_ms assert (samples[0].sample_start_epoch_ms_utc == cached_sample.sample_start_epoch_ms_utc)
def test_InMemorySampleSaver(): observer = sample_observers.InMemorySampleSaver() sample = model_pb2.Sample(text="Hello, world!") assert observer.OnSample(sample) assert len(observer.samples) == 1 assert observer.samples[-1].text == "Hello, world!" assert observer.OnSample(sample) assert len(observer.samples) == 2 assert observer.samples[-1].text == "Hello, world!"
def test_KerasBackend_Sample_return_value_matches_cached_sample( clgen_cache_dir, abc_keras_model_config): """Test that Sample() returns Sample protos.""" del clgen_cache_dir m = models.Model(abc_keras_model_config) sample_observer = sample_observers.InMemorySampleSaver() m.Sample(MockSampler(hash='hash'), [sample_observers.MaxSampleCountObserver(1), sample_observer]) samples = sample_observer.samples assert len(samples) == 1 assert len(list((m.cache.path / 'samples' / 'hash').iterdir())) == 1 cached_sample_path = (m.cache.path / 'samples' / 'hash' / list( (m.cache.path / 'samples' / 'hash').iterdir())[0]) assert cached_sample_path.is_file() cached_sample = pbutil.FromFile(cached_sample_path, model_pb2.Sample()) assert samples[0].text == cached_sample.text assert samples[0].sample_time_ms == cached_sample.sample_time_ms assert samples[ 0].sample_start_epoch_ms_utc == cached_sample.sample_start_epoch_ms_utc