Ejemplo n.º 1
0
def test_redis_look_ahead():
    """Test the redis sampler in look-ahead mode."""
    model, prior, distance, obs = basic_testcase()
    eps = pyabc.ListEpsilon([20, 10, 5])
    # spice things up with an adaptive population size
    pop_size = pyabc.AdaptivePopulationSize(start_nr_particles=50,
                                            mean_cv=0.5,
                                            max_population_size=50)
    with tempfile.NamedTemporaryFile(mode='w', suffix='.csv') as fh:
        sampler = RedisEvalParallelSamplerServerStarter(
            look_ahead=True,
            look_ahead_delay_evaluation=False,
            log_file=fh.name)
        try:
            abc = pyabc.ABCSMC(model,
                               prior,
                               distance,
                               sampler=sampler,
                               population_size=pop_size,
                               eps=eps)
            abc.new(pyabc.create_sqlite_db_id(), obs)
            h = abc.run(max_nr_populations=3)
        finally:
            sampler.shutdown()

        assert h.n_populations == 3

        # read log file
        df = pd.read_csv(fh.name, sep=',')
        assert (df.n_lookahead > 0).any()
        assert (df.n_lookahead_accepted > 0).any()
        assert (df.n_preliminary == 0).all()
Ejemplo n.º 2
0
def test_redis_look_ahead_error():
    """Test whether the look-ahead mode fails as expected."""
    model, prior, distance, obs = basic_testcase()
    with tempfile.NamedTemporaryFile(mode='w', suffix='.csv') as fh:
        sampler = RedisEvalParallelSamplerServerStarter(
            look_ahead=True,
            look_ahead_delay_evaluation=False,
            log_file=fh.name)
        args_list = [{
            'eps': pyabc.MedianEpsilon()
        }, {
            'distance_function': pyabc.AdaptivePNormDistance()
        }]
        for args in args_list:
            if 'distance_function' not in args:
                args['distance_function'] = distance
            try:
                with pytest.raises(AssertionError) as e:
                    abc = pyabc.ABCSMC(model,
                                       prior,
                                       sampler=sampler,
                                       population_size=10,
                                       **args)
                    abc.new(pyabc.create_sqlite_db_id(), obs)
                    abc.run(max_nr_populations=3)
                assert "cannot be used in look-ahead mode" in str(e.value)
            finally:
                sampler.shutdown()
Ejemplo n.º 3
0
def test_redis_catch_error():
    def model(pars):
        if np.random.uniform() < 0.1:
            raise ValueError("error")
        return {'s0': pars['p0'] + 0.2 * np.random.uniform()}

    def distance(s0, s1):
        return abs(s0['s0'] - s1['s0'])

    prior = pyabc.Distribution(p0=pyabc.RV("uniform", 0, 10))
    sampler = RedisEvalParallelSamplerServerStarter(batch_size=3,
                                                    workers=1,
                                                    processes_per_worker=1)
    try:
        abc = pyabc.ABCSMC(model,
                           prior,
                           distance,
                           sampler=sampler,
                           population_size=10)

        db_file = "sqlite:///" + os.path.join(tempfile.gettempdir(), "test.db")
        data = {'s0': 2.8}
        abc.new(db_file, data)
        abc.run(minimum_epsilon=.1, max_nr_populations=3)
    finally:
        sampler.shutdown()
Ejemplo n.º 4
0
def redis_starter_sampler(request):
    s = RedisEvalParallelSamplerServerStarter(batch_size=5)
    try:
        yield s
    finally:
        # release all resources
        s.shutdown()
Ejemplo n.º 5
0
def test_redis_continuous_analyses():
    """Test correct behavior of the redis server with multiple analyses."""
    sampler = RedisEvalParallelSamplerServerStarter()
    try:
        sampler.set_analysis_id("id1")
        # try "starting a new run while the old one has not finished yet"
        with pytest.raises(AssertionError) as e:
            sampler.set_analysis_id("id2")
        assert "busy with an analysis " in str(e.value)
        # after stopping it should work
        sampler.stop()
        sampler.set_analysis_id("id2")
    finally:
        sampler.shutdown()
Ejemplo n.º 6
0
def test_redis_pw_protection():
    def simulate_one():
        accepted = np.random.randint(2)
        return pyabc.Particle(0, {}, 0.1, [], [], accepted)

    sampler = RedisEvalParallelSamplerServerStarter(  # noqa: S106
        password="******")
    try:
        # needs to be always set
        sampler.set_analysis_id("ana_id")
        sample = sampler.sample_until_n_accepted(10, simulate_one, 0)
        assert 10 == len(sample.get_accepted_population())
    finally:
        sampler.shutdown()
Ejemplo n.º 7
0
def test_redis_multiprocess():
    def simulate_one():
        accepted = np.random.randint(2)
        return pyabc.Particle(0, {}, 0.1, [], [], accepted)

    sampler = RedisEvalParallelSamplerServerStarter(batch_size=3,
                                                    workers=1,
                                                    processes_per_worker=2)
    try:
        # id needs to be set
        sampler.set_analysis_id("ana_id")

        sample = sampler.sample_until_n_accepted(10, simulate_one, 0)
        assert 10 == len(sample.get_accepted_population())
    finally:
        sampler.shutdown()
Ejemplo n.º 8
0
def test_redis_look_ahead():
    """Test the redis sampler in look-ahead mode."""
    model, prior, distance, obs = basic_testcase()
    eps = pyabc.ListEpsilon([20, 10, 5])
    # spice things up with an adaptive population size
    pop_size = pyabc.AdaptivePopulationSize(start_nr_particles=50,
                                            mean_cv=0.5,
                                            max_population_size=50)
    with tempfile.NamedTemporaryFile(mode='w', suffix='.csv') as fh:
        sampler = RedisEvalParallelSamplerServerStarter(
            look_ahead=True,
            look_ahead_delay_evaluation=False,
            log_file=fh.name,
        )
        try:
            abc = pyabc.ABCSMC(
                model,
                prior,
                distance,
                sampler=sampler,
                population_size=pop_size,
                eps=eps,
            )
            abc.new(pyabc.create_sqlite_db_id(), obs)
            h = abc.run(max_nr_populations=3)
        finally:
            sampler.shutdown()

        assert h.n_populations == 3

        # read log file
        df = pd.read_csv(fh.name, sep=',')
        assert (df.n_lookahead > 0).any()
        assert (df.n_lookahead_accepted > 0).any()
        assert (df.n_preliminary == 0).all()

        # check history proposal ids
        for t in range(0, h.max_t + 1):
            pop = h.get_population(t=t)
            pop_size = len(pop)
            n_lookahead_pop = len(
                [p for p in pop.particles if p.proposal_id == -1])
            assert (min(pop_size, int(
                df.loc[df.t == t, 'n_lookahead_accepted'])) == n_lookahead_pop)
Ejemplo n.º 9
0
def test_redis_subprocess():
    """Test whether the instructed redis sampler allows worker subprocesses."""
    # print worker output
    logging.getLogger("Redis-Worker").addHandler(logging.StreamHandler())

    def model_process(p, pipe):
        """The actual model."""
        pipe.send({"y": p['p0'] + 0.1 * np.random.randn(10)})

    def model(p):
        """Model calling a subprocess."""
        parent, child = multiprocessing.Pipe()
        proc = multiprocessing.Process(target=model_process, args=(p, child))
        proc.start()
        res = parent.recv()
        proc.join()
        return res

    prior = pyabc.Distribution(p0=pyabc.RV('uniform', -5, 10),
                               p1=pyabc.RV('uniform', -2, 2))

    def distance(y1, y2):
        return np.abs(y1['y'] - y2['y']).sum()

    obs = {'y': 1}
    # False as daemon argument is ok, True and None are not allowed
    sampler = RedisEvalParallelSamplerServerStarter(workers=1,
                                                    processes_per_worker=2,
                                                    daemon=False)
    try:
        abc = pyabc.ABCSMC(model,
                           prior,
                           distance,
                           sampler=sampler,
                           population_size=10)
        abc.new(pyabc.create_sqlite_db_id(), obs)
        # would just never return if model evaluation fails
        abc.run(max_nr_populations=3)
    finally:
        sampler.shutdown()