Exemple #1
0
class TestEventBuffer(unittest.TestCase):
    def setUp(self):
        self.eb = EventBuffer()

    def test_gen_event(self):
        """Test event generating correct"""
        evt = self.eb.gen_atom_event(1, 1, (0, 0))

        # fields should be same as specified
        self.assertEqual(evt.category, EventCategory.ATOM)
        self.assertEqual(evt.tick, 1)
        self.assertEqual(evt.event_type , 1)
        self.assertEqual(evt.payload, (0, 0))

        evt = self.eb.gen_cascade_event(2, 2, (1, 1, 1))

        self.assertEqual(evt.category, EventCategory.CASCADE)
        self.assertEqual(evt.tick, 2)
        self.assertEqual(evt.event_type , 2)
        self.assertEqual(evt.payload, (1, 1, 1))

    def test_insert_event(self):
        """Test insert event works as expected"""

        # pending pool should be empty at beginning
        self.assertEqual(len(self.eb._pending_events), 0)

        evt = self.eb.gen_atom_event(1, 1, 1)

        self.eb.insert_event(evt)

        # after insert one event, we should have 1 in pending pool
        self.assertEqual(len(self.eb._pending_events), 1)

    def test_event_dispatch(self):
        """Test event dispatching work as expected"""
        def cb(evt):
            # test event tick
            self.assertEqual(1, evt.tick, msg="recieved event tick should be 1")

            # test event payload
            self.assertTupleEqual((1, 3), evt.payload, msg="recieved event's payload should be (1, 3)")

        evt = self.eb.gen_atom_event(1, 1, (1, 3))

        self.eb.insert_event(evt)

        self.eb.register_event_handler(1, cb)

        self.eb.execute(1) # dispatch event

    def test_get_finish_events(self):
        """Test if we can get correct finished events"""

        # no finised at first
        self.assertListEqual([], self.eb.get_finished_events(), msg="finished pool should be empty")

        evt = self.eb.gen_atom_event(1, 1, (1, 3))

        self.eb.insert_event(evt)

        self.eb.execute(1)

        # after dispatching, finish pool should contains 1 object
        self.assertEqual(1, len(self.eb.get_finished_events()), msg="after dispathing, there should 1 object")

    def test_get_pending_events(self):
        """Test if we can get correct pending events"""

        # not pending at first
        self.assertEqual(0, len(self.eb.get_pending_events(1)), msg="pending pool should be empty")

        evt = self.eb.gen_atom_event(1, 1, (1, 3))

        self.eb.insert_event(evt)

        self.assertEqual(1, len(self.eb.get_pending_events(1)), msg="pending pool should contains 1 objects")

    def test_reset(self):
        """Test reset, all internal states should be reset"""
        evt = self.eb.gen_atom_event(1, 1, 1)

        self.eb.insert_event(evt)

        self.eb.reset()
        
        self.assertEqual(len(self.eb._pending_events), 0)
        self.assertEqual(len(self.eb._finished_events), 0)
Exemple #2
0
class TestEventBuffer(unittest.TestCase):
    def setUp(self):
        self.eb = EventBuffer()

    def test_cascade_event(self):
        evt = CascadeEvent(0, 1, None, None)
        self.assertEqual(type(evt.immediate_event_head), DummyEvent)
        self.assertIsNone(evt.immediate_event_head.next_event)
        self.assertIsNone(evt.immediate_event_tail)
        self.assertEqual(evt.immediate_event_head.next_event,
                         evt.immediate_event_tail)
        self.assertEqual(evt.immediate_event_count, 0)

        evt.add_immediate_event(AtomEvent(1, 1, None, None), is_head=False)
        evt.add_immediate_event(AtomEvent(2, 1, None, None), is_head=False)
        evt.add_immediate_event(AtomEvent(3, 1, None, None), is_head=True)
        evt.add_immediate_event(AtomEvent(4, 1, None, None), is_head=True)
        evt.add_immediate_event(AtomEvent(5, 1, None, None), is_head=False)
        evt.add_immediate_event(AtomEvent(6, 1, None, None), is_head=False)
        evt.add_immediate_event(AtomEvent(7, 1, None, None), is_head=True)
        evt.add_immediate_event(AtomEvent(8, 1, None, None), is_head=True)
        self.assertEqual(evt.immediate_event_count, 8)

        # Should be declined because the tick of the events are not equal
        self.assertTrue(
            not evt.add_immediate_event(AtomEvent(9, 2, None, None)))

        iter_evt: Optional[ActualEvent] = evt.immediate_event_head.next_event
        event_ids = []
        while iter_evt is not None:
            event_ids.append(iter_evt.id)
            iter_evt = iter_evt.next_event
        self.assertListEqual(event_ids, [8, 7, 4, 3, 1, 2, 5, 6])

        evt.clear()
        self.assertIsNone(evt.immediate_event_head.next_event)
        self.assertIsNone(evt.immediate_event_tail)
        self.assertEqual(evt.immediate_event_head.next_event,
                         evt.immediate_event_tail)
        self.assertEqual(evt.immediate_event_count, 0)

    def test_event_linked_list(self):
        event_linked_list = EventLinkedList()
        self.assertEqual(len(event_linked_list), 0)
        self.assertListEqual([evt for evt in event_linked_list], [])

        evt_list = [CascadeEvent(i, None, None, None) for i in range(7)]
        evt_list[0].add_immediate_event(evt_list[3])
        evt_list[0].add_immediate_event(evt_list[4])
        evt_list[0].add_immediate_event(evt_list[5])
        evt_list[0].add_immediate_event(evt_list[6])

        event_linked_list.append_head(evt_list[1])
        event_linked_list.append_tail(evt_list[2])
        event_linked_list.append_head(evt_list[0])
        self.assertEqual(len(event_linked_list), 3)

        event_ids = [event.id for event in event_linked_list]
        self.assertListEqual(event_ids, [0, 1, 2])

        evt = event_linked_list.clear_finished_and_get_front()
        self.assertEqual(evt.id, 0)

        # Test `_clear_finished_events()`
        evt_list[0].state = EventState.FINISHED
        evt = event_linked_list.clear_finished_and_get_front()
        self.assertIsInstance(evt, ActualEvent)
        self.assertEqual(evt.id, 3)
        self.assertEqual(len(event_linked_list), 6)

        self.assertListEqual([evt.id for evt in event_linked_list],
                             [3, 4, 5, 6, 1, 2])

        evt_list[3].event_type = MaroEvents.PENDING_DECISION
        evt_list[4].event_type = MaroEvents.PENDING_DECISION
        evt_list[5].event_type = MaroEvents.PENDING_DECISION
        evts = event_linked_list.clear_finished_and_get_front()
        self.assertTrue(all(isinstance(evt, ActualEvent) for evt in evts))
        self.assertEqual(len(evts), 3)
        self.assertListEqual([evt.id for evt in evts], [3, 4, 5])
        self.assertListEqual([evt.id for evt in event_linked_list],
                             [3, 4, 5, 6, 1, 2])

        event_linked_list.clear()
        self.assertEqual(len(event_linked_list), 0)
        self.assertListEqual([evt for evt in event_linked_list], [])

    def test_event_pool(self):
        ep = EventPool()
        cascade_events = [CascadeEvent(i, None, None, None) for i in range(5)]
        atom_events = [AtomEvent(i, None, None, None) for i in range(5, 10)]

        for evt in cascade_events:
            ep.recycle(evt)
        ep.recycle(atom_events)

        self.assertEqual(ep.atom_event_count, 5)
        self.assertEqual(ep.cascade_event_count, 5)

        for i in range(5):
            is_cascade = i % 2 == 0
            evt = ep.gen(tick=i,
                         event_type=-1,
                         payload=-1,
                         is_cascade=is_cascade)
            self.assertEqual(evt.id, i)
            self.assertEqual(evt.tick, i)
            self.assertEqual(evt.event_type, -1)
            self.assertEqual(evt.payload, -1)
            self.assertIsInstance(evt,
                                  CascadeEvent if is_cascade else AtomEvent)

        self.assertEqual(ep.atom_event_count, 3)
        self.assertEqual(ep.cascade_event_count, 2)

    def test_gen_event(self):
        """Test event generating correct"""
        evt = self.eb.gen_atom_event(1, 1, (0, 0))

        # fields should be same as specified
        self.assertEqual(AtomEvent, type(evt))
        self.assertEqual(evt.tick, 1)
        self.assertEqual(evt.event_type, 1)
        self.assertEqual(evt.payload, (0, 0))

        evt = self.eb.gen_cascade_event(2, 2, (1, 1, 1))

        self.assertEqual(CascadeEvent, type(evt))
        self.assertEqual(evt.tick, 2)
        self.assertEqual(evt.event_type, 2)
        self.assertEqual(evt.payload, (1, 1, 1))

    def test_insert_event(self):
        """Test insert event works as expected"""

        # pending pool should be empty at beginning
        self.assertEqual(len(self.eb._pending_events), 0)

        evt = self.eb.gen_atom_event(1, 1, 1)

        self.eb.insert_event(evt)

        # after insert one event, we should have 1 in pending pool
        self.assertEqual(len(self.eb._pending_events), 1)

    def test_event_dispatch(self):
        """Test event dispatching work as expected"""
        def cb(evt):
            # test event tick
            self.assertEqual(1,
                             evt.tick,
                             msg="received event tick should be 1")

            # test event payload
            self.assertTupleEqual(
                (1, 3),
                evt.payload,
                msg="received event's payload should be (1, 3)")

        evt = self.eb.gen_atom_event(1, 1, (1, 3))

        self.eb.insert_event(evt)

        self.eb.register_event_handler(1, cb)

        self.eb.execute(1)  # dispatch event

    def test_get_finish_events(self):
        """Test if we can get correct finished events"""

        # no finished at first
        self.assertListEqual([],
                             self.eb.get_finished_events(),
                             msg="finished pool should be empty")

        evt = self.eb.gen_atom_event(1, 1, (1, 3))

        self.eb.insert_event(evt)

        self.eb.execute(1)

        # after dispatching, finish pool should contains 1 object
        self.assertEqual(1,
                         len(self.eb.get_finished_events()),
                         msg="after dispatching, there should 1 object")

    def test_get_pending_events(self):
        """Test if we can get correct pending events"""

        # not pending at first
        self.assertEqual(0,
                         len(self.eb.get_pending_events(1)),
                         msg="pending pool should be empty")

        evt = self.eb.gen_atom_event(1, 1, (1, 3))

        self.eb.insert_event(evt)

        self.assertEqual(1,
                         len(self.eb.get_pending_events(1)),
                         msg="pending pool should contains 1 objects")

    def test_reset(self):
        """Test reset, all internal states should be reset"""
        evt = self.eb.gen_atom_event(1, 1, 1)

        self.eb.insert_event(evt)

        self.eb.reset()

        # reset will not clear the tick (key), just clear the pending pool
        self.assertEqual(len(self.eb._pending_events), 1)

        for tick, pending_pool in self.eb._pending_events.items():
            self.assertEqual(0, len(pending_pool))

        self.assertEqual(len(self.eb._finished_events), 0)

    def test_sub_events(self):
        def cb1(evt):
            self.assertEqual(1, evt.payload)

        def cb2(evt):
            self.assertEqual(2, evt.payload)

        self.eb.register_event_handler(1, cb1)
        self.eb.register_event_handler(2, cb2)

        evt: CascadeEvent = self.eb.gen_cascade_event(1, 1, 1)

        evt.add_immediate_event(self.eb.gen_atom_event(1, 2, 2))

        self.eb.insert_event(evt)

        self.eb.execute(1)

    def test_sub_events_with_decision(self):
        evt1 = self.eb.gen_decision_event(1, (1, 1, 1))
        sub1 = self.eb.gen_decision_event(1, (2, 2, 2))
        sub2 = self.eb.gen_decision_event(1, (3, 3, 3))

        evt1.add_immediate_event(sub1, is_head=True)
        evt1.add_immediate_event(sub2)

        self.eb.insert_event(evt1)

        # sub events will be unfold after parent being processed
        decision_events = self.eb.execute(1)

        # so we will get 1 decision events for 1st time executing
        self.assertEqual(1, len(decision_events))
        self.assertEqual(evt1, decision_events[0])

        # mark decision event as executing to make it process following events
        decision_events[0].state = EventState.FINISHED

        # then there will be 2 additional decision event from sub events
        decision_events = self.eb.execute(1)

        self.assertEqual(2, len(decision_events))
        self.assertEqual(sub1, decision_events[0])
        self.assertEqual(sub2, decision_events[1])

    def test_disable_finished_events(self):
        eb = EventBuffer(disable_finished_events=True)
        self.assertListEqual([],
                             eb.get_finished_events(),
                             msg="finished pool should be empty")

        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.execute(1)

        # after dispatching, finish pool should still contains no object
        self.assertListEqual([],
                             eb.get_finished_events(),
                             msg="finished pool should be empty")

    def test_record_events(self):
        timestamp = str(time.time()).replace(".", "_")
        temp_file_path = f'{tempfile.gettempdir()}/{timestamp}.txt'

        try:
            EventBuffer(record_events=True, record_path=None)
            self.assertTrue(False)
        except ValueError:
            pass

        eb = EventBuffer(record_events=True, record_path=temp_file_path)
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.execute(1)
        eb.reset()
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.insert_event(eb.gen_atom_event(1, 1, (1, 3)))
        eb.execute(1)
        del eb

        with open(temp_file_path, "r") as input_stream:
            texts = input_stream.readlines()
            self.assertListEqual(texts, [
                'episode,tick,event_type,payload\n', '0,1,1,"(1, 3)"\n',
                '0,1,1,"(1, 3)"\n', '0,1,1,"(1, 3)"\n', '0,1,1,"(1, 3)"\n',
                '1,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n', '1,1,1,"(1, 3)"\n',
                '1,1,1,"(1, 3)"\n'
            ])
Exemple #3
0
class Env(AbsEnv):
    """Default environment implementation using generator.

    Args:
        scenario (str): Scenario name under maro/simulator/scenarios folder.
        topology (str): Topology name under specified scenario folder.
            If it points to an existing folder, the corresponding topology will be used for the built-in scenario.
        start_tick (int): Start tick of the scenario, usually used for pre-processed data streaming.
        durations (int): Duration ticks of this environment from start_tick.
        snapshot_resolution (int): How many ticks will take a snapshot.
        max_snapshots(int): Max in-memory snapshot number.
            When the number of dumped snapshots reached the limitation, oldest one will be overwrote by new one.
            None means keeping all snapshots in memory. Defaults to None.
        business_engine_cls (type): Class of business engine. If specified, use it to construct the be instance,
            or search internally by scenario.
        disable_finished_events (bool): Disable finished events list, with this set to True, EventBuffer will
            re-use finished event object, this reduce event object number.
        record_finished_events (bool): If record finished events into csv file, default is False.
        record_file_path (str): Where to save the recording file, only work if record_finished_events is True.
        options (dict): Additional parameters passed to business engine.
    """
    def __init__(self,
                 scenario: str = None,
                 topology: str = None,
                 start_tick: int = 0,
                 durations: int = 100,
                 snapshot_resolution: int = 1,
                 max_snapshots: int = None,
                 decision_mode: DecisionMode = DecisionMode.Sequential,
                 business_engine_cls: type = None,
                 disable_finished_events: bool = False,
                 record_finished_events: bool = False,
                 record_file_path: str = None,
                 options: Optional[dict] = None) -> None:
        super().__init__(scenario, topology, start_tick, durations,
                         snapshot_resolution, max_snapshots, decision_mode,
                         business_engine_cls, disable_finished_events,
                         options if options is not None else {})

        self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None \
            else business_engine_cls.__name__

        self._event_buffer = EventBuffer(disable_finished_events,
                                         record_finished_events,
                                         record_file_path)

        # decision_events array for dump.
        self._decision_events = []

        # The generator used to push the simulator forward.
        self._simulate_generator = self._simulate()

        # Initialize the business engine.
        self._init_business_engine()

        if "enable-dump-snapshot" in self._additional_options:
            parent_path = self._additional_options["enable-dump-snapshot"]
            self._converter = DumpConverter(
                parent_path, self._business_engine.scenario_name)
            self._converter.reset_folder_path()

        self._streamit_episode = 0

    def step(
        self, action
    ) -> Tuple[Optional[dict], Optional[List[object]], Optional[bool]]:
        """Push the environment to next step with action.

        Args:
            action (Action): Action(s) from agent.

        Returns:
            tuple: a tuple of (metrics, decision event, is_done).
        """
        try:
            metrics, decision_event, _is_done = self._simulate_generator.send(
                action)
        except StopIteration:
            return None, None, True

        return metrics, decision_event, _is_done

    def dump(self) -> None:
        """Dump environment for restore.

        NOTE:
            Not implemented.
        """
        return

    def reset(self, keep_seed: bool = False) -> None:
        """Reset environment.

        Args:
            keep_seed (bool): Reset the random seed to the generate the same data sequence or not. Defaults to False.
        """
        self._tick = self._start_tick

        self._simulate_generator.close()
        self._simulate_generator = self._simulate()

        self._event_buffer.reset()

        if "enable-dump-snapshot" in self._additional_options and self._business_engine.frame is not None:
            dump_folder = self._converter.get_new_snapshot_folder()

            self._business_engine.frame.dump(dump_folder)
            self._converter.start_processing(self.configs)
            self._converter.dump_descsion_events(self._decision_events,
                                                 self._start_tick,
                                                 self._snapshot_resolution)
            self._business_engine.dump(dump_folder)

        self._decision_events.clear()

        self._business_engine.reset(keep_seed)

    @property
    def configs(self) -> dict:
        """dict: Configurations of current environment."""
        return self._business_engine.configs

    @property
    def summary(self) -> dict:
        """dict: Summary about current simulator, including node details and mappings."""
        return {
            "node_mapping": self._business_engine.get_node_mapping(),
            "node_detail": self.current_frame.get_node_info(),
            "event_payload": self._business_engine.get_event_payload_detail()
        }

    @property
    def name(self) -> str:
        """str: Name of current environment."""
        return self._name

    @property
    def current_frame(self) -> FrameBase:
        """Frame: Frame of current environment."""
        return self._business_engine.frame

    @property
    def tick(self) -> int:
        """int: Current tick of environment."""
        return self._tick

    @property
    def frame_index(self) -> int:
        """int: Frame index in snapshot list for current tick."""
        return tick_to_frame_index(self._start_tick, self._tick,
                                   self._snapshot_resolution)

    @property
    def snapshot_list(self) -> SnapshotList:
        """SnapshotList: A snapshot list containing all the snapshots of frame at each dump point.

        NOTE: Due to different environment configurations, the resolution of the snapshot may be different.
        """
        return self._business_engine.snapshots

    @property
    def agent_idx_list(self) -> List[int]:
        """List[int]: Agent index list that related to this environment."""
        return self._business_engine.get_agent_idx_list()

    def set_seed(self, seed: int) -> None:
        """Set random seed used by simulator.

        NOTE:
            This will not set seed for Python random or other packages' seed, such as NumPy.

        Args:
            seed (int): Seed to set.
        """

        if seed is not None:
            random.seed(seed)

    @property
    def metrics(self) -> dict:
        """Some statistics information provided by business engine.

        Returns:
            dict: Dictionary of metrics, content and format is determined by business engine.
        """

        return self._business_engine.get_metrics()

    def get_finished_events(self) -> List[ActualEvent]:
        """List[Event]: All events finished so far."""
        return self._event_buffer.get_finished_events()

    def get_pending_events(self, tick) -> List[ActualEvent]:
        """Pending events at certain tick.

        Args:
            tick (int): Specified tick to query.
        """
        return self._event_buffer.get_pending_events(tick)

    def _init_business_engine(self) -> None:
        """Initialize business engine object.

        NOTE:
        1. For built-in scenarios, they will always under "maro/simulator/scenarios" folder.
        2. For external scenarios, the business engine instance is built with the loaded business engine class.
        """
        max_tick = self._start_tick + self._durations

        if self._business_engine_cls is not None:
            business_class = self._business_engine_cls
        else:
            # Combine the business engine import path.
            business_class_path = f'maro.simulator.scenarios.{self._scenario}.business_engine'

            # Load the module to find business engine for that scenario.
            business_module = import_module(business_class_path)

            business_class = None

            for _, obj in getmembers(business_module, isclass):
                if issubclass(obj,
                              AbsBusinessEngine) and obj != AbsBusinessEngine:
                    # We find it.
                    business_class = obj

                    break

            if business_class is None:
                raise BusinessEngineNotFoundError()

        self._business_engine: AbsBusinessEngine = business_class(
            event_buffer=self._event_buffer,
            topology=self._topology,
            start_tick=self._start_tick,
            max_tick=max_tick,
            snapshot_resolution=self._snapshot_resolution,
            max_snapshots=self._max_snapshots,
            additional_options=self._additional_options)

    def _simulate(
            self) -> Generator[Tuple[dict, List[object], bool], object, None]:
        """This is the generator to wrap each episode process."""
        self._streamit_episode += 1

        streamit.episode(self._streamit_episode)

        while True:
            # Ask business engine to do thing for this tick, such as generating and pushing events.
            # We do not push events now.
            streamit.tick(self._tick)

            self._business_engine.step(self._tick)

            while True:
                # Keep processing events, until no more events in this tick.
                pending_events = self._event_buffer.execute(self._tick)

                if len(pending_events) == 0:
                    # We have processed all the event of current tick, lets go for next tick.
                    break

                # Insert snapshot before each action.
                self._business_engine.frame.take_snapshot(self.frame_index)

                # Append source event id to decision events, to support sequential action in joint mode.
                decision_events = [event.payload for event in pending_events]

                decision_events = decision_events[0] if self._decision_mode == DecisionMode.Sequential \
                    else decision_events

                # Yield current state first, and waiting for action.
                actions = yield self._business_engine.get_metrics(
                ), decision_events, False
                # archive decision events.
                self._decision_events.append(decision_events)

                if actions is None:
                    # Make business engine easy to work.
                    actions = []
                elif not isinstance(actions, Iterable):
                    actions = [actions]

                if self._decision_mode == DecisionMode.Sequential:
                    # Generate a new atom event first.
                    action_event = self._event_buffer.gen_action_event(
                        self._tick, actions)

                    # NOTE: decision event always be a CascadeEvent
                    # We just append the action into sub event of first pending cascade event.
                    event = pending_events[0]
                    assert isinstance(event, CascadeEvent)
                    event.state = EventState.EXECUTING
                    event.add_immediate_event(action_event, is_head=True)
                else:
                    # For joint mode, we will assign actions from beginning to end.
                    # Then mark others pending events to finished if not sequential action mode.
                    for i, pending_event in enumerate(pending_events):
                        if i >= len(actions):
                            if self._decision_mode == DecisionMode.Joint:
                                # Ignore following pending events that have no action matched.
                                pending_event.state = EventState.FINISHED
                        else:
                            # Set the state as executing, so event buffer will not pop them again.
                            # Then insert the action to it.
                            action = actions[i]
                            pending_event.state = EventState.EXECUTING
                            action_event = self._event_buffer.gen_action_event(
                                self._tick, action)

                            assert isinstance(pending_event, CascadeEvent)
                            pending_event.add_immediate_event(action_event,
                                                              is_head=True)

            # Check the end tick of the simulation to decide if we should end the simulation.
            is_end_tick = self._business_engine.post_step(self._tick)

            if is_end_tick:
                break

            self._tick += 1

        # Make sure we have no missing data.
        if (self._tick + 1) % self._snapshot_resolution != 0:
            self._business_engine.frame.take_snapshot(self.frame_index)

        # The end.
        yield self._business_engine.get_metrics(), None, True
Exemple #4
0
class Env(AbsEnv):
    """Default environment implementation using generator.

    Args:
        scenario (str): Scenario name under maro/simulator/scenarios folder.
        topology (str): Topology name under specified scenario folder.
            If it points to an existing folder, the corresponding topology will be used for the built-in scenario.
        start_tick (int): Start tick of the scenario, usually used for pre-processed data streaming.
        durations (int): Duration ticks of this environment from start_tick.
        snapshot_resolution (int): How many ticks will take a snapshot.
        max_snapshots(int): Max in-memory snapshot number.
            When the number of dumped snapshots reached the limitation, oldest one will be overwrote by new one.
            None means keeping all snapshots in memory. Defaults to None.
        business_engine_cls: Class of business engine. If specified, use it to construct the be instance,
            or search internally by scenario.
        options (dict): Additional parameters passed to business engine.
    """
    def __init__(self,
                 scenario: str = None,
                 topology: str = None,
                 start_tick: int = 0,
                 durations: int = 100,
                 snapshot_resolution: int = 1,
                 max_snapshots: int = None,
                 decision_mode: DecisionMode = DecisionMode.Sequential,
                 business_engine_cls: type = None,
                 options: dict = {}):
        super().__init__(scenario, topology, start_tick, durations,
                         snapshot_resolution, max_snapshots, decision_mode,
                         business_engine_cls, options)

        self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None \
            else business_engine_cls.__name__
        self._business_engine: AbsBusinessEngine = None

        self._event_buffer = EventBuffer()

        # The generator used to push the simulator forward.
        self._simulate_generator = self._simulate()

        # Initialize the business engine.
        self._init_business_engine()

    def step(self, action):
        """Push the environment to next step with action.

        Args:
            action (Action): Action(s) from agent.

        Returns:
            tuple: a tuple of (metrics, decision event, is_done).
        """
        try:
            metrics, decision_event, _is_done = self._simulate_generator.send(
                action)
        except StopIteration:
            return None, None, True

        return metrics, decision_event, _is_done

    def dump(self):
        """Dump environment for restore.

        NOTE:
            Not implemented.
        """
        return

    def reset(self):
        """Reset environment."""
        self._tick = self._start_tick

        self._simulate_generator.close()
        self._simulate_generator = self._simulate()

        self._event_buffer.reset()

        self._business_engine.reset()

    @property
    def configs(self) -> dict:
        """dict: Configurations of current environment."""
        return self._business_engine.configs

    @property
    def summary(self) -> dict:
        """dict: Summary about current simulator, including node details and mappings."""
        return {
            "node_mapping": self._business_engine.get_node_mapping(),
            "node_detail": self.current_frame.get_node_info()
        }

    @property
    def name(self) -> str:
        """str: Name of current environment."""
        return self._name

    @property
    def current_frame(self) -> FrameBase:
        """Frame: Frame of current environment."""
        return self._business_engine.frame

    @property
    def tick(self) -> int:
        """int: Current tick of environment."""
        return self._tick

    @property
    def frame_index(self) -> int:
        """int: Frame index in snapshot list for current tick."""
        return tick_to_frame_index(self._start_tick, self._tick,
                                   self._snapshot_resolution)

    @property
    def snapshot_list(self) -> SnapshotList:
        """SnapshotList: A snapshot list containing all the snapshots of frame at each dump point.

        NOTE: Due to different environment configurations, the resolution of the snapshot may be different.
        """
        return self._business_engine.snapshots

    @property
    def agent_idx_list(self) -> List[int]:
        """List[int]: Agent index list that related to this environment."""
        return self._business_engine.get_agent_idx_list()

    def set_seed(self, seed: int):
        """Set random seed used by simulator.

        NOTE:
            This will not set seed for Python random or other packages' seed, such as NumPy.

        Args:
            seed (int): Seed to set.
        """

        if seed is not None:
            sim_seed(seed)

    @property
    def metrics(self) -> dict:
        """Some statistics information provided by business engine.

        Returns:
            dict: Dictionary of metrics, content and format is determined by business engine.
        """

        return self._business_engine.get_metrics()

    def get_finished_events(self):
        """List[Event]: All events finished so far."""
        return self._event_buffer.get_finished_events()

    def get_pending_events(self, tick):
        """Pending events at certain tick.

        Args:
            tick (int): Specified tick to query.
        """
        return self._event_buffer.get_pending_events(tick)

    def _init_business_engine(self):
        """Initialize business engine object.

        NOTE:
        1. For built-in scenarios, they will always under "maro/simulator/scenarios" folder.
        2. For external scenarios, the business engine instance is built with the loaded business engine class.
        """
        max_tick = self._start_tick + self._durations

        if self._business_engine_cls is not None:
            business_class = self._business_engine_cls
        else:
            # Combine the business engine import path.
            business_class_path = f'maro.simulator.scenarios.{self._scenario}.business_engine'

            # Load the module to find business engine for that scenario.
            business_module = import_module(business_class_path)

            business_class = None

            for _, obj in getmembers(business_module, isclass):
                if issubclass(obj,
                              AbsBusinessEngine) and obj != AbsBusinessEngine:
                    # We find it.
                    business_class = obj

                    break

            if business_class is None:
                raise BusinessEngineNotFoundError()

        self._business_engine = business_class(
            event_buffer=self._event_buffer,
            topology=self._topology,
            start_tick=self._start_tick,
            max_tick=max_tick,
            snapshot_resolution=self._snapshot_resolution,
            max_snapshots=self._max_snapshots,
            additional_options=self._additional_options)

    def _simulate(self):
        """This is the generator to wrap each episode process."""
        is_end_tick = False

        while True:
            # Ask business engine to do thing for this tick, such as generating and pushing events.
            # We do not push events now.
            self._business_engine.step(self._tick)

            while True:
                # Keep processing events, until no more events in this tick.
                pending_events = self._event_buffer.execute(self._tick)

                # Processing pending events.
                pending_event_length: int = len(pending_events)

                if pending_event_length == 0:
                    # We have processed all the event of current tick, lets go for next tick.
                    break

                # Insert snapshot before each action.
                self._business_engine.frame.take_snapshot(self.frame_index)

                decision_events = []

                # Append source event id to decision events, to support sequential action in joint mode.
                for evt in pending_events:
                    payload = evt.payload

                    payload.source_event_id = evt.id

                    decision_events.append(payload)

                decision_events = decision_events[0] if self._decision_mode == DecisionMode.Sequential \
                    else decision_events

                # Yield current state first, and waiting for action.
                actions = yield self._business_engine.get_metrics(
                ), decision_events, False

                if actions is None:
                    # Make business engine easy to work.
                    actions = []

                if actions is not None and not isinstance(actions, Iterable):
                    actions = [actions]

                # Generate a new atom event first.
                action_event = self._event_buffer.gen_atom_event(
                    self._tick, DECISION_EVENT, actions)

                # We just append the action into sub event of first pending cascade event.
                pending_events[0].state = EventState.EXECUTING
                pending_events[0].immediate_event_list.append(action_event)

                if self._decision_mode == DecisionMode.Joint:
                    # For joint event, we will disable following cascade event.

                    # We expect that first action contains a src_event_id to support joint event with sequential action.
                    action_related_event_id = None if len(
                        actions) == 1 else getattr(actions[0], "src_event_id",
                                                   None)

                    # If the first action has a decision event attached, it means sequential action is supported.
                    is_support_seq_action = action_related_event_id is not None

                    if is_support_seq_action:
                        for i in range(1, pending_event_length):
                            if pending_events[i].id == actions[0].src_event_id:
                                pending_events[i].state = EventState.FINISHED
                    else:
                        for i in range(1, pending_event_length):
                            pending_events[i].state = EventState.FINISHED

            # Check the end tick of the simulation to decide if we should end the simulation.
            is_end_tick = self._business_engine.post_step(self._tick)

            if is_end_tick:
                break

            self._tick += 1

        # Make sure we have no missing data.
        if (self._tick + 1) % self._snapshot_resolution != 0:
            self._business_engine.frame.take_snapshot(self.frame_index)

        # The end.
        yield self._business_engine.get_metrics(), None, True
Exemple #5
0
class Env(AbsEnv):
    """Default environment

    Args:
        scenario (str): scenario name under maro/sim/scenarios folder
        topology (str): topology name under specified scenario folder, if this point to a existing folder, then it will use this as topology for built-in scenario
        start_tick (int): start tick of the scenario, usually used for pre-processed data streaming
        durations (int): duration ticks of this environment from start_tick
        snapshot_resolution (int): how many ticks will take a snapshot
        max_snapshots(int): max in-memory snapshot number, default None means keep all snapshots in memory, when taking a snapshot, if it reaches this limitation, oldest one will be overwrote.
        business_engine_cls : class of business engine, if specified, then use it to construct be instance, or will search internal by scenario
        options (dict): additional parameters passed to business engine

    """
    def __init__(self,
                 scenario: str = None,
                 topology: str = None,
                 start_tick: int = 0,
                 durations: int = 100,
                 snapshot_resolution: int = 1,
                 max_snapshots: int = None,
                 decision_mode: DecisionMode = DecisionMode.Sequential,
                 business_engine_cls: type = None,
                 options: dict = {}):
        super().__init__(scenario, topology, start_tick, durations,
                         snapshot_resolution, max_snapshots, decision_mode,
                         business_engine_cls, options)

        self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None else business_engine_cls.__name__
        self._business_engine: AbsBusinessEngine = None

        self._event_buffer = EventBuffer()

        # generator to push the simulator moving on
        self._simulate_generator = self._simulate()

        # initialize business
        self._init_business_engine()

    def step(self, action):
        """Push the environment to next step with action

        Args:
            action (Action): Action(s) from agent

        Returns:
            (float, object, bool): a tuple of (reward, decision event, is_done)

            The returned tuple contains 3 fields:

            - reward for current action. a list of reward if the input action is a list

            - decision_event for sequential decision mode, or a list of decision_event

            - whether the episode ends
        """

        try:
            reward, decision_event, _is_done = self._simulate_generator.send(
                action)
        except StopIteration:
            return None, None, True

        return reward, decision_event, _is_done

    def dump(self):
        """Dump environment for restore

        NOTE:
            not implemented
        """
        return

    def reset(self):
        """Reset environment"""
        # . reset self
        self._tick = self._start_tick

        self._simulate_generator.close()
        self._simulate_generator = self._simulate()

        # . reset event buffer
        self._event_buffer.reset()

        # . ask business engine reset itself
        self._business_engine.reset()

    @property
    def configs(self) -> dict:
        """object: Configurations of current environment"""
        return self._business_engine.configs

    @property
    def summary(self) -> dict:
        """Summary about current simulator, include node details, and mappings
        
        NOTE: this is provided by scenario, so may have different format and content
        """
        return {
            "node_mapping": self._business_engine.get_node_mapping(),
            "node_detail": self.current_frame.get_node_info()
        }

    @property
    def name(self) -> str:
        """str: Name of current environment"""
        return self._name

    @property
    def current_frame(self) -> FrameBase:
        """Frame: Frame of current environment"""
        return self._business_engine.frame

    @property
    def tick(self) -> int:
        """int: Current tick of environment"""
        return self._tick

    @property
    def frame_index(self) -> int:
        """int: frame index in snapshot list for current tick"""
        return tick_to_frame_index(self._start_tick, self._tick,
                                   self._snapshot_resolution)

    @property
    def snapshot_list(self) -> SnapshotList:
        """SnapshotList: Current snapshot list

        a snapshot list contains all the snapshots of frame at each tick
        """
        return self._business_engine.snapshots

    @property
    def agent_idx_list(self) -> List[int]:
        """List[int]: Agent index list that related to this environment"""
        return self._business_engine.get_agent_idx_list()

    def set_seed(self, seed: int):
        """Set random seed used by simulator.
        
        NOTE: this will not set seed for python random or other packages' seed, such as numpy.
        
        Args:
            seed (int): 
        """

        if seed is not None:
            sim_seed(seed)

    @property
    def metrics(self) -> dict:
        """Some statistics information provided by business engine
        
        Returns:
            dict: dictionary of metrics, content and format is determined by business engine
        """

        return self._business_engine.get_metrics()

    def get_finished_events(self):
        """List[Event]: All events finished so far
        """
        return self._event_buffer.get_finished_events()

    def get_pending_events(self, tick):
        """
        Pending events at certain tick

        Args:
            tick (int): Specified tick
        """
        return self._event_buffer.get_pending_events(tick)

    def _init_business_engine(self):
        """Initialize business engine object.

        NOTE:
        1. internal scenarios will always under "maro/simulator/scenarios" folder
        2. external scenarios, we access the business engine class to create instance 
        """
        max_tick = self._start_tick + self._durations

        if self._business_engine_cls is not None:
            business_class = self._business_engine_cls
        else:
            # combine the business engine import path
            business_class_path = f'maro.simulator.scenarios.{self._scenario}.business_engine'

            # load the module to find business engine for that scenario
            business_module = import_module(business_class_path)

            business_class = None

            for _, obj in getmembers(business_module, isclass):
                if issubclass(obj,
                              AbsBusinessEngine) and obj != AbsBusinessEngine:
                    # we find it
                    business_class = obj

                    break

            if business_class is None:
                raise BusinessEngineNotFoundError()

        self._business_engine = business_class(
            event_buffer=self._event_buffer,
            topology=self._topology,
            start_tick=self._start_tick,
            max_tick=max_tick,
            snapshot_resolution=self._snapshot_resolution,
            max_snapshots=self._max_snapshots,
            additional_options=self._additional_options)

    def _simulate(self):
        """
        this is the generator to wrap each episode process
        """
        is_end_tick = False

        while True:
            # ask business engine to do thing for this tick, such as gen and push events
            # we do not push events now
            self._business_engine.step(self._tick)

            while True:
                # we keep process all the events, until no more any events
                pending_events = self._event_buffer.execute(self._tick)

                # processing pending events
                pending_event_length: int = len(pending_events)

                if pending_event_length == 0:
                    # we have processed all the event of current tick, lets go for next tick
                    break

                # insert snapshot before each action
                self._business_engine.frame.take_snapshot(self.frame_index)

                decision_events = []

                # append source event id to decision events, to support sequential action in joint mode
                for evt in pending_events:
                    payload = evt.payload

                    payload.source_event_id = evt.id

                    decision_events.append(payload)

                decision_events = decision_events[
                    0] if self._decision_mode == DecisionMode.Sequential else decision_events

                # yield current state first, and waiting for action
                actions = yield self._business_engine.get_metrics(
                ), decision_events, False

                if actions is None:
                    actions = []  # make business engine easy to work

                if actions is not None and not isinstance(actions, Iterable):
                    actions = [actions]

                # generate a new atom event first
                action_event = self._event_buffer.gen_atom_event(
                    self._tick, DECISION_EVENT, actions)

                # 3. we just append the action into sub event of first pending cascade event
                pending_events[0].state = EventState.EXECUTING
                pending_events[0].immediate_event_list.append(action_event)

                # TODO: support get reward after action complete here, via using event_buffer.execute

                if self._decision_mode == DecisionMode.Joint:
                    # for joint event, we will disable following cascade event

                    # we expect that first action contains a src_event_id to support joint event with sequential action
                    action_related_event_id = None if len(
                        actions) == 1 else getattr(actions[0], "src_event_id",
                                                   None)

                    # if first action have decision event attached, then means support sequential action
                    is_support_seq_action = action_related_event_id is not None

                    if is_support_seq_action:
                        for i in range(1, pending_event_length):
                            if pending_events[i].id == actions[0].src_event_id:
                                pending_events[i].state = EventState.FINISHED
                    else:
                        for i in range(1, pending_event_length):
                            pending_events[i].state = EventState.FINISHED

            # check if we should end simulation
            is_end_tick = self._business_engine.post_step(self._tick) == True

            if is_end_tick:
                break

            self._tick += 1

        # make sure we have no missing data
        if (self._tick + 1) % self._snapshot_resolution != 0:
            self._business_engine.frame.take_snapshot(self.frame_index)

        # the end
        yield self._business_engine.get_metrics(), None, True
Exemple #6
0
class TestEventBuffer(unittest.TestCase):
    def setUp(self):
        self.eb = EventBuffer()

    def test_gen_event(self):
        """Test event generating correct"""
        evt = self.eb.gen_atom_event(1, 1, (0, 0))

        # fields should be same as specified
        self.assertEqual(AtomEvent, type(evt))
        self.assertEqual(evt.tick, 1)
        self.assertEqual(evt.event_type, 1)
        self.assertEqual(evt.payload, (0, 0))

        evt = self.eb.gen_cascade_event(2, 2, (1, 1, 1))

        self.assertEqual(CascadeEvent, type(evt))
        self.assertEqual(evt.tick, 2)
        self.assertEqual(evt.event_type, 2)
        self.assertEqual(evt.payload, (1, 1, 1))

    def test_insert_event(self):
        """Test insert event works as expected"""

        # pending pool should be empty at beginning
        self.assertEqual(len(self.eb._pending_events), 0)

        evt = self.eb.gen_atom_event(1, 1, 1)

        self.eb.insert_event(evt)

        # after insert one event, we should have 1 in pending pool
        self.assertEqual(len(self.eb._pending_events), 1)

    def test_event_dispatch(self):
        """Test event dispatching work as expected"""
        def cb(evt):
            # test event tick
            self.assertEqual(
                1, evt.tick, msg="recieved event tick should be 1")

            # test event payload
            self.assertTupleEqual(
                (1, 3), evt.payload, msg="recieved event's payload should be (1, 3)")

        evt = self.eb.gen_atom_event(1, 1, (1, 3))

        self.eb.insert_event(evt)

        self.eb.register_event_handler(1, cb)

        self.eb.execute(1)  # dispatch event

    def test_get_finish_events(self):
        """Test if we can get correct finished events"""

        # no finised at first
        self.assertListEqual([], self.eb.get_finished_events(),
                             msg="finished pool should be empty")

        evt = self.eb.gen_atom_event(1, 1, (1, 3))

        self.eb.insert_event(evt)

        self.eb.execute(1)

        # after dispatching, finish pool should contains 1 object
        self.assertEqual(1, len(self.eb.get_finished_events()),
                         msg="after dispathing, there should 1 object")

    def test_get_pending_events(self):
        """Test if we can get correct pending events"""

        # not pending at first
        self.assertEqual(0, len(self.eb.get_pending_events(1)),
                         msg="pending pool should be empty")

        evt = self.eb.gen_atom_event(1, 1, (1, 3))

        self.eb.insert_event(evt)

        self.assertEqual(1, len(self.eb.get_pending_events(1)),
                         msg="pending pool should contains 1 objects")

    def test_reset(self):
        """Test reset, all internal states should be reset"""
        evt = self.eb.gen_atom_event(1, 1, 1)

        self.eb.insert_event(evt)

        self.eb.reset()

        # reset will not clear the tick (key), just clear the pending pool
        self.assertEqual(len(self.eb._pending_events), 1)

        for tick, pending_pool in self.eb._pending_events.items():
            self.assertEqual(0, len(pending_pool))

        self.assertEqual(len(self.eb._finished_events), 0)

    def test_sub_events(self):

        def cb1(evt):
            self.assertEqual(1, evt.payload)

        def cb2(evt):
            self.assertEqual(2, evt.payload)

        self.eb.register_event_handler(1, cb1)
        self.eb.register_event_handler(2, cb2)

        evt: CascadeEvent = self.eb.gen_cascade_event(1, 1, 1)

        evt.add_immediate_event(self.eb.gen_atom_event(1, 2, 2))

        self.eb.insert_event(evt)

        self.eb.execute(1)

    def test_sub_events_with_decision(self):
        evt1 = self.eb.gen_decision_event(1, (1, 1, 1))
        sub1 = self.eb.gen_decision_event(1, (2, 2, 2))
        sub2 = self.eb.gen_decision_event(1, (3, 3, 3))

        evt1.add_immediate_event(sub1, is_head=True)
        evt1.add_immediate_event(sub2)

        self.eb.insert_event(evt1)

        # sub events will be unfold after parent being processed
        decision_events = self.eb.execute(1)

        # so we will get 1 decision events for 1st time executing
        self.assertEqual(1, len(decision_events))
        self.assertEqual(evt1, decision_events[0])

        # mark decision event as executing to make it process folloing events
        decision_events[0].state = EventState.FINISHED

        # then there will be 2 additional decision event from sub events
        decision_events = self.eb.execute(1)

        self.assertEqual(2, len(decision_events))
        self.assertEqual(sub1, decision_events[0])
        self.assertEqual(sub2, decision_events[1])