예제 #1
0
    def test_seeding(self):
        g = nx.Graph()
        g.add_edges_from([(0, 1), (1, 2), (2, 3)])
        nx.set_node_attributes(g, {
            0: (0, 1),
            1: (0, 2),
            2: (1, 1),
            3: (1, 2)
        },
                               name="coords")
        orders = [(0, 1, 0, 1, 1), (1, 1, 1, 2, 2), (2, 2, 1, 3, 3),
                  (3, 2, 2, 3, 3)]
        drivers = np.array([1, 0, 0, 5])
        action = np.array([0.3, 0.4, 0.3], dtype=float)

        obs, rew, done, info = None, None, None, None
        for i in range(100):
            env = TaxiEnv(g, orders, 1, drivers, 10, seed=123)
            env.step(action)
            obs2, rew2, done2, info2 = env.step(action)
            if i > 0:
                assert (obs == obs2).all()
                assert rew == rew2
                assert done == done2
                assert info == info2
            obs, rew, done, info = obs2, rew2, done2, info2
예제 #2
0
 def testInit(self):
     g = nx.Graph()
     g.add_edges_from([(0, 1)])
     nx.set_node_attributes(g, {0: (0, 1), 1: (1, 2)}, "coords")
     orders = [(1, 1, 1, 1, 0.5)]
     drivers = np.ones((2), dtype=int)
     env = TaxiEnv(g, orders, 1, drivers, 10, 0.5)
     env.step(np.zeros(2))
예제 #3
0
 def testInit(self):
     '''
     Test initialization of driver and order distributions
     '''
     # initialize input parameters
     g = nx.Graph()
     g.add_edges_from([(0, 1), (1, 2), (0, 2), (2, 3)])
     orders = [(1, 1, 0, 1, 0.5), (1, 1, 1, 1, 0.7)
               ]  # <source, destination, time, length, price>
     drivers = [0, 1, 2, 3]
     order_sampling_rate = 1
     n_intervals = 3
     env = TaxiEnv(g, orders, order_sampling_rate, drivers, n_intervals)
     assert env.n_drivers == 6
     assert len(env.all_driver_list) == env.n_drivers
     assert env.max_reward == 0.7
     assert len(env.orders_per_time_interval) == n_intervals + 1
     assert env.action_space_shape == (4, )
     assert env.done == False
     assert env.time == 0
     assert len(env.world) == 4
     assert sum([env.drivers_per_node[i] for i in range(4)]) == 6
     assert sum([
         env.world.nodes[i]['info'].get_driver_num() for i in range(4)
     ]) == 6
     assert sum([
         env.world.nodes[i]['info'].get_order_num() for i in range(4)
     ]) == 1
예제 #4
0
    def __init__(self,
                 world: nx.Graph,
                 orders: Tuple[int, int, int, int, float],
                 order_sampling_rate: float,
                 drivers_per_node: Array[int],
                 n_intervals: List,
                 wc: float,
                 count_neighbors: bool = False,
                 weight_poorest: bool = False,
                 normalize_rewards: bool = True,
                 minimum_reward: bool = False,
                 reward_bound: float = None,
                 include_income_to_observation: bool = False,
                 poorest_first: bool = False,
                 idle_reward: bool = False) -> None:

        self.itEnv = TaxiEnv(world, orders, order_sampling_rate,
                             drivers_per_node, n_intervals, wc,
                             count_neighbors, weight_poorest,
                             normalize_rewards, minimum_reward, reward_bound,
                             include_income_to_observation, poorest_first,
                             idle_reward)
        self.world = self.itEnv.world
        self.n_intervals = n_intervals
        self.n_drivers = self.itEnv.n_drivers
        self.time = 0
        self.include_income_to_observation = include_income_to_observation
        self.one_cell_action_space = self.itEnv.max_degree + 1
        self.action_space = spaces.Box(low=0,
                                       high=1,
                                       shape=(self.one_cell_action_space *
                                              self.itEnv.world_size, ))

        if include_income_to_observation:
            assert self.itEnv.observation_space_shape[0] == 3 * len(
                self.world) + self.itEnv.n_intervals + 3
            self.observation_space_shape = (
                self.itEnv.observation_space_shape[0] + 2 * len(self.world) -
                3, )
        else:
            assert self.itEnv.observation_space_shape[0] == 3 * len(
                self.world) + self.itEnv.n_intervals
            self.observation_space_shape = (
                self.itEnv.observation_space_shape[0] - len(self.world), )
        self.observation_space = spaces.Box(low=0,
                                            high=1,
                                            shape=self.observation_space_shape)
예제 #5
0
    def test_automatic_return(self):
        """
        Check linear graph: half of the graph is outside the view, and the car that was sent outside
        is returning automatically to the nearest node in the view.
        """
        g = nx.Graph()
        g.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)])
        nx.set_node_attributes(g, {
            0: (0, 1),
            1: (0, 2),
            2: (0, 3),
            3: (0, 4),
            4: (0, 5),
            5: (0, 6)
        },
                               name="coords")
        orders = [(1, 4, 0, 10, 80)]
        drivers = np.array([0, 1, 0, 0, 0, 0])
        action = np.array([0, 0, 0], dtype=float)

        env = TaxiEnv(g, orders, 1, drivers, 30)
        env.set_view([0, 1, 2])
        env.step(action)
        # check that the final destination of the car is the node 2, in (10+2) intervals
        # after performing step(), the env should set current_node_id to 2, and time to 12
        assert env.time == 12
        assert env.current_node_id == 2
        d = env.all_driver_list[0]
        d.status = 1
        d.income = 80
        d.position = 2
예제 #6
0
    def test_count_neighbors(self):
        g = nx.Graph()
        g.add_edges_from([(0, 1), (1, 2), (2, 0)])
        nx.set_node_attributes(g, {
            0: (0, 1),
            1: (0, 2),
            2: (1, 1)
        },
                               name="coords")
        orders = [(0, 1, 0, 1, 1), (1, 2, 0, 2, 2), (2, 0, 0, 3, 3)]
        drivers = np.array([4, 0, 1])

        env = TaxiEnv(g, orders, 1, drivers, 3, 0.5, count_neighbors=True)
        observation1 = env.reset()
        env.current_node_id = 0
        dispatch_list = env.make_order_dispatch_list_and_remove_orders()
        assert len(dispatch_list) == 2

        env.reset()  # order_dispatch_list can be run only single time
        env.current_node_id = 0
        observation2, reward, done, info = env.step(np.zeros(3))
        assert reward == (1 + 2 - 0.5 - 0.5) / 4
예제 #7
0
    def testMove3CarsStrictly(self):
        g = nx.Graph()
        N = 9
        g.add_edges_from([(0, 1), (1, 2), (0, 3), (1, 4), (2, 5), (3, 4),
                          (4, 5), (3, 6), (4, 7), (5, 8), (6, 7), (7, 8)])
        nx.set_node_attributes(
            g, {
                0: (0, 0),
                1: (0, 1),
                2: (0, 2),
                3: (1, 0),
                4: (1, 1),
                5: (1, 2),
                6: (2, 0),
                7: (2, 1),
                8: (2, 2)
            }, "coords")
        orders = [(7, 3, 0, 2, 999.4)]
        drivers = np.zeros(N, dtype=int)
        drivers[0] = 1
        drivers[2] = 1
        drivers[7] = 2
        env = TaxiEnv(g, orders, 1, drivers, 3, 0.5, hold_observation=False)
        observation = env.reset()

        # observation should be [drivers] + [customers] + [onehot time] + [onehot cell]
        order_distr = np.zeros(N)
        order_distr[7] = 1
        assert (observation[:N] == drivers / np.max(drivers)).all()
        assert (observation[N:2 * N] == order_distr).all()
        onehot_time = np.zeros(3)
        onehot_time[0] = 1
        assert (observation[3 * N:3 * N + 3] == onehot_time).all()
        node_id = np.argmax(observation[3 * N + 3:4 * N + 3])
        assert node_id in [0, 2, 7]

        env.current_node_id = 0
        env.find_non_empty_nodes()
        action = np.zeros(5)
        action[0] = 1
        observation, reward, done, info = env.step(action)
        assert env.time == 0
        assert reward == -0.5
        assert done == False
        new_drivers = np.copy(drivers)
        new_drivers[0] = 0
        new_drivers[1] = 0
        assert (observation[:N] == new_drivers / np.max(new_drivers)).all()

        env.current_node_id = 2
        env.find_non_empty_nodes()
        action = np.zeros(5)
        action[-1] = 1
        observation, reward, done, info = env.step(action)
        assert done == False
        assert reward == -0.5

        assert env.current_node_id == 7
        env.find_non_empty_nodes()
        assert env.time == 0
        assert (observation[N:2 * N] == order_distr).all()
        new_drivers[2] = 0
        assert (observation[:N] == new_drivers / np.max(new_drivers)).all()
        assert (observation[3 * N:3 * N + 3] == onehot_time).all()
        assert np.argmax(observation[3 * N + 3:4 * N + 3]) == 7
        action = np.zeros(5)
        action[1] = 1
        observation, reward, done, info = env.step(action)

        # should have completed the step
        assert done == False
        assert reward == (999.4 -
                          0.5) / 2  # averaging by cars in the node by default
        assert env.time == 1
        assert (observation[N:2 * N] == np.zeros(N)).all()
        new_drivers = np.zeros(N)
        new_drivers[2] = 1
        new_drivers[1] = 1
        new_drivers[6] = 1  # assuming second edge is to 6th node
        # one driver still travelling
        assert (observation[:N] == new_drivers / np.max(new_drivers)).all()
        assert len(env.traveling_pool[2]) == 1
        d = env.traveling_pool[2][0]
        assert d.income == 999.4
        assert d.status == 0
        assert d.position == 3

        action = np.zeros(5)
        action[-1] = 1

        observation, reward, done, info = env.step(action)
        assert done == False
        assert env.time == 1
        observation, reward, done, info = env.step(action)
        assert done == False
        assert env.time == 1
        observation, reward, done, info = env.step(action)
        assert env.time == 2
        assert done == False  # True when time = n_intervals (3)

        # check status of nodes and drivers
        assert env.world.nodes[0]['info'].get_driver_num() == 0
        assert env.world.nodes[1]['info'].get_driver_num() == 1
        assert env.world.nodes[2]['info'].get_driver_num() == 1
        assert env.world.nodes[6]['info'].get_driver_num() == 1
        assert env.world.nodes[3]['info'].get_driver_num() == 1
        assert len(env.all_driver_list) == 4
        total_income = 0
        for d in env.all_driver_list:
            total_income += d.income
            assert d.status == 1
        assert total_income == -0.5 * 3 * 2 + 999.4
예제 #8
0
    def test_view(self):
        g = nx.Graph()
        g.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4), (1, 5)])
        nx.set_node_attributes(g, {
            0: (0, 1),
            1: (0, 2),
            2: (0, 3),
            3: (0, 4),
            4: (0, 5),
            5: (1, 1)
        },
                               name="coords")
        orders = [(3, 2, 2, 3, 3)
                  ]  # <source, destination, time, length, price>
        drivers = np.array([1, 1, 1, 1, 1, 1])
        action = np.array([1, 0, 0], dtype=float)

        env = TaxiEnv(g, orders, 1, drivers, 10)
        env.seed(123)
        env.set_view([2, 3, 4])
        obs, _, _ = env.get_observation()

        # check observation space and content
        assert env.observation_space_shape == obs.shape
        view_size = 3
        assert env.observation_space_shape == (
            view_size * 4 + 10,
        )  # default income is not included, so its <driver, order, idle, time_id, node_id>
        assert env.action_space_shape == (
            3,
        )  # degree of 1 is 3, but of the rest is 2. So it should be 2 + 1 (staying action)
        assert env.current_node_id in [2, 3, 4]

        # an action [1, 0, 0] for the node 2 means to go to node 3, because its the only neighbor in the view
        env.step(action)
        assert env.current_node_id in [2, 3, 4]
        env.step(action)
        assert env.current_node_id in [2, 3, 4]
        obs, rew, done, info = env.step(action)
        assert (obs[:view_size] == np.array([0.5, 1, 0])).all(
        )  # there are 2 drivers in the node 3 at the end, one from node 2, one from node 4.
        assert (obs[view_size:2 * view_size] == np.array([0, 0, 0])).all()
        assert (obs[2 * view_size:3 * view_size] == np.array([0.5, 1,
                                                              0])).all()
        # next time iteration should happen
        assert env.time == 1
        assert env.current_node_id in [2, 3]
        assert (obs[2 * view_size:3 * view_size] == np.array([0.5, 1,
                                                              0])).all()
        assert (obs[3 * view_size:3 * view_size + 10] == np.array(
            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0])).all()
        assert obs[3 * view_size + 10:].shape == (3, )
        assert (obs[3 * view_size + 10:] == np.array([
            1, 0, 0
        ])).all() or (obs[3 * view_size + 10:] == np.array([0, 1, 0])).all()
        assert [d.position for d in env.all_driver_list] == [0, 1, 3, 2, 3, 5]
예제 #9
0
    def test_sync(self):
        g = nx.Graph()
        g.add_edges_from([(0, 1), (1, 2), (2, 3)])
        nx.set_node_attributes(g, {
            0: (0, 1),
            1: (0, 2),
            2: (1, 1),
            3: (1, 2)
        },
                               name="coords")
        orders = [(0, 1, 0, 1, 1), (1, 1, 1, 2, 2), (2, 2, 1, 3, 3),
                  (3, 2, 2, 3, 3)]
        drivers = np.array([1, 0, 0, 5])
        action = np.array([0.3, 0.4, 0.3], dtype=float)

        env = TaxiEnv(g, orders, 1, drivers, 10)
        env.step(action)

        env2 = TaxiEnv(g, orders, 1, drivers, 10)
        env2.sync(env)

        o1, _, _ = env.get_observation()
        o2, _, _ = env2.get_observation()
        assert (o1 == o2).all()

        env.seed(1)
        env2.seed(1)

        while not env.done:
            obs, rew, done, info = env.step(action)
            obs2, rew2, done2, info2 = env2.step(action)

            assert (obs == obs2).all()
            assert rew == rew2
            assert done == done2
            assert info == info2
예제 #10
0
    def test_reward_options(self):
        '''
        Test these:
            weight_poorest: bool = False,
            normalize_rewards: bool = True,
            minimum_reward: bool = False,
            reward_bound: float = None,
            include_income_to_observation: int = 0
        '''
        g = nx.Graph()
        g.add_edges_from([(0, 1), (1, 2), (2, 3)])
        nx.set_node_attributes(g, {
            0: (0, 1),
            1: (0, 2),
            2: (1, 1),
            3: (1, 2)
        },
                               name="coords")
        orders = [(0, 1, 0, 1, 1), (1, 1, 0, 2, 2), (2, 2, 0, 3, 3),
                  (3, 2, 0, 3, 3)]
        drivers = np.array([1, 0, 0, 5])
        action = np.array([1, 0, 0], dtype=float)

        env = TaxiEnv(g,
                      orders,
                      1,
                      drivers,
                      3,
                      0.5,
                      count_neighbors=True,
                      normalize_rewards=False)
        observation = env.reset()
        env.current_node_id = 3
        env.non_empty_nodes = [0, 1, 2]
        observation, reward, done, info = env.step(action)
        assert reward == (3 + 3 - 0.5 - 0.5 - 0.5)

        env = TaxiEnv(g,
                      orders,
                      1,
                      drivers,
                      3,
                      0.5,
                      count_neighbors=True,
                      weight_poorest=True)
        observation = env.reset()
        env.current_node_id = 3
        env.non_empty_nodes = [0, 1, 2]
        observation, reward, done, info = env.step(action)
        # reward is softmax of the reard multiplied by reward
        r = np.array([
            0, 3, 3, -0.5, -.5, -.5
        ])  # 0 is because there is a guy in the node 0 that does not move
        mult = 1 - env.softmax(r)
        rew = mult * r
        rew /= 5
        assert reward == pytest.approx(np.sum(rew))

        env = TaxiEnv(g,
                      orders,
                      1,
                      drivers,
                      3,
                      0.5,
                      count_neighbors=True,
                      minimum_reward=True)
        observation = env.reset()
        env.current_node_id = 3
        env.non_empty_nodes = [0, 1, 2]
        observation, reward, done, info = env.step(action)
        assert reward == -0.5 / 5  # returns a single value of a minimum reward, normalized

        env = TaxiEnv(g,
                      orders,
                      1,
                      drivers,
                      3,
                      0.5,
                      count_neighbors=True,
                      normalize_rewards=False,
                      minimum_reward=True)
        observation = env.reset()
        env.current_node_id = 3
        env.non_empty_nodes = [0, 1, 2]
        observation, reward, done, info = env.step(action)
        assert reward == -0.5  # returns a single value of a minimum reward, non-normalized

        env = TaxiEnv(g,
                      orders,
                      1,
                      drivers,
                      3,
                      0.5,
                      count_neighbors=True,
                      reward_bound=1)
        observation = env.reset()
        env.current_node_id = 3
        env.non_empty_nodes = [0, 1, 2]
        observation, reward, done, info = env.step(action)
        assert reward == (1 + 1 - 0.5 - 0.5 - 0.5) / 5

        drivers = np.array([2, 0, 0, 5])
        env = TaxiEnv(g,
                      orders,
                      1,
                      drivers,
                      3,
                      0.5,
                      count_neighbors=True,
                      reward_bound=1,
                      include_income_to_observation=True)
        observation = env.reset()
        env.world.nodes[0]['info'].drivers[0].add_income(0.9)
        env.current_node_id = 3
        env.non_empty_nodes = [0, 1, 2]
        # all drivers from 3rd node are moved but haven't arrived, so observation should show only the driver at 0's node
        observation, reward, done, info = env.step(action)
        assert env.current_node_id == 0
        assert observation.shape[0] == 5 * env.world_size + env.n_intervals
예제 #11
0
class TaxiEnvBatch(gym.Env):
    '''
    This class is a wrapper over taxi_env, providing an interface for cA2C,
    that requires processing drivers in batches + some additional context information
    '''
    metadata = {'render.modes': ['rgb_array']}

    def __init__(self,
                 world: nx.Graph,
                 orders: Tuple[int, int, int, int, float],
                 order_sampling_rate: float,
                 drivers_per_node: Array[int],
                 n_intervals: List,
                 wc: float,
                 count_neighbors: bool = False,
                 weight_poorest: bool = False,
                 normalize_rewards: bool = True,
                 minimum_reward: bool = False,
                 reward_bound: float = None,
                 include_income_to_observation: bool = False,
                 poorest_first: bool = False,
                 idle_reward: bool = False) -> None:

        self.itEnv = TaxiEnv(world, orders, order_sampling_rate,
                             drivers_per_node, n_intervals, wc,
                             count_neighbors, weight_poorest,
                             normalize_rewards, minimum_reward, reward_bound,
                             include_income_to_observation, poorest_first,
                             idle_reward)
        self.world = self.itEnv.world
        self.n_intervals = n_intervals
        self.n_drivers = self.itEnv.n_drivers
        self.time = 0
        self.include_income_to_observation = include_income_to_observation
        self.one_cell_action_space = self.itEnv.max_degree + 1
        self.action_space = spaces.Box(low=0,
                                       high=1,
                                       shape=(self.one_cell_action_space *
                                              self.itEnv.world_size, ))

        if include_income_to_observation:
            assert self.itEnv.observation_space_shape[0] == 3 * len(
                self.world) + self.itEnv.n_intervals + 3
            self.observation_space_shape = (
                self.itEnv.observation_space_shape[0] + 2 * len(self.world) -
                3, )
        else:
            assert self.itEnv.observation_space_shape[0] == 3 * len(
                self.world) + self.itEnv.n_intervals
            self.observation_space_shape = (
                self.itEnv.observation_space_shape[0] - len(self.world), )
        self.observation_space = spaces.Box(low=0,
                                            high=1,
                                            shape=self.observation_space_shape)

    def reset(self) -> Array[int]:
        self.time = 0
        if self.itEnv.include_income_to_observation:
            t = self.itEnv.world_size + 3
            observation = self.itEnv.reset()[:-t]
            # assuming all incomes are zero
            return np.concatenate(
                (observation, np.zeros(3 * self.itEnv.world_size)))
        else:
            t = self.itEnv.world_size
            return self.itEnv.reset()[:-t]

    def get_reset_info(self):
        '''
        Currently used only to get max_orders and max_drivers, that should current_cell independent
        '''
        return self.itEnv.get_reset_info()

    def step(self,
             action: Array[float]) -> Tuple[Array[int], float, bool, Dict]:
        cells_with_nonzero_drivers = np.sum([
            1 for n in self.itEnv.world.nodes(data=True)
            if n[1]['info'].get_driver_num() > 0
        ])
        nodes_with_orders = np.sum([
            1 for n in self.itEnv.world.nodes(data=True)
            if n[1]['info'].get_order_num() > 0
        ])
        total_orders = np.sum([
            n[1]['info'].get_order_num()
            for n in self.itEnv.world.nodes(data=True)
        ])
        global_observation = np.zeros(5 * self.itEnv.world_size +
                                      self.itEnv.n_intervals)
        global_done = False
        global_reward = 0
        reward_per_node = np.zeros(self.itEnv.world_size)
        init_t = self.itEnv.time
        self.last_action_for_drawing = action

        total_served_orders = 0
        max_driver = None
        max_order = None

        for i in range(cells_with_nonzero_drivers):
            current_cell = self.itEnv.current_node_id
            a = current_cell * self.one_cell_action_space
            action_per_cell = action[a:a + self.one_cell_action_space]

            observation, reward, done, info = self.itEnv.step(action_per_cell)

            reward_per_node[current_cell] = reward
            global_done = done
            global_reward += reward
            total_served_orders += info['served_orders']

            # updated at each step, but the final should be corrent
            max_driver = info["driver normalization constant"]
            max_order = info["order normalization constant"]

            if self.itEnv.include_income_to_observation:
                assert observation.shape[
                    0] == 3 * self.itEnv.world_size + self.itEnv.n_intervals + 3
                size_without_income = 2 * self.itEnv.world_size + self.itEnv.n_intervals
                ws = self.itEnv.world_size
                offset = current_cell
                global_observation[:
                                   size_without_income] = observation[:
                                                                      size_without_income]
                global_observation[size_without_income +
                                   3 * offset:size_without_income +
                                   3 * offset + 3] = observation[-3:]
            else:
                global_observation = observation[:-self.itEnv.world_size]

        # if cells_with_nonzero_drivers == 0:
        #     observation, reward, done, info = self.itEnv.step(action_per_cell)
        #
        #     reward_per_node[current_cell] = reward
        #     global_done = done
        #     global_reward += reward
        #     total_served_orders += info['served_orders']
        #
        #     # updated at each step, but the final should be corrent
        #     max_driver = info["driver normalization constant"]
        #     max_order = info["order normalization constant"]
        #
        #     if self.itEnv.include_income_to_observation:
        #         assert observation.shape[0] == 3*self.itEnv.world_size+self.itEnv.n_intervals+3
        #         size_without_income = 2*self.itEnv.world_size+self.itEnv.n_intervals
        #         ws = self.itEnv.world_size
        #         offset = current_cell
        #         global_observation[:size_without_income] = observation[:size_without_income]
        #         global_observation[size_without_income+3*offset:size_without_income+3*offset+3] = observation[-3:]
        #     else:
        #         global_observation = observation[:-self.itEnv.world_size]

        assert not global_done or init_t + 1 == self.itEnv.n_intervals
        assert self.itEnv.time == init_t + 1
        self.time += 1

        global_info = {
            "reward_per_node":
            reward_per_node,
            "served_orders":
            total_served_orders,
            "nodes_with_drivers":
            cells_with_nonzero_drivers,
            "nodes_with_orders":
            nodes_with_orders,
            "driver normalization constant":
            max_driver,
            "order normalization constant":
            max_order,
            "total_orders":
            total_orders,
            "idle_reward":
            float(
                np.mean([
                    d.get_idle_period() for d in self.itEnv.all_driver_list
                ])),
            "min_idle":
            float(
                np.min(
                    [d.get_idle_period() for d in self.itEnv.all_driver_list]))
        }
        return global_observation, global_reward, global_done, global_info

    def seed(self, seed):
        self.itEnv.seed(seed)

    def get_min_revenue(self):
        return self.itEnv.get_min_revenue()

    def get_total_revenue(self):
        return self.itEnv.get_total_revenue()

    def compute_remaining_drivers_and_orders(self, state):
        return self.itEnv.compute_remaining_drivers_and_orders(state)

    def set_income_bound(self, bound):
        self.itEnv.set_income_bound(bound)

    def render(self, mode='rgb_array'):
        fig = plt.figure(figsize=(20, 20))
        ax = fig.gca()
        ax.axis('off')

        pos = nx.get_node_attributes(self.world, 'coords')
        G = nx.DiGraph(self.world)
        nodelist = []
        edgelist = []
        action = self.last_action_for_drawing
        act = self.itEnv.action_space_shape[0]
        node_colors = []
        edge_colors = []
        for n in self.world.nodes():
            node_action = action[act * n:act * (n + 1)]
            nodelist.append(n)
            node_colors.append(node_action[-1])
            j = 0
            added = 0
            for nn in self.world.neighbors(n):
                if node_action[j] > 0:
                    edgelist.append((n, nn))
                    edge_colors.append(node_action[j])
                    added += 1
                j += 1
            assert abs(np.sum(node_action) - 1) < 0.00001, node_action
            assert node_action[-1] != 0 or added > 0, (node_action, n)

        nx.draw_networkx(G,
                         edgelist=edgelist,
                         edge_color=edge_colors,
                         vmin=-1,
                         vmax=1,
                         node_shape='.',
                         edge_vmax=1.1,
                         cmap=matplotlib.cm.get_cmap("Blues"),
                         edge_cmap=matplotlib.cm.get_cmap("Blues"),
                         node_color=node_colors,
                         nodelist=nodelist,
                         pos=pos,
                         arrows=True,
                         with_labels=False,
                         ax=ax)

        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        s, (width, height) = canvas.print_to_buffer()

        # Option 2a: Convert to a NumPy array
        X = np.frombuffer(s, np.uint8).reshape((height, width, 4))
        plt.close(fig)
        return X