Пример #1
0
def test_type_filter(make_filter_snapshot, simulation_factory, type_indices):
    particle_types = ['A', 'B']
    N = 10
    filter_snapshot = make_filter_snapshot(n=N, particle_types=particle_types)
    sim = simulation_factory(filter_snapshot)

    A_filter = Type(["A"])
    B_filter = Type(["B"])
    AB_filter = Type(["A", "B"])
    assert A_filter(sim.state) == list(range(N))
    assert B_filter(sim.state) == []

    s = sim.state.get_snapshot()
    if s.communicator.rank == 0:
        set_types(s, range(N), particle_types, "B")
    sim.state.set_snapshot(s)
    assert A_filter(sim.state) == []
    assert B_filter(sim.state) == list(range(N))

    A_indices, B_indices = type_indices
    s = sim.state.get_snapshot()
    if s.communicator.rank == 0:
        set_types(s, A_indices, particle_types, "A")
        set_types(s, B_indices, particle_types, "B")
    sim.state.set_snapshot(s)
    assert A_filter(sim.state) == A_indices
    assert B_filter(sim.state) == B_indices
    assert AB_filter(sim.state) == list(range(N))
Пример #2
0
def test_difference(make_filter_snapshot, simulation_factory, set_indices):
    particle_types = ['A', 'B', 'C']
    N = 10
    filter_snapshot = make_filter_snapshot(n=N, particle_types=particle_types)
    sim = simulation_factory(filter_snapshot)
    A_indices, B_indices, C_indices = set_indices
    s = sim.state.get_snapshot()
    if s.communicator.rank == 0:
        set_types(s, A_indices, particle_types, "A")
        set_types(s, B_indices, particle_types, "B")
        set_types(s, C_indices, particle_types, "C")
    sim.state.set_snapshot(s)

    for type_combo in combinations(particle_types, 2):
        combo_filter = Type(type_combo)
        remaining_type = type_not_in_combo(type_combo, particle_types)
        remaining_filter = Type([remaining_type])
        for i in [0, 1]:
            type_filter1 = Type([type_combo[i]])
            type_filter2 = Type([type_combo[i - 1]])
            difference_filter = SetDifference(combo_filter, type_filter1)
            assert difference_filter(sim.state) == type_filter2(sim.state)
            difference_filter = SetDifference(combo_filter, type_filter2)
            assert difference_filter(sim.state) == type_filter1(sim.state)
        difference_filter = SetDifference(combo_filter, remaining_filter)
        assert difference_filter(sim.state) == combo_filter(sim.state)
Пример #3
0
def test_union(make_filter_snapshot, simulation_factory, set_indices):
    particle_types = ['A', 'B', 'C']
    N = 10
    filter_snapshot = make_filter_snapshot(n=N, particle_types=particle_types)
    sim = simulation_factory(filter_snapshot)
    A_indices, B_indices, C_indices = set_indices
    s = sim.state.get_snapshot()
    if s.communicator.rank == 0:
        set_types(s, A_indices, particle_types, "A")
        set_types(s, B_indices, particle_types, "B")
        set_types(s, C_indices, particle_types, "C")
    sim.state.set_snapshot(s)

    for type_combo in combinations(particle_types, 2):
        filter1 = Type([type_combo[0]])
        filter2 = Type([type_combo[1]])
        combo_filter = Type(type_combo)
        union_filter = Union(filter1, filter2)
        assert union_filter(sim.state) == combo_filter(sim.state)
Пример #4
0
def test_intersection(make_filter_snapshot, simulation_factory, set_indices):
    particle_types = ['A', 'B', 'C']
    N = 10
    filter_snapshot = make_filter_snapshot(n=N, particle_types=particle_types)
    sim = simulation_factory(filter_snapshot)
    A_indices, B_indices, C_indices = set_indices
    s = sim.state.get_snapshot()
    if s.communicator.rank == 0:
        set_types(s, A_indices, particle_types, "A")
        set_types(s, B_indices, particle_types, "B")
        set_types(s, C_indices, particle_types, "C")
    sim.state.set_snapshot(s)

    for type_combo in combinations(particle_types, 2):
        combo_filter = Type(type_combo)
        for particle_type in type_combo:
            type_filter = Type([particle_type])
            intersection_filter = Intersection(combo_filter, type_filter)
            assert intersection_filter(sim.state) == type_filter(sim.state)
        remaining_type = type_not_in_combo(type_combo, particle_types)
        remaining_filter = Type([remaining_type])
        intersection_filter = Intersection(combo_filter, remaining_filter)
        assert intersection_filter(sim.state) == []
Пример #5
0
_filter_classes = [
    All,
    Tags,
    Type,
    Rigid,
    SetDifference,
    Union,
    Intersection,
]

_constructor_args = [
    (),
    ([1, 2, 3], ),
    ({'a', 'b'}, ),
    (('center', 'free'), ),
    (Tags([1, 4, 5]), Type({'a'})),
    (Tags([1, 4, 5]), Type({'a'})),
    (Tags([1, 4, 5]), Type({'a'})),
]


@pytest.mark.parametrize('constructor, args',
                         zip(_filter_classes, _constructor_args),
                         ids=lambda x: None
                         if isinstance(x, tuple) else x.__name__)
def test_pickling(constructor, args):
    filter_ = constructor(*args)
    pickled_filter = pickle.loads(pickle.dumps(filter_))
    assert pickled_filter == filter_
    assert hash(pickled_filter) == hash(filter_)