예제 #1
0
파일: data.py 프로젝트: mecha2k/rl_handson
    def handle_obs_space(cls, env: magent.GridWorld, handle) -> gym.Space:
        # view shape
        v = env.get_view_space(handle)
        # extra features
        r = env.get_feature_space(handle)

        # rearrange planes to pytorch convention
        view_shape = (v[-1],) + v[:2]
        view_space = spaces.Box(low=0.0, high=1.0, shape=view_shape)
        extra_space = spaces.Box(low=0.0, high=1.0, shape=r)
        return spaces.Tuple((view_space, extra_space))
예제 #2
0
    def handle_observation_space(cls, environment: GridWorld,
                                 handle: int) -> gym.Space:
        magent_view_space: Tuple = environment.get_view_space(handle)
        magent_feature_space: Tuple = environment.get_feature_space(handle)

        view_shape: Tuple = (magent_view_space[-1], ) + magent_view_space[:2]
        view_space: spaces.Box = spaces.Box(low=0.0,
                                            high=1.0,
                                            shape=view_shape)
        extra_space: spaces.Box = spaces.Box(low=0.0,
                                             high=1.0,
                                             shape=magent_feature_space)

        return spaces.Tuple((view_space, extra_space))
예제 #3
0
    def __init__(self, env: magent.GridWorld, handle,
                 reset_env_func: Callable[[], None]):
        reset_env_func()
        action_space = spaces.Discrete(env.get_action_space(handle)[0])

        # view shape
        v = env.get_view_space(handle)
        # extra features
        r = env.get_feature_space(handle)

        # rearrange planes to pytorch convention
        view_shape = (v[-1],) + v[:2]
        view_space = spaces.Box(low=0.0, high=1.0, shape=view_shape)
        extra_space = spaces.Box(low=0.0, high=1.0, shape=r)
        observation_space = spaces.Tuple((view_space, extra_space))

        count = env.get_num(handle)

        super(MAgentEnv, self).__init__(count, observation_space, action_space)
        self.action_space = self.single_action_space
        self._env = env
        self._handle = handle
        self._reset_env_func = reset_env_func
예제 #4
0
파일: data.py 프로젝트: mecha2k/rl_handson
    def handle_observations(
        cls, env: magent.GridWorld, handle
    ) -> List[Tuple[np.ndarray, np.ndarray]]:
        view_obs, feats_obs = env.get_observation(handle)
        entries = view_obs.shape[0]
        if entries == 0:
            return []
        # copy data
        view_obs = np.array(view_obs)
        feats_obs = np.array(feats_obs)
        view_obs = np.moveaxis(view_obs, 3, 1)

        res = []
        for o_view, o_feats in zip(np.vsplit(view_obs, entries), np.vsplit(feats_obs, entries)):
            res.append((o_view[0], o_feats[0]))
        return res
예제 #5
0
def test_model(dqn_model: DQNModel, device: torch.device,
               configuration: Config) -> Tuple[float, float, float]:
    gridworld_test: GridWorld = GridWorld(configuration, map_size=MAP_SIZE)
    deer_handle: int
    tiger_handle: int
    deer_handle, tiger_handle = gridworld_test.get_handles()

    def reset_environment():
        gridworld_test.reset()
        gridworld_test.add_walls(method="random",
                                 n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        gridworld_test.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        gridworld_test.add_agents(tiger_handle,
                                  method="random",
                                  n=COUNT_TIGERS)

    magent_environment: MAgentEnv = MAgentEnv(
        gridworld_test,
        tiger_handle,
        reset_environment_funcion=reset_environment)
    pre_processor: MAgentPreprocessor = MAgentPreprocessor(device)
    dqn_agent: ptan.agent.DQNAgent = ptan.agent.DQNAgent(
        dqn_model, ArgmaxActionSelector(), device, preprocessor=pre_processor)

    observations = magent_environment.reset()
    steps: int = 0
    rewards: float = 0.0
    survivors: int = COUNT_DEERS

    while True:
        actions = dqn_agent(observations)[0]
        observations, all_rewards, dones, _ = magent_environment.step(actions)
        steps += len(observations)
        rewards += sum(all_rewards)

        # Temporary hack
        current_survivors: int = np.count_nonzero(
            gridworld_test.get_alive(deer_handle))
        if current_survivors <= survivors:
            survivors = current_survivors

        if dones[0]:
            break

    return rewards / COUNT_TIGERS, steps / COUNT_TIGERS, survivors / COUNT_DEERS
예제 #6
0
    def handle_observations(cls, environment: GridWorld,
                            handle: int) -> List[Tuple[ndarray, ndarray]]:
        view_observation, feature_observation = environment.get_observation(
            handle)
        entries: int = view_observation.shape[0]
        if entries == 0:
            return []

        view_observation_array: ndarray = np.array(view_observation)
        feature_observation_array: ndarray = np.array(feature_observation)
        view_observation_array = np.moveaxis(view_observation_array, 3, 1)

        result: List[Tuple[ndarray, ndarray]] = []
        for observation, features in zip(
                np.vsplit(view_observation_array, entries),
                np.vsplit(feature_observation_array, entries)):
            result.append((observation[0], features[0]))

        return result
예제 #7
0
    def __init__(self, env: magent.GridWorld, handle,
                 reset_env_func: Callable[[], None],
                 is_slave: bool = False,
                 steps_limit: Optional[int] = None):
        reset_env_func()
        action_space = self.handle_action_space(env, handle)
        observation_space = self.handle_obs_space(env, handle)

        count = env.get_num(handle)

        super(MAgentEnv, self).__init__(count, observation_space,
                                        action_space)
        self.action_space = self.single_action_space
        self._env = env
        self._handle = handle
        self._reset_env_func = reset_env_func
        self._is_slave = is_slave
        self._steps_limit = steps_limit
        self._steps_done = 0
예제 #8
0
    def __init__(self,
                 environment: GridWorld,
                 handle: int,
                 reset_environment_funcion: Callable[[], None],
                 is_slave: bool = False,
                 step_limit: Optional[int] = None):
        self._steps_done: int = 0
        reset_environment_funcion()
        action_space: gym.Space = self.handle_action_space(environment, handle)
        observation_space: gym.Space = self.handle_observation_space(
            environment, handle)
        number_of_agents: int = environment.get_num(handle)

        super(MAgentEnv, self).__init__(number_of_agents, observation_space,
                                        action_space)

        self.action_space = self.single_action_space
        self._env: GridWorld = environment
        self._handle: int = handle
        self._reset_env_func: Callable[[], None] = reset_environment_funcion
        self._is_slave: bool = is_slave
        self._steps_limit: Optional[int] = step_limit
예제 #9
0
 def handle_action_space(cls, environment: GridWorld,
                         handle: int) -> gym.Space:
     return spaces.Discrete(environment.get_action_space(handle)[0])
예제 #10
0
 def handle_action_space(cls, env: magent.GridWorld,
                         handle) -> gym.Space:
     return spaces.Discrete(env.get_action_space(handle)[0])
예제 #11
0

if __name__ == "__main__":

    if not CUDA:
        print("CUDA is not enabled!")
    else:
        print("Training on GPU")
    run_name: str = "edible_tigers"
    configuration: Config = social.get_forest_configuration(MAP_SIZE)

    device: torch.device = torch.device("cuda" if CUDA else "cpu")
    saves_path = os.path.join("saves", run_name)
    os.makedirs(saves_path, exist_ok=True)

    gridworld: GridWorld = GridWorld(configuration, map_size=MAP_SIZE)

    deer_handle: int
    tiger_handle: int
    deer_handle, tiger_handle = gridworld.get_handles()

    def reset_environment():
        gridworld.reset()
        gridworld.add_walls(method="random",
                            n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        gridworld.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        gridworld.add_agents(tiger_handle, method="random", n=COUNT_TIGERS)

    environment: MAgentEnv = MAgentEnv(
        gridworld, tiger_handle, reset_environment_funcion=reset_environment)