def test_type_filter(make_filter_snapshot, simulation_factory, type_indices): particle_types = ['A', 'B'] N = 10 filter_snapshot = make_filter_snapshot(n=N, particle_types=particle_types) sim = simulation_factory(filter_snapshot) A_filter = Type(["A"]) B_filter = Type(["B"]) AB_filter = Type(["A", "B"]) assert A_filter(sim.state) == list(range(N)) assert B_filter(sim.state) == [] s = sim.state.get_snapshot() if s.communicator.rank == 0: set_types(s, range(N), particle_types, "B") sim.state.set_snapshot(s) assert A_filter(sim.state) == [] assert B_filter(sim.state) == list(range(N)) A_indices, B_indices = type_indices s = sim.state.get_snapshot() if s.communicator.rank == 0: set_types(s, A_indices, particle_types, "A") set_types(s, B_indices, particle_types, "B") sim.state.set_snapshot(s) assert A_filter(sim.state) == A_indices assert B_filter(sim.state) == B_indices assert AB_filter(sim.state) == list(range(N))
def test_difference(make_filter_snapshot, simulation_factory, set_indices): particle_types = ['A', 'B', 'C'] N = 10 filter_snapshot = make_filter_snapshot(n=N, particle_types=particle_types) sim = simulation_factory(filter_snapshot) A_indices, B_indices, C_indices = set_indices s = sim.state.get_snapshot() if s.communicator.rank == 0: set_types(s, A_indices, particle_types, "A") set_types(s, B_indices, particle_types, "B") set_types(s, C_indices, particle_types, "C") sim.state.set_snapshot(s) for type_combo in combinations(particle_types, 2): combo_filter = Type(type_combo) remaining_type = type_not_in_combo(type_combo, particle_types) remaining_filter = Type([remaining_type]) for i in [0, 1]: type_filter1 = Type([type_combo[i]]) type_filter2 = Type([type_combo[i - 1]]) difference_filter = SetDifference(combo_filter, type_filter1) assert difference_filter(sim.state) == type_filter2(sim.state) difference_filter = SetDifference(combo_filter, type_filter2) assert difference_filter(sim.state) == type_filter1(sim.state) difference_filter = SetDifference(combo_filter, remaining_filter) assert difference_filter(sim.state) == combo_filter(sim.state)
def test_union(make_filter_snapshot, simulation_factory, set_indices): particle_types = ['A', 'B', 'C'] N = 10 filter_snapshot = make_filter_snapshot(n=N, particle_types=particle_types) sim = simulation_factory(filter_snapshot) A_indices, B_indices, C_indices = set_indices s = sim.state.get_snapshot() if s.communicator.rank == 0: set_types(s, A_indices, particle_types, "A") set_types(s, B_indices, particle_types, "B") set_types(s, C_indices, particle_types, "C") sim.state.set_snapshot(s) for type_combo in combinations(particle_types, 2): filter1 = Type([type_combo[0]]) filter2 = Type([type_combo[1]]) combo_filter = Type(type_combo) union_filter = Union(filter1, filter2) assert union_filter(sim.state) == combo_filter(sim.state)
def test_intersection(make_filter_snapshot, simulation_factory, set_indices): particle_types = ['A', 'B', 'C'] N = 10 filter_snapshot = make_filter_snapshot(n=N, particle_types=particle_types) sim = simulation_factory(filter_snapshot) A_indices, B_indices, C_indices = set_indices s = sim.state.get_snapshot() if s.communicator.rank == 0: set_types(s, A_indices, particle_types, "A") set_types(s, B_indices, particle_types, "B") set_types(s, C_indices, particle_types, "C") sim.state.set_snapshot(s) for type_combo in combinations(particle_types, 2): combo_filter = Type(type_combo) for particle_type in type_combo: type_filter = Type([particle_type]) intersection_filter = Intersection(combo_filter, type_filter) assert intersection_filter(sim.state) == type_filter(sim.state) remaining_type = type_not_in_combo(type_combo, particle_types) remaining_filter = Type([remaining_type]) intersection_filter = Intersection(combo_filter, remaining_filter) assert intersection_filter(sim.state) == []
_filter_classes = [ All, Tags, Type, Rigid, SetDifference, Union, Intersection, ] _constructor_args = [ (), ([1, 2, 3], ), ({'a', 'b'}, ), (('center', 'free'), ), (Tags([1, 4, 5]), Type({'a'})), (Tags([1, 4, 5]), Type({'a'})), (Tags([1, 4, 5]), Type({'a'})), ] @pytest.mark.parametrize('constructor, args', zip(_filter_classes, _constructor_args), ids=lambda x: None if isinstance(x, tuple) else x.__name__) def test_pickling(constructor, args): filter_ = constructor(*args) pickled_filter = pickle.loads(pickle.dumps(filter_)) assert pickled_filter == filter_ assert hash(pickled_filter) == hash(filter_)