def test_redis_pw_protection(): sampler = RedisEvalParallelSamplerServerStarter( # noqa: S106 password="******", port=8888) def simulate_one(): accepted = np.random.randint(2) return Particle(0, {}, 0.1, [], [], accepted) sample = sampler.sample_until_n_accepted(10, simulate_one) assert 10 == len(sample.get_accepted_population()) sampler.cleanup()
def test_redis_multiprocess(): sampler = RedisEvalParallelSamplerServerStarter( batch_size=3, workers=1, processes_per_worker=1) def simulate_one(): accepted = np.random.randint(2) return Particle(0, {}, 0.1, [], [], accepted) sample = sampler.sample_until_n_accepted(10, simulate_one) assert 10 == len(sample.get_accepted_population()) sampler.cleanup()
def test_redis_catch_error(): def model(pars): if np.random.uniform() < 0.1: raise ValueError("error") return {'s0': pars['p0'] + 0.2 * np.random.uniform()} def distance(s0, s1): return abs(s0['s0'] - s1['s0']) prior = Distribution(p0=RV("uniform", 0, 10)) sampler = RedisEvalParallelSamplerServerStarter( batch_size=3, workers=1, processes_per_worker=1, port=8775) abc = ABCSMC(model, prior, distance, sampler=sampler, population_size=10) db_file = "sqlite:///" + os.path.join(tempfile.gettempdir(), "test.db") data = {'s0': 2.8} abc.new(db_file, data) abc.run(minimum_epsilon=.1, max_nr_populations=3) sampler.cleanup()
def redis_starter_sampler(): s = RedisEvalParallelSamplerServerStarter(batch_size=5) yield s s.cleanup()