예제 #1
0
def test_sum_stats_save_load(history: History):
    arr = sp.random.rand(10)
    arr2 = sp.random.rand(10, 2)
    particle_list = [
        Particle(m=0,
                 parameter=Parameter({"a": 23, "b": 12}),
                 weight=.2,
                 accepted_sum_stats=[{"ss1": .1, "ss2": arr2,
                                      "ss3": example_df(),
                                      "rdf0": r["faithful"]}],
                 # TODO: check why iris fails
                 accepted_distances=[.1]),
        Particle(m=0,
                 parameter=Parameter({"a": 23, "b": 12}),
                 weight=.2,
                 accepted_sum_stats=[{"ss12": .11, "ss22": arr,
                                      "ss33": example_df(),
                                      "rdf": r["mtcars"]}],
                 accepted_distances=[.1])]

    history.append_population(0, 42,
                              Population(particle_list), 2, ["m1", "m2"])
    weights, sum_stats = history.get_weighted_sum_stats_for_model(0, 0)
    assert (weights == 0.5).all()
    assert sum_stats[0]["ss1"] == .1
    assert (sum_stats[0]["ss2"] == arr2).all()
    assert (sum_stats[0]["ss3"] == example_df()).all().all()
    assert (sum_stats[0]["rdf0"] == pandas2ri.ri2py(r["faithful"])).all().all()
    assert sum_stats[1]["ss12"] == .11
    assert (sum_stats[1]["ss22"] == arr).all()
    assert (sum_stats[1]["ss33"] == example_df()).all().all()
    assert (sum_stats[1]["rdf"] == pandas2ri.ri2py(r["mtcars"])).all().all()
예제 #2
0
def test_sum_stats_save_load(history: History):
    arr = np.random.rand(10)
    arr2 = np.random.rand(10, 2)
    particle_list = [
        Particle(m=0,
                 parameter=Parameter({
                     "a": 23,
                     "b": 12
                 }),
                 weight=.2,
                 sum_stat={
                     "ss1": .1,
                     "ss2": arr2,
                     "ss3": example_df(),
                     "rdf0": r["iris"]
                 },
                 distance=.1),
        Particle(m=0,
                 parameter=Parameter({
                     "a": 23,
                     "b": 12
                 }),
                 weight=.2,
                 sum_stat={
                     "ss12": .11,
                     "ss22": arr,
                     "ss33": example_df(),
                     "rdf": r["mtcars"]
                 },
                 distance=.1)
    ]

    history.append_population(0, 42, Population(particle_list), 2,
                              ["m1", "m2"])
    weights, sum_stats = history.get_weighted_sum_stats_for_model(0, 0)
    assert (weights == 0.5).all()
    assert sum_stats[0]["ss1"] == .1
    assert (sum_stats[0]["ss2"] == arr2).all()
    assert (sum_stats[0]["ss3"] == example_df()).all().all()
    with localconverter(pandas2ri.converter):
        assert (sum_stats[0]["rdf0"] == r["iris"]).all().all()
    assert sum_stats[1]["ss12"] == .11
    assert (sum_stats[1]["ss22"] == arr).all()
    assert (sum_stats[1]["ss33"] == example_df()).all().all()
    with localconverter(pandas2ri.converter):
        assert (sum_stats[1]["rdf"] == r["mtcars"]).all().all()