예제 #1
0
파일: __init__.py 프로젝트: zbzhu99/SMARTS
        "u_accel": 0.1,
        "u_yaw_rate": 1.0,
        "terminal": 0.01,
        "impatience": 0.01,
        "speed": 0.01,
        "rate": 1,
    },
    debug=False,
    aggressiveness=0,
    max_episode_steps=None,
):
    from .agent import OpEnAgent

    return AgentSpec(
        interface=AgentInterface(
            action=ActionSpaceType.Trajectory,
            waypoints=True,
            neighborhood_vehicles=True,
            max_episode_steps=max_episode_steps,
            agent_behavior=AgentBehavior(aggressiveness=aggressiveness),
        ),
        agent_params={
            "gains": gains,
            "debug": debug,
        },
        agent_builder=OpEnAgent,
    )


register(locator="open_agent-v0", entry_point=entrypoint)
예제 #2
0
class KeepLaneAgent(Agent):
    def act(self, obs):
        return "keep_lane"


class StoppedAgent(Agent):
    def act(self, obs):
        return (0, 0)


register(
    locator="laner-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(
            waypoints=True, action=ActionSpaceType.Lane, max_episode_steps=5000
        ),
        agent_builder=KeepLaneAgent,
    ),
)

register(
    locator="buddha-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(
            waypoints=True,
            action=ActionSpaceType.LaneWithContinuousSpeed,
            done_criteria=DoneCriteria(not_moving=True),
        ),
        agent_builder=StoppedAgent,
    ),
예제 #3
0
    target_speed=15,
    lane_change_speed=12.5,
):
    with pkg_resources.path(rl_agent, "checkpoint") as checkpoint_path:
        return AgentSpec(
            interface=agent_interface,
            observation_adapter=get_observation_adapter(
                goal_is_nearby_threshold=goal_is_nearby_threshold,
                lane_end_threshold=lane_end_threshold,
                lane_crash_distance_threshold=lane_crash_distance_threshold,
                lane_crash_ttc_threshold=lane_crash_ttc_threshold,
                intersection_crash_distance_threshold=
                intersection_crash_distance_threshold,
                intersection_crash_ttc_threshold=
                intersection_crash_ttc_threshold,
            ),
            action_adapter=get_action_adapter(
                target_speed=target_speed,
                lane_change_speed=lane_change_speed,
            ),
            agent_builder=lambda: RLAgent(
                load_path=str((checkpoint_path / "checkpoint").absolute()),
                policy_name="default_policy",
                observation_space=OBSERVATION_SPACE,
                action_space=ACTION_SPACE,
            ),
        )


register(locator="rl-agent-v1", entry_point=entrypoint)
예제 #4
0
from smarts.core.agent_interface import AgentInterface, AgentType
from smarts.zoo.registry import register


class SimpleAgent(Agent):
    def act(self, obs):
        return "keep_lane"


# You can register a callable that will build your AgentSpec
def demo_agent_callable(target_prefix=None, interface=None):
    if interface is None:
        interface = AgentInterface.from_type(AgentType.Laner)
    return AgentSpec(interface=interface, agent_builder=SimpleAgent)


register(
    locator="zoo-agent1-v0",
    entry_point="smarts.core.agent:AgentSpec",
    # Also works:
    # entry_point=smarts.core.agent.AgentSpec
    interface=AgentInterface.from_type(AgentType.Laner, max_episode_steps=20000),
)

register(
    locator="zoo-agent2-v0",
    entry_point=demo_agent_callable,
    # Also works:
    # entry_point="scenarios.zoo_intersection:demo_agent_callable",
)
예제 #5
0
            [
                obs.waypoint_paths[lane_index][i].pos[1]
                for i in range(num_trajectory_points)
            ],
            [
                obs.waypoint_paths[lane_index][i].heading
                for i in range(num_trajectory_points)
            ],
            [desired_speed for i in range(num_trajectory_points)],
        ]
        return trajectory


register(
    locator="pose-boid-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(action=ActionSpaceType.TargetPose,
                                 waypoints=True),
        agent_builder=PoseBoidAgent,
    ),
)

register(
    locator="trajectory-boid-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(action=ActionSpaceType.Trajectory,
                                 waypoints=True),
        agent_builder=TrajectoryBoidAgent,
    ),
)
예제 #6
0
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from smarts.zoo.registry import register
from .sac.sac.policy import SACPolicy
from .ppo.ppo.policy import PPOPolicy
from .dqn.dqn.policy import DQNPolicy
from .td3.td3.policy import TD3Policy
from .bdqn.bdqn.policy import BehavioralDQNPolicy
from smarts.core.controllers import ActionSpaceType
from ultra.baselines.agent_spec import BaselineAgentSpec

register(
    locator="sac-v0",
    entry_point=lambda **kwargs: BaselineAgentSpec(policy_class=SACPolicy,
                                                   **kwargs),
)
register(
    locator="ppo-v0",
    entry_point=lambda **kwargs: BaselineAgentSpec(policy_class=PPOPolicy,
                                                   **kwargs),
)
register(
    locator="td3-v0",
    entry_point=lambda **kwargs: BaselineAgentSpec(policy_class=TD3Policy,
                                                   **kwargs),
)
register(
    locator="dqn-v0",
    entry_point=lambda **kwargs: BaselineAgentSpec(policy_class=DQNPolicy,
예제 #7
0
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from smarts.zoo.registry import register
from .sac.sac.policy import SACPolicy
from .ppo.ppo.policy import PPOPolicy
from .dqn.dqn.policy import DQNPolicy
from .ddpg.ddpg.policy import TD3Policy
from .bdqn.bdqn.policy import BehavioralDQNPolicy
from smarts.core.controllers import ActionSpaceType
from ultra.baselines.agent_spec import BaselineAgentSpec

register(
    locator="sac-v0",
    entry_point=lambda **kwargs: BaselineAgentSpec(action_type=ActionSpaceType.
                                                   Continuous,
                                                   policy_class=SACPolicy,
                                                   **kwargs),
)
register(
    locator="ppo-v0",
    entry_point=lambda **kwargs: BaselineAgentSpec(action_type=ActionSpaceType.
                                                   Continuous,
                                                   policy_class=PPOPolicy,
                                                   **kwargs),
)
register(
    locator="ddpg-v0",
    entry_point=lambda **kwargs: BaselineAgentSpec(action_type=ActionSpaceType.
                                                   Continuous,
                                                   policy_class=TD3Policy,
예제 #8
0
import cross_rl_agent

from smarts.zoo.agent_spec import AgentSpec
from smarts.zoo.registry import register

from .agent import RLAgent
from .cross_space import (
    action_adapter,
    cross_interface,
    get_aux_info,
    observation_adapter,
    reward_adapter,
)


def entrypoint():
    with pkg_resources.path(cross_rl_agent, "models") as model_path:
        return AgentSpec(
            interface=cross_interface,
            observation_adapter=observation_adapter,
            action_adapter=action_adapter,
            agent_builder=lambda: RLAgent(
                load_path=str(model_path) + "/",
                policy_name="Soc_Mt_TD3Network",
            ),
        )


register(locator="cross_rl_agent-v1", entry_point=entrypoint)
예제 #9
0
    def act(self, obs):
        return "keep_lane"


class MotionPlannerAgent(Agent):
    def act(self, obs):
        wp = obs.waypoint_paths[0][:5][-1]
        dist_to_wp = np.linalg.norm(wp.pos -
                                    obs.ego_vehicle_state.position[:2])
        target_speed = 5  # m/s
        return np.array([*wp.pos, wp.heading, dist_to_wp / target_speed])


register(
    locator="zoo-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface.from_type(AgentType.Laner,
                                           max_episode_steps=20000),
        agent_builder=KeepLaneAgent,
    ),
)

register(
    locator="motion-planner-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(waypoints=True,
                                 action=ActionSpaceType.TargetPose),
        agent_builder=MotionPlannerAgent,
    ),
)
예제 #10
0
import numpy as np

from smarts.zoo.registry import register
from smarts.core.agent_interface import AgentInterface, AgentType
from smarts.core.agent import Agent, AgentSpec
from smarts.core.controllers import ActionSpaceType


class BasicAgent(Agent):
    def act(self, obs):
        return "keep_lane"


register(
    locator="minimal",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(waypoints=True, action=ActionSpaceType.Lane),
        agent_builder=BasicAgent,
    ),
)
예제 #11
0
        "position": 4.0,
        "obstacle": 3.0,
        "u_accel": 0.1,
        "u_yaw_rate": 1.0,
        "terminal": 0.01,
        "impatience": 0.01,
        "speed": 0.01,
    },
    debug=False,
    max_episode_steps=600,
):
    from .policy import Policy

    return AgentSpec(
        interface=AgentInterface(
            action=ActionSpaceType.Trajectory,
            waypoints=True,
            neighborhood_vehicles=True,
            max_episode_steps=max_episode_steps,
        ),
        policy_params={
            "gains": gains,
            "debug": debug,
        },
        policy_builder=Policy,
        perform_self_test=False,
    )


register(locator="open_agent-v0", entry_point=make_agent_spec)
예제 #12
0
from typing import Any, Dict

from smarts.core.agent_interface import AgentInterface, AgentType
from smarts.core.controllers import ActionSpaceType
from smarts.zoo.agent_spec import AgentSpec
from smarts.zoo.registry import make, register

from .keep_lane_agent import KeepLaneAgent
from .non_interactive_agent import NonInteractiveAgent

register(
    locator="non-interactive-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(waypoints=True,
                                 action=ActionSpaceType.TargetPose),
        agent_builder=NonInteractiveAgent,
        agent_params=kwargs,
    ),
)

register(
    locator="keep-lane-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface.from_type(AgentType.Laner,
                                           max_episode_steps=20000),
        agent_builder=KeepLaneAgent,
    ),
)


def klws_entrypoint(speed):
예제 #13
0
파일: __init__.py 프로젝트: LUMO666/Highway
from smarts.core.agent import AgentSpec
from smarts.core.agent_interface import AgentInterface, AgentType
from smarts.core.controllers import ActionSpaceType
from smarts.zoo.registry import register

from .keep_lane_agent import KeepLaneAgent
from .non_interactive_agent import NonInteractiveAgent


register(
    locator="non-interactive-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(waypoints=True, action=ActionSpaceType.TargetPose),
        agent_builder=NonInteractiveAgent,
        agent_params=kwargs,
    ),
)

register(
    locator="keep-lane-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface.from_type(AgentType.Laner, max_episode_steps=20000),
        agent_builder=KeepLaneAgent,
    ),
)
    # interface=AgentInterface.from_type(
    #     AgentType.Laner, max_episode_steps=2000
    # ),
    policy=KeeplanePolicy(),
)

random_agent = Agent(
    interface=AgentInterface(
        max_episode_steps=None,
        waypoints=True,
        neighborhood_vehicles=False,
        ogm=False,
        rgb=False,
        lidar=False,
        action=ActionSpaceType.Lane,
    ),
    policy=RandomPolicy(),
)

# keep lane social agent
register(
    locator="zoo-agent1-v0",
    entry_point=lambda: keeplane_agent,
)

# random act social agent
register(
    locator="zoo-agent2-v0",
    entry_point=lambda: random_agent,
)
예제 #15
0
            [
                obs.waypoint_paths[lane_index][i].heading
                for i in range(num_trajectory_points)
            ],
            [desired_speed for i in range(num_trajectory_points)],
        ]
        return trajectory


register(
    locator="pose-boid-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(
            action=ActionSpaceType.MultiTargetPose,
            waypoints=True,
            ogm=True,
            rgb=True,
            drivable_area_grid_map=True,
        ),
        policy_builder=PoseBoidPolicy,
    ),
)

register(
    locator="trajectory-boid-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(
            action=ActionSpaceType.Trajectory,
            waypoints=True,
        ),
        policy_builder=TrajectoryBoidPolicy,
예제 #16
0
from smarts.zoo.registry import register
from smarts.core.agent_interface import AgentInterface, AgentType
from smarts.core.agent import AgentPolicy, AgentSpec
from smarts.core.controllers import ActionSpaceType


class BoidPolicy(AgentPolicy):
    def act(self, obs):
        returning = {
            vehicle_id: self._single_act(obs_)
            for vehicle_id, obs_ in obs.items()
        }
        return returning

    def _single_act(self, obs):
        wp = obs.waypoint_paths[0][:5][-1]
        dist_to_wp = np.linalg.norm(wp.pos -
                                    obs.ego_vehicle_state.position[:2])
        target_speed = 5  # m/s
        return np.array([*wp.pos, wp.heading, dist_to_wp / target_speed])


register(
    locator="boid-agent-v0",
    entry_point=lambda **kwargs: AgentSpec(
        interface=AgentInterface(waypoints=True,
                                 action=ActionSpaceType.MultiTargetPose),
        policy_builder=BoidPolicy,
    ),
)