Example #1
0
 def from_type(action_type):
     space_type = ActionSpaceType(action_type)
     if space_type == ActionSpaceType.Continuous:
         return ActionAdapter.continuous_action_adapter
     elif space_type == ActionSpaceType.Lane:
         return ActionAdapter.discrete_action_adapter
     else:
         raise NotImplementedError
Example #2
0
 def from_type(action_type: int):
     space_type = ActionSpaceType(action_type)
     if space_type == ActionSpaceType.Continuous:
         return gym.spaces.Box(
             low=np.array([0.0, 0.0, -1.0]),
             high=np.array([1.0, 1.0, 1.0]),
             dtype=np.float32,
         )
     elif space_type == ActionSpaceType.Lane:
         return gym.spaces.Discrete(4)
     else:
         raise NotImplementedError
Example #3
0
def _make_rllib_config(config, mode="training"):
    """Generate agent configuration. `mode` can be `train` or 'evaluation', and the
    only difference on the generated configuration is the agent info adapter.
    """

    agent_config = config["agent"]
    interface_config = config["interface"]
    """ Parse the state configuration for agent """
    state_config = agent_config["state"]

    # initialize environment wrapper if the wrapper config is not None
    wrapper_config = state_config.get("wrapper", {"name": "Simple"})
    features_config = state_config["features"]
    # only for one frame, not really an observation
    frame_space = gym.spaces.Dict(common.subscribe_features(**features_config))
    action_type = ActionSpaceType(agent_config["action"]["type"])
    env_action_space = common.ActionSpace.from_type(action_type)
    wrapper_cls = getattr(rllib_wrappers, wrapper_config["name"])
    """ Parse policy configuration """
    policy_obs_space = wrapper_cls.get_observation_space(
        frame_space, wrapper_config)
    policy_action_space = wrapper_cls.get_action_space(env_action_space,
                                                       wrapper_config)

    observation_adapter = wrapper_cls.get_observation_adapter(
        policy_obs_space,
        feature_configs=features_config,
        wrapper_config=wrapper_config)
    action_adapter = wrapper_cls.get_action_adapter(action_type,
                                                    policy_action_space,
                                                    wrapper_config)
    # policy observation space is related to the wrapper usage
    policy_config = (
        None,
        policy_obs_space,
        policy_action_space,
        config["policy"].get(
            "config", {"custom_preprocessor": wrapper_cls.get_preprocessor()}),
    )
    """ Parse agent interface configuration """
    if interface_config.get("neighborhood_vehicles"):
        interface_config["neighborhood_vehicles"] = NeighborhoodVehicles(
            **interface_config["neighborhood_vehicles"])

    if interface_config.get("waypoints"):
        interface_config["waypoints"] = Waypoints(
            **interface_config["waypoints"])

    if interface_config.get("rgb"):
        interface_config["rgb"] = RGB(**interface_config["rgb"])

    if interface_config.get("ogm"):
        interface_config["ogm"] = OGM(**interface_config["ogm"])

    interface_config["action"] = ActionSpaceType(action_type)
    """ Pack environment configuration """
    config["run"]["config"].update({"env": wrapper_cls})
    config["env_config"] = {
        "custom_config": {
            **wrapper_config,
            "reward_adapter":
            wrapper_cls.get_reward_adapter(observation_adapter),
            "observation_adapter":
            observation_adapter,
            "action_adapter":
            action_adapter,
            "info_adapter":
            metrics.agent_info_adapter if mode == "evaluation" else None,
            "observation_space":
            policy_obs_space,
            "action_space":
            policy_action_space,
        },
    }
    config["agent"] = {"interface": AgentInterface(**interface_config)}
    config["trainer"] = _get_trainer(**config["policy"]["trainer"])
    config["policy"] = policy_config

    print(format.pretty_dict(config))

    return config
Example #4
0
def _make_rllib_config(config, mode="train"):
    """ Generate agent configuration. `mode` can be `train` or 'evaluation', and the
    only difference on the generated configuration is the agent info adapter.
    """

    agent = config["agent"]
    state_config = agent["state"]

    # initialize environment wrapper if the wrapper config is not None
    wrapper_config = state_config["wrapper"]
    wrapper = (
        getattr(rllib_wrappers, wrapper_config["name"]) if wrapper_config else None
    )

    features = state_config["features"]
    # only for one frame, not really an observation
    frame_space = gym.spaces.Dict(common.subscribe_features(**features))

    action_type = agent["action"]["type"]
    action_space = common.ActionSpace.from_type(action_type)

    # policy observation space is related to the wrapper usage
    policy_obs_space = _get_policy_observation_space(
        wrapper, frame_space, wrapper_config
    )
    policy_config = (
        None,
        policy_obs_space,
        action_space,
        config["policy"].get("config", {}),
    )

    interface = config["interface"]
    if interface.get("neighborhood_vehicles"):
        interface["neighborhood_vehicles"] = NeighborhoodVehicles(
            **interface["neighborhood_vehicles"]
        )

    if interface.get("waypoints"):
        interface["waypoints"] = Waypoints(**interface["waypoints"])

    if interface.get("rgb"):
        interface["rgb"] = RGB(**interface["rgb"])

    if interface.get("ogm"):
        interface["ogm"] = OGM(**interface["ogm"])

    interface["action"] = ActionSpaceType(action_type)

    adapter_type = "vanilla" if wrapper == rllib_wrappers.FrameStack else "single_frame"

    agent_config = dict(
        action_adapter=common.ActionAdapter.from_type(action_type),
        info_adapter=metrics.agent_info_adapter
        if mode == "evaluate"
        else common.default_info_adapter,
    )

    adapter_type = (
        "stack_frame" if wrapper == rllib_wrappers.FrameStack else "single_frame"
    )
    if agent.get("adapter_type", {}) != {}:
        adapter_type = agent.get("adapter_type", {})
    observation_adapter = common.get_observation_adapter(
        frame_space, adapter_type, wrapper=wrapper, feature_configs=features
    )

    env_config = dict()
    wrapper_config["parameters"] = {
        "observation_adapter": observation_adapter,
        "reward_adapter": common.get_reward_adapter(observation_adapter, adapter_type),
        "base_env_cls": RLlibHiWayEnv,
    }
    env_config.update(**wrapper_config["parameters"])
    config["run"]["config"].update({"env": wrapper})

    config["env_config"] = env_config
    config["agent"] = agent_config
    config["interface"] = interface
    config["trainer"] = _get_trainer(**config["policy"]["trainer"])
    config["policy"] = policy_config

    pprint.pprint(config)

    return config