Beispiel #1
0
def zero_play(**args):
    env = Env(**args)
    _, pa, is_done = env.step(None)
    while not is_done:
        action = Action(pa.vessel_idx, pa.port_idx, 0)
        r, pa, is_done = env.step(action)
    return env.snapshot_list
Beispiel #2
0
    def setUp(self):
        env = Env(scenario="vm_scheduling",
                  topology="tests/data/vm_scheduling/azure.2019.toy",
                  start_tick=0,
                  durations=5,
                  snapshot_resolution=1)
        metrics, decision_event, is_done = env.step(None)

        while not is_done:
            action = AllocateAction(vm_id=decision_event.vm_id,
                                    pm_id=decision_event.valid_pms[0])
            self.metrics, decision_event, is_done = env.step(action)
Beispiel #3
0
def test_cim():
    eps = 4

    env = Env("cim", "toy.5p_ssddd_l0.0", durations=MAX_TICK)

    start_time = time()

    for _ in range(eps):
        _, _, is_done = env.step(None)

        while not is_done:
            _, _, is_done = env.step(None)

        env.reset()

    end_time = time()

    print(f"cim 5p toplogy with {MAX_TICK} total time cost: {(end_time - start_time)/eps}")
Beispiel #4
0
    def run(self):
        """Initialize environment and process commands."""
        metrics = None
        decision_event = None,
        is_done = False

        env = Env(*self._args, **self._kwargs)

        while True:
            cmd, content = self._pipe.recv()

            if cmd == "step":
                if is_done:
                    # Skip is current environment is completed.
                    self._pipe.send((None, None, True, env.frame_index))
                else:
                    metrics, decision_event, is_done = env.step(content)

                    self._pipe.send((metrics, decision_event))
            elif cmd == "reset":
                env.reset()

                metrics = None
                decision_event = None
                is_done = False

                self._pipe.send(None)
            elif cmd == "query":
                node_name, args = content

                states = env.snapshot_list[node_name][args]

                self._pipe.send(states)
            elif cmd == "tick":
                self._pipe.send(env.tick)
            elif cmd == "frame_index":
                self._pipe.send(env.frame_index)
            elif cmd == "is_done":
                self._pipe.send(is_done)
            elif cmd == "stop":
                self._pipe.send(None)
                break
Beispiel #5
0
              start_tick=config.env.start_tick,
              durations=config.env.durations,
              snapshot_resolution=config.env.resolution)
    shutil.copy(os.path.join(env._business_engine._config_path, "config.yml"),
                os.path.join(LOG_PATH, "BEconfig.yml"))
    shutil.copy(CONFIG_PATH, os.path.join(LOG_PATH, "config.yml"))

    if config.env.seed is not None:
        env.set_seed(config.env.seed)

    metrics: object = None
    decision_event: DecisionPayload = None
    is_done: bool = False
    action: Action = None

    metrics, decision_event, is_done = env.step(None)

    # Get the core & memory capacity of all PMs in this environment.
    pm_capacity = env.snapshot_list["pms"][
        env.frame_index::["cpu_cores_capacity", "memory_capacity"]].reshape(
            -1, 2)
    pm_num = pm_capacity.shape[0]

    # ILP agent.
    ilp_agent = IlpAgent(ilp_config=config.ilp,
                         pm_capacity=pm_capacity,
                         vm_table_path=env.configs.VM_TABLE,
                         env_start_tick=config.env.start_tick,
                         env_duration=config.env.durations,
                         simulation_logger=simulation_logger,
                         ilp_logger=ilp_logger,
Beispiel #6
0
          topology="toy.4s_4t",
          start_tick=start_tick,
          durations=durations,
          snapshot_resolution=60,
          options=opts)

print(env.summary)

for ep in range(max_ep):
    metrics = None
    decision_evt: DecisionEvent = None
    is_done = False
    action = None

    while not is_done:
        metrics, decision_evt, is_done = env.step(action)

        # It will be None at the end.
        if decision_evt is not None:
            action = Action(decision_evt.station_idx, 0, 10)

    station_ss = env.snapshot_list['stations']
    shortage_states = station_ss[::'shortage']
    print("total shortage", shortage_states.sum())

    trips_states = station_ss[::'trip_requirement']
    print("total trip", trips_states.sum())

    cost_states = station_ss[::["extra_cost", "transfer_cost"]]

    print("total cost", cost_states.sum())
Beispiel #7
0
class TestCimScenarios(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestCimScenarios, self).__init__(*args, **kwargs)

        with open(os.path.join(TOPOLOGY_PATH_CONFIG, "config.yml"),
                  "r") as input_stream:
            self._raw_topology = yaml.safe_load(input_stream)

        self._env: Optional[Env] = None
        self._reload_topology: str = TOPOLOGY_PATH_CONFIG
        self._business_engine: Optional[CimBusinessEngine] = None

        random.clear()

    def _init_env(self, backend_name: str) -> None:
        os.environ["DEFAULT_BACKEND_NAME"] = backend_name
        self._env = Env(
            scenario="cim",
            topology=self._reload_topology,
            start_tick=0,
            durations=200,
            options={"enable-dump-snapshot": tempfile.gettempdir()})
        self._business_engine = self._env.business_engine

    def test_load_from_config(self) -> None:
        for backend_name in backends_to_test:
            self._init_env(backend_name)

            #########################################################
            if len(
                    self._business_engine.configs
            ) > 0:  # Env will not have `configs` if loaded from dump/real.
                self.assertTrue(
                    compare_dictionary(self._business_engine.configs,
                                       self._raw_topology))

            self.assertEqual(
                len(getattr(self._business_engine.frame, "ports")), 22)
            self.assertEqual(self._business_engine._data_cntr.port_number, 22)
            self.assertEqual(
                len(getattr(self._business_engine.frame, "vessels")), 46)
            self.assertEqual(self._business_engine._data_cntr.vessel_number,
                             46)
            self.assertEqual(len(self._business_engine.snapshots), 0)

            #########################################################
            # Vessel
            vessels: List[
                VesselSetting] = self._business_engine._data_cntr.vessels
            for i, vessel in enumerate(vessels):
                vessel_config = self._raw_topology["vessels"][vessel.name]
                self.assertEqual(vessel.index, i)
                self.assertEqual(vessel.capacity, vessel_config["capacity"])
                self.assertEqual(vessel.parking_duration,
                                 vessel_config["parking"]["duration"])
                self.assertEqual(vessel.parking_noise,
                                 vessel_config["parking"]["noise"])
                self.assertEqual(vessel.start_port_name,
                                 vessel_config["route"]["initial_port_name"])
                self.assertEqual(vessel.route_name,
                                 vessel_config["route"]["route_name"])
                self.assertEqual(vessel.sailing_noise,
                                 vessel_config["sailing"]["noise"])
                self.assertEqual(vessel.sailing_speed,
                                 vessel_config["sailing"]["speed"])

            for name, idx in self._business_engine.get_node_mapping(
            )["vessels"].items():
                self.assertEqual(vessels[idx].name, name)

            #########################################################
            # Port
            ports: List[PortSetting] = self._business_engine._data_cntr.ports
            port_names = [port.name for port in ports]
            for i, port in enumerate(ports):
                assert isinstance(port, SyntheticPortSetting)
                port_config = self._raw_topology["ports"][port.name]
                self.assertEqual(port.index, i)
                self.assertEqual(port.capacity, port_config["capacity"])
                self.assertEqual(port.empty_return_buffer.noise,
                                 port_config["empty_return"]["noise"])
                self.assertEqual(port.full_return_buffer.noise,
                                 port_config["full_return"]["noise"])
                self.assertEqual(
                    port.source_proportion.noise,
                    port_config["order_distribution"]["source"]["noise"])
                for target in port.target_proportions:
                    self.assertEqual(
                        target.noise, port_config["order_distribution"]
                        ["targets"][port_names[target.index]]["noise"])

            for name, idx in self._business_engine.get_node_mapping(
            )["ports"].items():
                self.assertEqual(ports[idx].name, name)

    def test_load_from_real(self) -> None:
        for topology in [TOPOLOGY_PATH_REAL_BIN, TOPOLOGY_PATH_REAL_CSV]:
            self._reload_topology = topology
            for backend_name in backends_to_test:
                self._init_env(backend_name)

                for i, port in enumerate(self._business_engine._ports):
                    self.assertEqual(port.booking, 0)
                    self.assertEqual(port.shortage, 0)

                hard_coded_truth = [556, 0,
                                    20751], [1042, 0,
                                             17320], [0, 0,
                                                      25000], [0, 0, 25000]

                self._env.step(action=None)
                for i, port in enumerate(self._business_engine._ports):
                    self.assertEqual(port.booking, hard_coded_truth[i][0])
                    self.assertEqual(port.shortage, hard_coded_truth[i][1])
                    self.assertEqual(port.empty, hard_coded_truth[i][2])

                self._env.reset(keep_seed=True)
                self._env.step(action=None)
                for i, port in enumerate(self._business_engine._ports):
                    self.assertEqual(port.booking, hard_coded_truth[i][0])
                    self.assertEqual(port.shortage, hard_coded_truth[i][1])
                    self.assertEqual(port.empty, hard_coded_truth[i][2])

        self._reload_topology = TOPOLOGY_PATH_CONFIG

    def test_dump_and_load(self) -> None:
        dump_from_config(os.path.join(TOPOLOGY_PATH_CONFIG, "config.yml"),
                         TOPOLOGY_PATH_DUMP, 200)

        self._reload_topology = TOPOLOGY_PATH_DUMP

        # The reloaded Env should have same behaviors
        self.test_load_from_config()
        self.test_vessel_movement()
        self.test_order_state()
        self.test_order_export()
        self.test_early_discharge()

        self._reload_topology = TOPOLOGY_PATH_CONFIG

    def test_vessel_movement(self) -> None:
        for backend_name in backends_to_test:
            self._init_env(backend_name)

            hard_coded_period = [
                67, 75, 84, 67, 53, 58, 51, 58, 61, 49, 164, 182, 146, 164,
                182, 146, 90, 98, 79, 95, 104, 84, 87, 97, 78, 154, 169, 136,
                154, 169, 94, 105, 117, 94, 189, 210, 167, 189, 210, 167, 141,
                158, 125, 141, 158, 125
            ]
            self.assertListEqual(
                self._business_engine._data_cntr.vessel_period,
                hard_coded_period)

            ports: List[PortSetting] = self._business_engine._data_cntr.ports
            port_names: List[str] = [port.name for port in ports]
            vessel_stops: VesselStopsWrapper = self._business_engine._data_cntr.vessel_stops
            vessels: List[
                VesselSetting] = self._business_engine._data_cntr.vessels

            # Test invalid argument
            self.assertIsNone(vessel_stops[None])

            #########################################################
            for i, vessel in enumerate(vessels):
                start_port_index = port_names.index(vessel.start_port_name)
                self.assertEqual(vessel_stops[i, 0].port_idx, start_port_index)

            #########################################################
            for i, vessel in enumerate(vessels):
                stop_port_indices = [stop.port_idx for stop in vessel_stops[i]]

                raw_route = self._raw_topology["routes"][vessel.route_name]
                route_stop_names = [stop["port_name"] for stop in raw_route]
                route_stop_indices = [
                    port_names.index(name) for name in route_stop_names
                ]
                start_offset = route_stop_indices.index(
                    port_names.index(vessel.start_port_name))

                for j, stop_port_index in enumerate(stop_port_indices):
                    self.assertEqual(
                        stop_port_index,
                        route_stop_indices[(j + start_offset) %
                                           len(route_stop_indices)])

            #########################################################
            # STEP: beginning
            for i, vessel in enumerate(self._business_engine._vessels):
                self.assertEqual(vessel.idx, i)
                self.assertEqual(vessel.next_loc_idx, 0)
                self.assertEqual(vessel.last_loc_idx, 0)

            #########################################################
            self._env.step(action=None)
            self.assertEqual(
                self._env.tick,
                5)  # Vessel 35 will trigger the first arrival event at tick 5
            for i, vessel in enumerate(self._business_engine._vessels):
                if i == 35:
                    self.assertEqual(vessel.next_loc_idx, 1)
                    self.assertEqual(vessel.last_loc_idx, 1)
                else:
                    self.assertEqual(vessel.next_loc_idx, 1)
                    self.assertEqual(vessel.last_loc_idx, 0)

            #########################################################
            self._env.step(action=None)
            self.assertEqual(
                self._env.tick,
                6)  # Vessel 27 will trigger the second arrival event at tick 6
            for i, vessel in enumerate(self._business_engine._vessels):
                if i == 27:  # Vessel 27 just arrives
                    self.assertEqual(vessel.next_loc_idx, 1)
                    self.assertEqual(vessel.last_loc_idx, 1)
                elif i == 35:  # Vessel 35 has already departed
                    self.assertEqual(vessel.next_loc_idx, 2)
                    self.assertEqual(vessel.last_loc_idx, 1)
                else:
                    self.assertEqual(vessel.next_loc_idx, 1)
                    self.assertEqual(vessel.last_loc_idx, 0)

            #########################################################
            while self._env.tick < 100:
                self._env.step(action=None)
            self.assertEqual(self._env.tick, 100)
            for i, vessel in enumerate(self._business_engine._vessels):
                expected_next_loc_idx = expected_last_loc_idx = -1
                for j, stop in enumerate(vessel_stops[i]):
                    if stop.arrival_tick == self._env.tick:
                        expected_next_loc_idx = expected_last_loc_idx = j
                        break
                    if stop.arrival_tick > self._env.tick:
                        expected_next_loc_idx = j
                        expected_last_loc_idx = j - 1
                        break

                self.assertEqual(vessel.next_loc_idx, expected_next_loc_idx)
                self.assertEqual(vessel.last_loc_idx, expected_last_loc_idx)

    def test_order_state(self) -> None:
        for backend_name in backends_to_test:
            self._init_env(backend_name)

            for i, port in enumerate(self._business_engine._ports):
                total_containers = self._raw_topology['total_containers']
                initial_container_proportion = self._raw_topology['ports'][
                    port.name]['initial_container_proportion']

                self.assertEqual(port.booking, 0)
                self.assertEqual(port.shortage, 0)
                self.assertEqual(
                    port.empty,
                    int(total_containers * initial_container_proportion))

            #########################################################
            self._env.step(action=None)
            self.assertEqual(self._env.tick, 5)

            hard_coded_truth = [  # Should get same results under default random seed
                [223, 0, 14726], [16, 0, 916], [18, 0, 917], [89, 0, 5516],
                [84, 0, 4613], [72, 0, 4603], [26, 0, 1374], [24, 0, 1378],
                [48, 0, 2756], [54, 0, 2760], [26, 0, 1379], [99, 0, 5534],
                [137, 0, 7340], [19, 0, 912], [13, 0, 925], [107, 0, 6429],
                [136, 0, 9164], [64, 0, 3680], [24, 0, 1377], [31, 0, 1840],
                [109, 0, 6454], [131, 0, 7351]
            ]
            for i, port in enumerate(self._business_engine._ports):
                self.assertEqual(port.booking, hard_coded_truth[i][0])
                self.assertEqual(port.shortage, hard_coded_truth[i][1])
                self.assertEqual(port.empty, hard_coded_truth[i][2])

    def test_keep_seed(self) -> None:
        for backend_name in backends_to_test:
            self._init_env(backend_name)

            vessel_stops_1: List[
                List[Stop]] = self._business_engine._data_cntr.vessel_stops
            self._env.step(action=None)
            port_info_1 = [(port.booking, port.shortage, port.empty)
                           for port in self._business_engine._ports]

            self._env.reset(keep_seed=True)
            vessel_stops_2: List[
                List[Stop]] = self._business_engine._data_cntr.vessel_stops
            self._env.step(action=None)
            port_info_2 = [(port.booking, port.shortage, port.empty)
                           for port in self._business_engine._ports]

            self._env.reset(keep_seed=False)
            vessel_stops_3: List[
                List[Stop]] = self._business_engine._data_cntr.vessel_stops
            self._env.step(action=None)
            port_info_3 = [(port.booking, port.shortage, port.empty)
                           for port in self._business_engine._ports]

            # Vessel
            for i in range(self._business_engine._data_cntr.vessel_number):
                # 1 and 2 should be totally equal
                self.assertListEqual(vessel_stops_1[i], vessel_stops_2[i])

                # 1 and 3 should have difference
                flag = True
                for stop1, stop3 in zip(vessel_stops_1[i], vessel_stops_3[i]):
                    self.assertListEqual(
                        [stop1.index, stop1.port_idx, stop1.vessel_idx],
                        [stop3.index, stop3.port_idx, stop3.vessel_idx])
                    if (stop1.arrival_tick, stop1.leave_tick) != (
                            stop3.arrival_tick, stop3.leave_tick):
                        flag = False
                self.assertFalse(flag)

            # Port
            self.assertListEqual(port_info_1, port_info_2)
            self.assertFalse(
                all(port1 == port3
                    for port1, port3 in zip(port_info_1, port_info_3)))

    def test_order_export(self) -> None:
        """order.tick, order.src_port_idx, order.dest_port_idx, order.quantity"""
        Order = namedtuple(
            "Order", ["tick", "src_port_idx", "dest_port_idx", "quantity"])

        #
        for enabled in [False, True]:
            exporter = PortOrderExporter(enabled)

            for i in range(5):
                exporter.add(Order(0, 0, 1, i + 1))

            out_folder = tempfile.gettempdir()
            if os.path.exists(f"{out_folder}/orders.csv"):
                os.remove(f"{out_folder}/orders.csv")

            exporter.dump(out_folder)

            if enabled:
                with open(f"{out_folder}/orders.csv") as fp:
                    reader = csv.DictReader(fp)
                    row = 0
                    for line in reader:
                        self.assertEqual(row + 1, int(line["quantity"]))
                        row += 1
            else:  # Should done nothing
                self.assertFalse(os.path.exists(f"{out_folder}/orders.csv"))

    def test_early_discharge(self) -> None:
        for backend_name in backends_to_test:
            self._init_env(backend_name)

            metric, decision_event, is_done = self._env.step(None)
            assert isinstance(decision_event, DecisionEvent)

            self.assertEqual(decision_event.action_scope.load, 1240)
            self.assertEqual(decision_event.action_scope.discharge, 0)
            self.assertEqual(decision_event.early_discharge, 0)

            decision_event = pickle.loads(
                pickle.dumps(decision_event))  # Test serialization

            load_action = Action(vessel_idx=decision_event.vessel_idx,
                                 port_idx=decision_event.port_idx,
                                 quantity=1201,
                                 action_type=ActionType.LOAD)
            discharge_action = Action(vessel_idx=decision_event.vessel_idx,
                                      port_idx=decision_event.port_idx,
                                      quantity=1,
                                      action_type=ActionType.DISCHARGE)
            metric, decision_event, is_done = self._env.step(
                [load_action, discharge_action])

            history = []
            while not is_done:
                metric, decision_event, is_done = self._env.step(None)
                assert decision_event is None or isinstance(
                    decision_event, DecisionEvent)
                if decision_event is not None and decision_event.vessel_idx == 35:
                    v = self._business_engine._vessels[35]
                    history.append((v.full, v.empty, v.early_discharge))

            hard_coded_benchmark = [(465, 838, 362), (756, 547, 291),
                                    (1261, 42, 505), (1303, 0, 42),
                                    (1303, 0, 0), (1303, 0, 0), (803, 0, 0)]
            self.assertListEqual(history, hard_coded_benchmark)

            #
            payload_detail_benchmark = {
                'ORDER': ['tick', 'src_port_idx', 'dest_port_idx', 'quantity'],
                'RETURN_FULL': ['src_port_idx', 'dest_port_idx', 'quantity'],
                'VESSEL_ARRIVAL': ['port_idx', 'vessel_idx'],
                'LOAD_FULL': ['port_idx', 'vessel_idx'],
                'DISCHARGE_FULL':
                ['vessel_idx', 'port_idx', 'from_port_idx', 'quantity'],
                'PENDING_DECISION': [
                    'tick', 'port_idx', 'vessel_idx', 'snapshot_list',
                    'action_scope', 'early_discharge'
                ],
                'LOAD_EMPTY':
                ['port_idx', 'vessel_idx', 'action_type', 'quantity'],
                'DISCHARGE_EMPTY':
                ['port_idx', 'vessel_idx', 'action_type', 'quantity'],
                'VESSEL_DEPARTURE': ['port_idx', 'vessel_idx'],
                'RETURN_EMPTY': ['port_idx', 'quantity']
            }
            self.assertTrue(
                compare_dictionary(
                    self._business_engine.get_event_payload_detail(),
                    payload_detail_benchmark))
            port_number = self._business_engine._data_cntr.port_number
            self.assertListEqual(self._business_engine.get_agent_idx_list(),
                                 list(range(port_number)))
Beispiel #8
0
                               (available_bikes, supply_candidate))
                if len(top_k_supplies) > self._supply_top_k:
                    heapq.heappop(top_k_supplies)

            max_reposition, source_idx = random.choice(top_k_supplies)
            action = Action(source_idx, decision_event.station_idx,
                            max_reposition)

        return action


if __name__ == "__main__":
    env = Env(scenario=config.env.scenario,
              topology=config.env.topology,
              start_tick=config.env.start_tick,
              durations=config.env.durations,
              snapshot_resolution=config.env.resolution)

    if config.env.seed is not None:
        env.set_seed(config.env.seed)

    policy = GreedyPolicy(config.agent.supply_top_k, config.agent.demand_top_k)
    metrics, decision_event, done = env.step(None)
    while not done:
        metrics, decision_event, done = env.step(
            policy.choose_action(decision_event))

    print(f"Greedy agent policy performance: {env.metrics}")

    env.reset()
Beispiel #9
0
def single_player_worker(index, config, exp_idx_mapping, pipe, action_io,
                         exp_output):
    """The A2C worker function to collect experience.

    Args:
        index (int): The process index counted from 0.
        config (dict): It is a dottable dictionary that stores the configuration of the simulation, state_shaper and
            postprocessing shaper.
        exp_idx_mapping (dict): The key is agent code and the value is the starting index where the experience is stored
            in the experience batch.
        pipe (Pipe): The pipe instance for communication with the main process.
        action_io (SharedStructure): The shared memory to hold the state information that the main process uses to
            generate an action.
        exp_output (SharedStructure): The shared memory to transfer the experience list to the main process.
    """
    env = Env(**config.env.param)
    fix_seed(env, config.env.seed)
    static_code_list, dynamic_code_list = list(env.summary["node_mapping"]["ports"].values()), \
        list(env.summary["node_mapping"]["vessels"].values())
    # Create gnn_state_shaper without consuming any resources.

    gnn_state_shaper = GNNStateShaper(
        static_code_list,
        dynamic_code_list,
        config.env.param.durations,
        config.model.feature,
        tick_buffer=config.model.tick_buffer,
        max_value=env.configs["total_containers"])
    gnn_state_shaper.compute_static_graph_structure(env)

    action_io_np = action_io.structuralize()

    action_shaper = DiscreteActionShaper(config.model.action_dim)
    exp_shaper = ExperienceShaper(static_code_list,
                                  dynamic_code_list,
                                  config.env.param.durations,
                                  gnn_state_shaper,
                                  scale_factor=config.env.return_scaler,
                                  time_slot=config.training.td_steps,
                                  discount_factor=config.training.gamma,
                                  idx=index,
                                  shared_storage=exp_output.structuralize(),
                                  exp_idx_mapping=exp_idx_mapping)

    i = 0
    while pipe.recv() == "reset":
        env.reset()
        r, decision_event, is_done = env.step(None)

        j = 0
        logs = []
        while not is_done:
            model_input = gnn_state_shaper(decision_event, env.snapshot_list)
            action_io_np["v"][:, index] = model_input["v"]
            action_io_np["p"][:, index] = model_input["p"]
            action_io_np["vo"][index] = model_input["vo"]
            action_io_np["po"][index] = model_input["po"]
            action_io_np["vedge"][index] = model_input["vedge"]
            action_io_np["pedge"][index] = model_input["pedge"]
            action_io_np["ppedge"][index] = model_input["ppedge"]
            action_io_np["mask"][index] = model_input["mask"]
            action_io_np["pid"][index] = decision_event.port_idx
            action_io_np["vid"][index] = decision_event.vessel_idx
            pipe.send("features")
            model_action = pipe.recv()
            env_action = action_shaper(decision_event, model_action)
            exp_shaper.record(decision_event=decision_event,
                              model_action=model_action,
                              model_input=model_input)
            logs.append([
                index, decision_event.tick, decision_event.port_idx,
                decision_event.vessel_idx, model_action, env_action,
                decision_event.action_scope.load,
                decision_event.action_scope.discharge
            ])
            action = Action(decision_event.vessel_idx, decision_event.port_idx,
                            env_action)
            r, decision_event, is_done = env.step(action)
            j += 1
        action_io_np["sh"][index] = compute_shortage(
            env.snapshot_list, config.env.param.durations, static_code_list)
        i += 1
        pipe.send("done")
        gnn_state_shaper.end_ep_callback(env.snapshot_list)
        # Organize and synchronize exp to shared memory.
        exp_shaper(env.snapshot_list)
        exp_shaper.reset()
        logs = np.array(logs, dtype=np.float)
        pipe.send(logs)
Beispiel #10
0
    if PEEP_AND_USE_REAL_DATA:
        ENV = env
        TRIP_PICKER = BinaryReader(env.configs["trip_data"]).items_tick_picker(
            start_time_offset=config.env.start_tick,
            end_time_offset=(config.env.start_tick + config.env.durations),
            time_unit="m"
        )

    if config.env.seed is not None:
        env.set_seed(config.env.seed)

    # Start simulation.
    decision_event: DecisionEvent = None
    action: Action = None
    is_done: bool = False
    _, decision_event, is_done = env.step(action=None)

    # TODO: Update the Env interface.
    num_station = len(env.agent_idx_list)
    station_distance_adj = np.array(
        load_adj_from_csv(env.configs["distance_adj_data"], skiprows=1)
    ).reshape(num_station, num_station)
    station_neighbor_list = [
        neighbor_list[1:]
        for neighbor_list in np.argsort(station_distance_adj, axis=1).tolist()
    ]

    # Init a Moving-Average based ILP agent.
    decision_interval = env.configs["decision"]["resolution"]
    ilp = CitiBikeILP(
        num_station=num_station,