def test_record_events(self): timestamp = str(time.time()).replace(".", "_") temp_file_path = f'{tempfile.gettempdir()}/{timestamp}.txt' try: EventBuffer(record_events=True, record_path=None) self.assertTrue(False) except ValueError: pass eb = EventBuffer(record_events=True, record_path=temp_file_path) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.execute(1) eb.reset() eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.execute(1) del eb with open(temp_file_path, "r") as input_stream: texts = input_stream.readlines() self.assertListEqual(texts, [ 'episode,tick,event_type,payload\n', '0,1,1,"(1, 3)"\n', '0,1,1,"(1, 3)"\n', '0,1,1,"(1, 3)"\n', '0,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n' ])
def test_disable_finished_events(self): eb = EventBuffer(disable_finished_events=True) self.assertListEqual([], eb.get_finished_events(), msg="finished pool should be empty") eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.execute(1) # after dispatching, finish pool should still contains no object self.assertListEqual([], eb.get_finished_events(), msg="finished pool should be empty")
def next_step(eb: EventBuffer, be: AbsBusinessEngine, tick: int): if tick > 0: # lets post process last tick first before start a new tick is_done = be.post_step(tick - 1) if is_done: return True be.step(tick) pending_events = eb.execute(tick) if len(pending_events) != 0: for evt in pending_events: evt.state = EventState.FINISHED eb.execute(tick) be.frame.take_snapshot(tick) return False
class TestEventBuffer(unittest.TestCase): def setUp(self): self.eb = EventBuffer() def test_gen_event(self): """Test event generating correct""" evt = self.eb.gen_atom_event(1, 1, (0, 0)) # fields should be same as specified self.assertEqual(evt.category, EventCategory.ATOM) self.assertEqual(evt.tick, 1) self.assertEqual(evt.event_type , 1) self.assertEqual(evt.payload, (0, 0)) evt = self.eb.gen_cascade_event(2, 2, (1, 1, 1)) self.assertEqual(evt.category, EventCategory.CASCADE) self.assertEqual(evt.tick, 2) self.assertEqual(evt.event_type , 2) self.assertEqual(evt.payload, (1, 1, 1)) def test_insert_event(self): """Test insert event works as expected""" # pending pool should be empty at beginning self.assertEqual(len(self.eb._pending_events), 0) evt = self.eb.gen_atom_event(1, 1, 1) self.eb.insert_event(evt) # after insert one event, we should have 1 in pending pool self.assertEqual(len(self.eb._pending_events), 1) def test_event_dispatch(self): """Test event dispatching work as expected""" def cb(evt): # test event tick self.assertEqual(1, evt.tick, msg="recieved event tick should be 1") # test event payload self.assertTupleEqual((1, 3), evt.payload, msg="recieved event's payload should be (1, 3)") evt = self.eb.gen_atom_event(1, 1, (1, 3)) self.eb.insert_event(evt) self.eb.register_event_handler(1, cb) self.eb.execute(1) # dispatch event def test_get_finish_events(self): """Test if we can get correct finished events""" # no finised at first self.assertListEqual([], self.eb.get_finished_events(), msg="finished pool should be empty") evt = self.eb.gen_atom_event(1, 1, (1, 3)) self.eb.insert_event(evt) self.eb.execute(1) # after dispatching, finish pool should contains 1 object self.assertEqual(1, len(self.eb.get_finished_events()), msg="after dispathing, there should 1 object") def test_get_pending_events(self): """Test if we can get correct pending events""" # not pending at first self.assertEqual(0, len(self.eb.get_pending_events(1)), msg="pending pool should be empty") evt = self.eb.gen_atom_event(1, 1, (1, 3)) self.eb.insert_event(evt) self.assertEqual(1, len(self.eb.get_pending_events(1)), msg="pending pool should contains 1 objects") def test_reset(self): """Test reset, all internal states should be reset""" evt = self.eb.gen_atom_event(1, 1, 1) self.eb.insert_event(evt) self.eb.reset() self.assertEqual(len(self.eb._pending_events), 0) self.assertEqual(len(self.eb._finished_events), 0)
class TestEventBuffer(unittest.TestCase): def setUp(self): self.eb = EventBuffer() def test_cascade_event(self): evt = CascadeEvent(0, 1, None, None) self.assertEqual(type(evt.immediate_event_head), DummyEvent) self.assertIsNone(evt.immediate_event_head.next_event) self.assertIsNone(evt.immediate_event_tail) self.assertEqual(evt.immediate_event_head.next_event, evt.immediate_event_tail) self.assertEqual(evt.immediate_event_count, 0) evt.add_immediate_event(AtomEvent(1, 1, None, None), is_head=False) evt.add_immediate_event(AtomEvent(2, 1, None, None), is_head=False) evt.add_immediate_event(AtomEvent(3, 1, None, None), is_head=True) evt.add_immediate_event(AtomEvent(4, 1, None, None), is_head=True) evt.add_immediate_event(AtomEvent(5, 1, None, None), is_head=False) evt.add_immediate_event(AtomEvent(6, 1, None, None), is_head=False) evt.add_immediate_event(AtomEvent(7, 1, None, None), is_head=True) evt.add_immediate_event(AtomEvent(8, 1, None, None), is_head=True) self.assertEqual(evt.immediate_event_count, 8) # Should be declined because the tick of the events are not equal self.assertTrue( not evt.add_immediate_event(AtomEvent(9, 2, None, None))) iter_evt: Optional[ActualEvent] = evt.immediate_event_head.next_event event_ids = [] while iter_evt is not None: event_ids.append(iter_evt.id) iter_evt = iter_evt.next_event self.assertListEqual(event_ids, [8, 7, 4, 3, 1, 2, 5, 6]) evt.clear() self.assertIsNone(evt.immediate_event_head.next_event) self.assertIsNone(evt.immediate_event_tail) self.assertEqual(evt.immediate_event_head.next_event, evt.immediate_event_tail) self.assertEqual(evt.immediate_event_count, 0) def test_event_linked_list(self): event_linked_list = EventLinkedList() self.assertEqual(len(event_linked_list), 0) self.assertListEqual([evt for evt in event_linked_list], []) evt_list = [CascadeEvent(i, None, None, None) for i in range(7)] evt_list[0].add_immediate_event(evt_list[3]) evt_list[0].add_immediate_event(evt_list[4]) evt_list[0].add_immediate_event(evt_list[5]) evt_list[0].add_immediate_event(evt_list[6]) event_linked_list.append_head(evt_list[1]) event_linked_list.append_tail(evt_list[2]) event_linked_list.append_head(evt_list[0]) self.assertEqual(len(event_linked_list), 3) event_ids = [event.id for event in event_linked_list] self.assertListEqual(event_ids, [0, 1, 2]) evt = event_linked_list.clear_finished_and_get_front() self.assertEqual(evt.id, 0) # Test `_clear_finished_events()` evt_list[0].state = EventState.FINISHED evt = event_linked_list.clear_finished_and_get_front() self.assertIsInstance(evt, ActualEvent) self.assertEqual(evt.id, 3) self.assertEqual(len(event_linked_list), 6) self.assertListEqual([evt.id for evt in event_linked_list], [3, 4, 5, 6, 1, 2]) evt_list[3].event_type = MaroEvents.PENDING_DECISION evt_list[4].event_type = MaroEvents.PENDING_DECISION evt_list[5].event_type = MaroEvents.PENDING_DECISION evts = event_linked_list.clear_finished_and_get_front() self.assertTrue(all(isinstance(evt, ActualEvent) for evt in evts)) self.assertEqual(len(evts), 3) self.assertListEqual([evt.id for evt in evts], [3, 4, 5]) self.assertListEqual([evt.id for evt in event_linked_list], [3, 4, 5, 6, 1, 2]) event_linked_list.clear() self.assertEqual(len(event_linked_list), 0) self.assertListEqual([evt for evt in event_linked_list], []) def test_event_pool(self): ep = EventPool() cascade_events = [CascadeEvent(i, None, None, None) for i in range(5)] atom_events = [AtomEvent(i, None, None, None) for i in range(5, 10)] for evt in cascade_events: ep.recycle(evt) ep.recycle(atom_events) self.assertEqual(ep.atom_event_count, 5) self.assertEqual(ep.cascade_event_count, 5) for i in range(5): is_cascade = i % 2 == 0 evt = ep.gen(tick=i, event_type=-1, payload=-1, is_cascade=is_cascade) self.assertEqual(evt.id, i) self.assertEqual(evt.tick, i) self.assertEqual(evt.event_type, -1) self.assertEqual(evt.payload, -1) self.assertIsInstance(evt, CascadeEvent if is_cascade else AtomEvent) self.assertEqual(ep.atom_event_count, 3) self.assertEqual(ep.cascade_event_count, 2) def test_gen_event(self): """Test event generating correct""" evt = self.eb.gen_atom_event(1, 1, (0, 0)) # fields should be same as specified self.assertEqual(AtomEvent, type(evt)) self.assertEqual(evt.tick, 1) self.assertEqual(evt.event_type, 1) self.assertEqual(evt.payload, (0, 0)) evt = self.eb.gen_cascade_event(2, 2, (1, 1, 1)) self.assertEqual(CascadeEvent, type(evt)) self.assertEqual(evt.tick, 2) self.assertEqual(evt.event_type, 2) self.assertEqual(evt.payload, (1, 1, 1)) def test_insert_event(self): """Test insert event works as expected""" # pending pool should be empty at beginning self.assertEqual(len(self.eb._pending_events), 0) evt = self.eb.gen_atom_event(1, 1, 1) self.eb.insert_event(evt) # after insert one event, we should have 1 in pending pool self.assertEqual(len(self.eb._pending_events), 1) def test_event_dispatch(self): """Test event dispatching work as expected""" def cb(evt): # test event tick self.assertEqual(1, evt.tick, msg="received event tick should be 1") # test event payload self.assertTupleEqual( (1, 3), evt.payload, msg="received event's payload should be (1, 3)") evt = self.eb.gen_atom_event(1, 1, (1, 3)) self.eb.insert_event(evt) self.eb.register_event_handler(1, cb) self.eb.execute(1) # dispatch event def test_get_finish_events(self): """Test if we can get correct finished events""" # no finished at first self.assertListEqual([], self.eb.get_finished_events(), msg="finished pool should be empty") evt = self.eb.gen_atom_event(1, 1, (1, 3)) self.eb.insert_event(evt) self.eb.execute(1) # after dispatching, finish pool should contains 1 object self.assertEqual(1, len(self.eb.get_finished_events()), msg="after dispatching, there should 1 object") def test_get_pending_events(self): """Test if we can get correct pending events""" # not pending at first self.assertEqual(0, len(self.eb.get_pending_events(1)), msg="pending pool should be empty") evt = self.eb.gen_atom_event(1, 1, (1, 3)) self.eb.insert_event(evt) self.assertEqual(1, len(self.eb.get_pending_events(1)), msg="pending pool should contains 1 objects") def test_reset(self): """Test reset, all internal states should be reset""" evt = self.eb.gen_atom_event(1, 1, 1) self.eb.insert_event(evt) self.eb.reset() # reset will not clear the tick (key), just clear the pending pool self.assertEqual(len(self.eb._pending_events), 1) for tick, pending_pool in self.eb._pending_events.items(): self.assertEqual(0, len(pending_pool)) self.assertEqual(len(self.eb._finished_events), 0) def test_sub_events(self): def cb1(evt): self.assertEqual(1, evt.payload) def cb2(evt): self.assertEqual(2, evt.payload) self.eb.register_event_handler(1, cb1) self.eb.register_event_handler(2, cb2) evt: CascadeEvent = self.eb.gen_cascade_event(1, 1, 1) evt.add_immediate_event(self.eb.gen_atom_event(1, 2, 2)) self.eb.insert_event(evt) self.eb.execute(1) def test_sub_events_with_decision(self): evt1 = self.eb.gen_decision_event(1, (1, 1, 1)) sub1 = self.eb.gen_decision_event(1, (2, 2, 2)) sub2 = self.eb.gen_decision_event(1, (3, 3, 3)) evt1.add_immediate_event(sub1, is_head=True) evt1.add_immediate_event(sub2) self.eb.insert_event(evt1) # sub events will be unfold after parent being processed decision_events = self.eb.execute(1) # so we will get 1 decision events for 1st time executing self.assertEqual(1, len(decision_events)) self.assertEqual(evt1, decision_events[0]) # mark decision event as executing to make it process following events decision_events[0].state = EventState.FINISHED # then there will be 2 additional decision event from sub events decision_events = self.eb.execute(1) self.assertEqual(2, len(decision_events)) self.assertEqual(sub1, decision_events[0]) self.assertEqual(sub2, decision_events[1]) def test_disable_finished_events(self): eb = EventBuffer(disable_finished_events=True) self.assertListEqual([], eb.get_finished_events(), msg="finished pool should be empty") eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.execute(1) # after dispatching, finish pool should still contains no object self.assertListEqual([], eb.get_finished_events(), msg="finished pool should be empty") def test_record_events(self): timestamp = str(time.time()).replace(".", "_") temp_file_path = f'{tempfile.gettempdir()}/{timestamp}.txt' try: EventBuffer(record_events=True, record_path=None) self.assertTrue(False) except ValueError: pass eb = EventBuffer(record_events=True, record_path=temp_file_path) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.execute(1) eb.reset() eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.insert_event(eb.gen_atom_event(1, 1, (1, 3))) eb.execute(1) del eb with open(temp_file_path, "r") as input_stream: texts = input_stream.readlines() self.assertListEqual(texts, [ 'episode,tick,event_type,payload\n', '0,1,1,"(1, 3)"\n', '0,1,1,"(1, 3)"\n', '0,1,1,"(1, 3)"\n', '0,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n' ])
class Env(AbsEnv): """Default environment implementation using generator. Args: scenario (str): Scenario name under maro/simulator/scenarios folder. topology (str): Topology name under specified scenario folder. If it points to an existing folder, the corresponding topology will be used for the built-in scenario. start_tick (int): Start tick of the scenario, usually used for pre-processed data streaming. durations (int): Duration ticks of this environment from start_tick. snapshot_resolution (int): How many ticks will take a snapshot. max_snapshots(int): Max in-memory snapshot number. When the number of dumped snapshots reached the limitation, oldest one will be overwrote by new one. None means keeping all snapshots in memory. Defaults to None. business_engine_cls (type): Class of business engine. If specified, use it to construct the be instance, or search internally by scenario. disable_finished_events (bool): Disable finished events list, with this set to True, EventBuffer will re-use finished event object, this reduce event object number. record_finished_events (bool): If record finished events into csv file, default is False. record_file_path (str): Where to save the recording file, only work if record_finished_events is True. options (dict): Additional parameters passed to business engine. """ def __init__(self, scenario: str = None, topology: str = None, start_tick: int = 0, durations: int = 100, snapshot_resolution: int = 1, max_snapshots: int = None, decision_mode: DecisionMode = DecisionMode.Sequential, business_engine_cls: type = None, disable_finished_events: bool = False, record_finished_events: bool = False, record_file_path: str = None, options: Optional[dict] = None) -> None: super().__init__(scenario, topology, start_tick, durations, snapshot_resolution, max_snapshots, decision_mode, business_engine_cls, disable_finished_events, options if options is not None else {}) self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None \ else business_engine_cls.__name__ self._event_buffer = EventBuffer(disable_finished_events, record_finished_events, record_file_path) # decision_events array for dump. self._decision_events = [] # The generator used to push the simulator forward. self._simulate_generator = self._simulate() # Initialize the business engine. self._init_business_engine() if "enable-dump-snapshot" in self._additional_options: parent_path = self._additional_options["enable-dump-snapshot"] self._converter = DumpConverter( parent_path, self._business_engine.scenario_name) self._converter.reset_folder_path() self._streamit_episode = 0 def step( self, action ) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]: """Push the environment to next step with action. Args: action (Action): Action(s) from agent. Returns: tuple: a tuple of (metrics, decision event, is_done). """ try: metrics, decision_event, _is_done = self._simulate_generator.send( action) except StopIteration: return None, None, True return metrics, decision_event, _is_done def dump(self) -> None: """Dump environment for restore. NOTE: Not implemented. """ return def reset(self, keep_seed: bool = False) -> None: """Reset environment. Args: keep_seed (bool): Reset the random seed to the generate the same data sequence or not. Defaults to False. """ self._tick = self._start_tick self._simulate_generator.close() self._simulate_generator = self._simulate() self._event_buffer.reset() if "enable-dump-snapshot" in self._additional_options and self._business_engine.frame is not None: dump_folder = self._converter.get_new_snapshot_folder() self._business_engine.frame.dump(dump_folder) self._converter.start_processing(self.configs) self._converter.dump_descsion_events(self._decision_events, self._start_tick, self._snapshot_resolution) self._business_engine.dump(dump_folder) self._decision_events.clear() self._business_engine.reset(keep_seed) @property def configs(self) -> dict: """dict: Configurations of current environment.""" return self._business_engine.configs @property def summary(self) -> dict: """dict: Summary about current simulator, including node details and mappings.""" return { "node_mapping": self._business_engine.get_node_mapping(), "node_detail": self.current_frame.get_node_info(), "event_payload": self._business_engine.get_event_payload_detail() } @property def name(self) -> str: """str: Name of current environment.""" return self._name @property def current_frame(self) -> FrameBase: """Frame: Frame of current environment.""" return self._business_engine.frame @property def tick(self) -> int: """int: Current tick of environment.""" return self._tick @property def frame_index(self) -> int: """int: Frame index in snapshot list for current tick.""" return tick_to_frame_index(self._start_tick, self._tick, self._snapshot_resolution) @property def snapshot_list(self) -> SnapshotList: """SnapshotList: A snapshot list containing all the snapshots of frame at each dump point. NOTE: Due to different environment configurations, the resolution of the snapshot may be different. """ return self._business_engine.snapshots @property def agent_idx_list(self) -> List[int]: """List[int]: Agent index list that related to this environment.""" return self._business_engine.get_agent_idx_list() def set_seed(self, seed: int) -> None: """Set random seed used by simulator. NOTE: This will not set seed for Python random or other packages' seed, such as NumPy. Args: seed (int): Seed to set. """ if seed is not None: random.seed(seed) @property def metrics(self) -> dict: """Some statistics information provided by business engine. Returns: dict: Dictionary of metrics, content and format is determined by business engine. """ return self._business_engine.get_metrics() def get_finished_events(self) -> List[ActualEvent]: """List[Event]: All events finished so far.""" return self._event_buffer.get_finished_events() def get_pending_events(self, tick) -> List[ActualEvent]: """Pending events at certain tick. Args: tick (int): Specified tick to query. """ return self._event_buffer.get_pending_events(tick) def _init_business_engine(self) -> None: """Initialize business engine object. NOTE: 1. For built-in scenarios, they will always under "maro/simulator/scenarios" folder. 2. For external scenarios, the business engine instance is built with the loaded business engine class. """ max_tick = self._start_tick + self._durations if self._business_engine_cls is not None: business_class = self._business_engine_cls else: # Combine the business engine import path. business_class_path = f'maro.simulator.scenarios.{self._scenario}.business_engine' # Load the module to find business engine for that scenario. business_module = import_module(business_class_path) business_class = None for _, obj in getmembers(business_module, isclass): if issubclass(obj, AbsBusinessEngine) and obj != AbsBusinessEngine: # We find it. business_class = obj break if business_class is None: raise BusinessEngineNotFoundError() self._business_engine: AbsBusinessEngine = business_class( event_buffer=self._event_buffer, topology=self._topology, start_tick=self._start_tick, max_tick=max_tick, snapshot_resolution=self._snapshot_resolution, max_snapshots=self._max_snapshots, additional_options=self._additional_options) def _simulate( self) -> Generator[Tuple[dict, List[object], bool], object, None]: """This is the generator to wrap each episode process.""" self._streamit_episode += 1 streamit.episode(self._streamit_episode) while True: # Ask business engine to do thing for this tick, such as generating and pushing events. # We do not push events now. streamit.tick(self._tick) self._business_engine.step(self._tick) while True: # Keep processing events, until no more events in this tick. pending_events = self._event_buffer.execute(self._tick) if len(pending_events) == 0: # We have processed all the event of current tick, lets go for next tick. break # Insert snapshot before each action. self._business_engine.frame.take_snapshot(self.frame_index) # Append source event id to decision events, to support sequential action in joint mode. decision_events = [event.payload for event in pending_events] decision_events = decision_events[0] if self._decision_mode == DecisionMode.Sequential \ else decision_events # Yield current state first, and waiting for action. actions = yield self._business_engine.get_metrics( ), decision_events, False # archive decision events. self._decision_events.append(decision_events) if actions is None: # Make business engine easy to work. actions = [] elif not isinstance(actions, Iterable): actions = [actions] if self._decision_mode == DecisionMode.Sequential: # Generate a new atom event first. action_event = self._event_buffer.gen_action_event( self._tick, actions) # NOTE: decision event always be a CascadeEvent # We just append the action into sub event of first pending cascade event. event = pending_events[0] assert isinstance(event, CascadeEvent) event.state = EventState.EXECUTING event.add_immediate_event(action_event, is_head=True) else: # For joint mode, we will assign actions from beginning to end. # Then mark others pending events to finished if not sequential action mode. for i, pending_event in enumerate(pending_events): if i >= len(actions): if self._decision_mode == DecisionMode.Joint: # Ignore following pending events that have no action matched. pending_event.state = EventState.FINISHED else: # Set the state as executing, so event buffer will not pop them again. # Then insert the action to it. action = actions[i] pending_event.state = EventState.EXECUTING action_event = self._event_buffer.gen_action_event( self._tick, action) assert isinstance(pending_event, CascadeEvent) pending_event.add_immediate_event(action_event, is_head=True) # Check the end tick of the simulation to decide if we should end the simulation. is_end_tick = self._business_engine.post_step(self._tick) if is_end_tick: break self._tick += 1 # Make sure we have no missing data. if (self._tick + 1) % self._snapshot_resolution != 0: self._business_engine.frame.take_snapshot(self.frame_index) # The end. yield self._business_engine.get_metrics(), None, True
class Env(AbsEnv): """Default environment implementation using generator. Args: scenario (str): Scenario name under maro/simulator/scenarios folder. topology (str): Topology name under specified scenario folder. If it points to an existing folder, the corresponding topology will be used for the built-in scenario. start_tick (int): Start tick of the scenario, usually used for pre-processed data streaming. durations (int): Duration ticks of this environment from start_tick. snapshot_resolution (int): How many ticks will take a snapshot. max_snapshots(int): Max in-memory snapshot number. When the number of dumped snapshots reached the limitation, oldest one will be overwrote by new one. None means keeping all snapshots in memory. Defaults to None. business_engine_cls: Class of business engine. If specified, use it to construct the be instance, or search internally by scenario. options (dict): Additional parameters passed to business engine. """ def __init__(self, scenario: str = None, topology: str = None, start_tick: int = 0, durations: int = 100, snapshot_resolution: int = 1, max_snapshots: int = None, decision_mode: DecisionMode = DecisionMode.Sequential, business_engine_cls: type = None, options: dict = {}): super().__init__(scenario, topology, start_tick, durations, snapshot_resolution, max_snapshots, decision_mode, business_engine_cls, options) self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None \ else business_engine_cls.__name__ self._business_engine: AbsBusinessEngine = None self._event_buffer = EventBuffer() # The generator used to push the simulator forward. self._simulate_generator = self._simulate() # Initialize the business engine. self._init_business_engine() def step(self, action): """Push the environment to next step with action. Args: action (Action): Action(s) from agent. Returns: tuple: a tuple of (metrics, decision event, is_done). """ try: metrics, decision_event, _is_done = self._simulate_generator.send( action) except StopIteration: return None, None, True return metrics, decision_event, _is_done def dump(self): """Dump environment for restore. NOTE: Not implemented. """ return def reset(self): """Reset environment.""" self._tick = self._start_tick self._simulate_generator.close() self._simulate_generator = self._simulate() self._event_buffer.reset() self._business_engine.reset() @property def configs(self) -> dict: """dict: Configurations of current environment.""" return self._business_engine.configs @property def summary(self) -> dict: """dict: Summary about current simulator, including node details and mappings.""" return { "node_mapping": self._business_engine.get_node_mapping(), "node_detail": self.current_frame.get_node_info() } @property def name(self) -> str: """str: Name of current environment.""" return self._name @property def current_frame(self) -> FrameBase: """Frame: Frame of current environment.""" return self._business_engine.frame @property def tick(self) -> int: """int: Current tick of environment.""" return self._tick @property def frame_index(self) -> int: """int: Frame index in snapshot list for current tick.""" return tick_to_frame_index(self._start_tick, self._tick, self._snapshot_resolution) @property def snapshot_list(self) -> SnapshotList: """SnapshotList: A snapshot list containing all the snapshots of frame at each dump point. NOTE: Due to different environment configurations, the resolution of the snapshot may be different. """ return self._business_engine.snapshots @property def agent_idx_list(self) -> List[int]: """List[int]: Agent index list that related to this environment.""" return self._business_engine.get_agent_idx_list() def set_seed(self, seed: int): """Set random seed used by simulator. NOTE: This will not set seed for Python random or other packages' seed, such as NumPy. Args: seed (int): Seed to set. """ if seed is not None: sim_seed(seed) @property def metrics(self) -> dict: """Some statistics information provided by business engine. Returns: dict: Dictionary of metrics, content and format is determined by business engine. """ return self._business_engine.get_metrics() def get_finished_events(self): """List[Event]: All events finished so far.""" return self._event_buffer.get_finished_events() def get_pending_events(self, tick): """Pending events at certain tick. Args: tick (int): Specified tick to query. """ return self._event_buffer.get_pending_events(tick) def _init_business_engine(self): """Initialize business engine object. NOTE: 1. For built-in scenarios, they will always under "maro/simulator/scenarios" folder. 2. For external scenarios, the business engine instance is built with the loaded business engine class. """ max_tick = self._start_tick + self._durations if self._business_engine_cls is not None: business_class = self._business_engine_cls else: # Combine the business engine import path. business_class_path = f'maro.simulator.scenarios.{self._scenario}.business_engine' # Load the module to find business engine for that scenario. business_module = import_module(business_class_path) business_class = None for _, obj in getmembers(business_module, isclass): if issubclass(obj, AbsBusinessEngine) and obj != AbsBusinessEngine: # We find it. business_class = obj break if business_class is None: raise BusinessEngineNotFoundError() self._business_engine = business_class( event_buffer=self._event_buffer, topology=self._topology, start_tick=self._start_tick, max_tick=max_tick, snapshot_resolution=self._snapshot_resolution, max_snapshots=self._max_snapshots, additional_options=self._additional_options) def _simulate(self): """This is the generator to wrap each episode process.""" is_end_tick = False while True: # Ask business engine to do thing for this tick, such as generating and pushing events. # We do not push events now. self._business_engine.step(self._tick) while True: # Keep processing events, until no more events in this tick. pending_events = self._event_buffer.execute(self._tick) # Processing pending events. pending_event_length: int = len(pending_events) if pending_event_length == 0: # We have processed all the event of current tick, lets go for next tick. break # Insert snapshot before each action. self._business_engine.frame.take_snapshot(self.frame_index) decision_events = [] # Append source event id to decision events, to support sequential action in joint mode. for evt in pending_events: payload = evt.payload payload.source_event_id = evt.id decision_events.append(payload) decision_events = decision_events[0] if self._decision_mode == DecisionMode.Sequential \ else decision_events # Yield current state first, and waiting for action. actions = yield self._business_engine.get_metrics( ), decision_events, False if actions is None: # Make business engine easy to work. actions = [] if actions is not None and not isinstance(actions, Iterable): actions = [actions] # Generate a new atom event first. action_event = self._event_buffer.gen_atom_event( self._tick, DECISION_EVENT, actions) # We just append the action into sub event of first pending cascade event. pending_events[0].state = EventState.EXECUTING pending_events[0].immediate_event_list.append(action_event) if self._decision_mode == DecisionMode.Joint: # For joint event, we will disable following cascade event. # We expect that first action contains a src_event_id to support joint event with sequential action. action_related_event_id = None if len( actions) == 1 else getattr(actions[0], "src_event_id", None) # If the first action has a decision event attached, it means sequential action is supported. is_support_seq_action = action_related_event_id is not None if is_support_seq_action: for i in range(1, pending_event_length): if pending_events[i].id == actions[0].src_event_id: pending_events[i].state = EventState.FINISHED else: for i in range(1, pending_event_length): pending_events[i].state = EventState.FINISHED # Check the end tick of the simulation to decide if we should end the simulation. is_end_tick = self._business_engine.post_step(self._tick) if is_end_tick: break self._tick += 1 # Make sure we have no missing data. if (self._tick + 1) % self._snapshot_resolution != 0: self._business_engine.frame.take_snapshot(self.frame_index) # The end. yield self._business_engine.get_metrics(), None, True
class Env(AbsEnv): """Default environment Args: scenario (str): scenario name under maro/sim/scenarios folder topology (str): topology name under specified scenario folder, if this point to a existing folder, then it will use this as topology for built-in scenario start_tick (int): start tick of the scenario, usually used for pre-processed data streaming durations (int): duration ticks of this environment from start_tick snapshot_resolution (int): how many ticks will take a snapshot max_snapshots(int): max in-memory snapshot number, default None means keep all snapshots in memory, when taking a snapshot, if it reaches this limitation, oldest one will be overwrote. business_engine_cls : class of business engine, if specified, then use it to construct be instance, or will search internal by scenario options (dict): additional parameters passed to business engine """ def __init__(self, scenario: str = None, topology: str = None, start_tick: int = 0, durations: int = 100, snapshot_resolution: int = 1, max_snapshots: int = None, decision_mode: DecisionMode = DecisionMode.Sequential, business_engine_cls: type = None, options: dict = {}): super().__init__(scenario, topology, start_tick, durations, snapshot_resolution, max_snapshots, decision_mode, business_engine_cls, options) self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None else business_engine_cls.__name__ self._business_engine: AbsBusinessEngine = None self._event_buffer = EventBuffer() # generator to push the simulator moving on self._simulate_generator = self._simulate() # initialize business self._init_business_engine() def step(self, action): """Push the environment to next step with action Args: action (Action): Action(s) from agent Returns: (float, object, bool): a tuple of (reward, decision event, is_done) The returned tuple contains 3 fields: - reward for current action. a list of reward if the input action is a list - decision_event for sequential decision mode, or a list of decision_event - whether the episode ends """ try: reward, decision_event, _is_done = self._simulate_generator.send( action) except StopIteration: return None, None, True return reward, decision_event, _is_done def dump(self): """Dump environment for restore NOTE: not implemented """ return def reset(self): """Reset environment""" # . reset self self._tick = self._start_tick self._simulate_generator.close() self._simulate_generator = self._simulate() # . reset event buffer self._event_buffer.reset() # . ask business engine reset itself self._business_engine.reset() @property def configs(self) -> dict: """object: Configurations of current environment""" return self._business_engine.configs @property def summary(self) -> dict: """Summary about current simulator, include node details, and mappings NOTE: this is provided by scenario, so may have different format and content """ return { "node_mapping": self._business_engine.get_node_mapping(), "node_detail": self.current_frame.get_node_info() } @property def name(self) -> str: """str: Name of current environment""" return self._name @property def current_frame(self) -> FrameBase: """Frame: Frame of current environment""" return self._business_engine.frame @property def tick(self) -> int: """int: Current tick of environment""" return self._tick @property def frame_index(self) -> int: """int: frame index in snapshot list for current tick""" return tick_to_frame_index(self._start_tick, self._tick, self._snapshot_resolution) @property def snapshot_list(self) -> SnapshotList: """SnapshotList: Current snapshot list a snapshot list contains all the snapshots of frame at each tick """ return self._business_engine.snapshots @property def agent_idx_list(self) -> List[int]: """List[int]: Agent index list that related to this environment""" return self._business_engine.get_agent_idx_list() def set_seed(self, seed: int): """Set random seed used by simulator. NOTE: this will not set seed for python random or other packages' seed, such as numpy. Args: seed (int): """ if seed is not None: sim_seed(seed) @property def metrics(self) -> dict: """Some statistics information provided by business engine Returns: dict: dictionary of metrics, content and format is determined by business engine """ return self._business_engine.get_metrics() def get_finished_events(self): """List[Event]: All events finished so far """ return self._event_buffer.get_finished_events() def get_pending_events(self, tick): """ Pending events at certain tick Args: tick (int): Specified tick """ return self._event_buffer.get_pending_events(tick) def _init_business_engine(self): """Initialize business engine object. NOTE: 1. internal scenarios will always under "maro/simulator/scenarios" folder 2. external scenarios, we access the business engine class to create instance """ max_tick = self._start_tick + self._durations if self._business_engine_cls is not None: business_class = self._business_engine_cls else: # combine the business engine import path business_class_path = f'maro.simulator.scenarios.{self._scenario}.business_engine' # load the module to find business engine for that scenario business_module = import_module(business_class_path) business_class = None for _, obj in getmembers(business_module, isclass): if issubclass(obj, AbsBusinessEngine) and obj != AbsBusinessEngine: # we find it business_class = obj break if business_class is None: raise BusinessEngineNotFoundError() self._business_engine = business_class( event_buffer=self._event_buffer, topology=self._topology, start_tick=self._start_tick, max_tick=max_tick, snapshot_resolution=self._snapshot_resolution, max_snapshots=self._max_snapshots, additional_options=self._additional_options) def _simulate(self): """ this is the generator to wrap each episode process """ is_end_tick = False while True: # ask business engine to do thing for this tick, such as gen and push events # we do not push events now self._business_engine.step(self._tick) while True: # we keep process all the events, until no more any events pending_events = self._event_buffer.execute(self._tick) # processing pending events pending_event_length: int = len(pending_events) if pending_event_length == 0: # we have processed all the event of current tick, lets go for next tick break # insert snapshot before each action self._business_engine.frame.take_snapshot(self.frame_index) decision_events = [] # append source event id to decision events, to support sequential action in joint mode for evt in pending_events: payload = evt.payload payload.source_event_id = evt.id decision_events.append(payload) decision_events = decision_events[ 0] if self._decision_mode == DecisionMode.Sequential else decision_events # yield current state first, and waiting for action actions = yield self._business_engine.get_metrics( ), decision_events, False if actions is None: actions = [] # make business engine easy to work if actions is not None and not isinstance(actions, Iterable): actions = [actions] # generate a new atom event first action_event = self._event_buffer.gen_atom_event( self._tick, DECISION_EVENT, actions) # 3. we just append the action into sub event of first pending cascade event pending_events[0].state = EventState.EXECUTING pending_events[0].immediate_event_list.append(action_event) # TODO: support get reward after action complete here, via using event_buffer.execute if self._decision_mode == DecisionMode.Joint: # for joint event, we will disable following cascade event # we expect that first action contains a src_event_id to support joint event with sequential action action_related_event_id = None if len( actions) == 1 else getattr(actions[0], "src_event_id", None) # if first action have decision event attached, then means support sequential action is_support_seq_action = action_related_event_id is not None if is_support_seq_action: for i in range(1, pending_event_length): if pending_events[i].id == actions[0].src_event_id: pending_events[i].state = EventState.FINISHED else: for i in range(1, pending_event_length): pending_events[i].state = EventState.FINISHED # check if we should end simulation is_end_tick = self._business_engine.post_step(self._tick) == True if is_end_tick: break self._tick += 1 # make sure we have no missing data if (self._tick + 1) % self._snapshot_resolution != 0: self._business_engine.frame.take_snapshot(self.frame_index) # the end yield self._business_engine.get_metrics(), None, True
class TestEventBuffer(unittest.TestCase): def setUp(self): self.eb = EventBuffer() def test_gen_event(self): """Test event generating correct""" evt = self.eb.gen_atom_event(1, 1, (0, 0)) # fields should be same as specified self.assertEqual(AtomEvent, type(evt)) self.assertEqual(evt.tick, 1) self.assertEqual(evt.event_type, 1) self.assertEqual(evt.payload, (0, 0)) evt = self.eb.gen_cascade_event(2, 2, (1, 1, 1)) self.assertEqual(CascadeEvent, type(evt)) self.assertEqual(evt.tick, 2) self.assertEqual(evt.event_type, 2) self.assertEqual(evt.payload, (1, 1, 1)) def test_insert_event(self): """Test insert event works as expected""" # pending pool should be empty at beginning self.assertEqual(len(self.eb._pending_events), 0) evt = self.eb.gen_atom_event(1, 1, 1) self.eb.insert_event(evt) # after insert one event, we should have 1 in pending pool self.assertEqual(len(self.eb._pending_events), 1) def test_event_dispatch(self): """Test event dispatching work as expected""" def cb(evt): # test event tick self.assertEqual( 1, evt.tick, msg="recieved event tick should be 1") # test event payload self.assertTupleEqual( (1, 3), evt.payload, msg="recieved event's payload should be (1, 3)") evt = self.eb.gen_atom_event(1, 1, (1, 3)) self.eb.insert_event(evt) self.eb.register_event_handler(1, cb) self.eb.execute(1) # dispatch event def test_get_finish_events(self): """Test if we can get correct finished events""" # no finised at first self.assertListEqual([], self.eb.get_finished_events(), msg="finished pool should be empty") evt = self.eb.gen_atom_event(1, 1, (1, 3)) self.eb.insert_event(evt) self.eb.execute(1) # after dispatching, finish pool should contains 1 object self.assertEqual(1, len(self.eb.get_finished_events()), msg="after dispathing, there should 1 object") def test_get_pending_events(self): """Test if we can get correct pending events""" # not pending at first self.assertEqual(0, len(self.eb.get_pending_events(1)), msg="pending pool should be empty") evt = self.eb.gen_atom_event(1, 1, (1, 3)) self.eb.insert_event(evt) self.assertEqual(1, len(self.eb.get_pending_events(1)), msg="pending pool should contains 1 objects") def test_reset(self): """Test reset, all internal states should be reset""" evt = self.eb.gen_atom_event(1, 1, 1) self.eb.insert_event(evt) self.eb.reset() # reset will not clear the tick (key), just clear the pending pool self.assertEqual(len(self.eb._pending_events), 1) for tick, pending_pool in self.eb._pending_events.items(): self.assertEqual(0, len(pending_pool)) self.assertEqual(len(self.eb._finished_events), 0) def test_sub_events(self): def cb1(evt): self.assertEqual(1, evt.payload) def cb2(evt): self.assertEqual(2, evt.payload) self.eb.register_event_handler(1, cb1) self.eb.register_event_handler(2, cb2) evt: CascadeEvent = self.eb.gen_cascade_event(1, 1, 1) evt.add_immediate_event(self.eb.gen_atom_event(1, 2, 2)) self.eb.insert_event(evt) self.eb.execute(1) def test_sub_events_with_decision(self): evt1 = self.eb.gen_decision_event(1, (1, 1, 1)) sub1 = self.eb.gen_decision_event(1, (2, 2, 2)) sub2 = self.eb.gen_decision_event(1, (3, 3, 3)) evt1.add_immediate_event(sub1, is_head=True) evt1.add_immediate_event(sub2) self.eb.insert_event(evt1) # sub events will be unfold after parent being processed decision_events = self.eb.execute(1) # so we will get 1 decision events for 1st time executing self.assertEqual(1, len(decision_events)) self.assertEqual(evt1, decision_events[0]) # mark decision event as executing to make it process folloing events decision_events[0].state = EventState.FINISHED # then there will be 2 additional decision event from sub events decision_events = self.eb.execute(1) self.assertEqual(2, len(decision_events)) self.assertEqual(sub1, decision_events[0]) self.assertEqual(sub2, decision_events[1])