예제 #1
0
파일: adapter.py 프로젝트: zbzhu99/SMARTS
    def __init__(self, agent_name):
        self.policy_params = load_yaml(
            f"ultra/baselines/{agent_name}/{agent_name}/params.yaml"
        )

        social_vehicle_params = self.policy_params["social_vehicles"]
        social_vehicle_params["observation_num_lookahead"] = self.policy_params[
            "observation_num_lookahead"
        ]
        self.observation_num_lookahead = social_vehicle_params[
            "observation_num_lookahead"
        ]
        self.num_social_features = social_vehicle_params["num_social_features"]
        self.social_capacity = social_vehicle_params["social_capacity"]
        self.social_vehicle_config = get_social_vehicle_configs(
            encoder_key=social_vehicle_params["encoder_key"],
            num_social_features=self.num_social_features,
            social_capacity=self.social_capacity,
            seed=social_vehicle_params["seed"],
            social_policy_hidden_units=social_vehicle_params[
                "social_policy_hidden_units"
            ],
            social_policy_init_std=social_vehicle_params["social_policy_init_std"],
        )

        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]

        self.state_preprocessor = BaselineStatePreprocessor(
            social_vehicle_config=self.social_vehicle_config,
            observation_waypoints_lookahead=self.observation_num_lookahead,
            action_size=2,
        )

        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"
        ]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"
        ]

        self.state_size = self.state_preprocessor.num_low_dim_states
        if self.social_feature_encoder_class:
            self.state_size += self.social_feature_encoder_class(
                **self.social_feature_encoder_params
            ).output_dim
        else:
            self.state_size += self.social_capacity * self.num_social_features
예제 #2
0
 def __new__(
     self,
     policy_class,
     action_type,
     checkpoint_dir=None,
     task=None,
     max_episode_steps=1200,
     experiment_dir=None,
 ):
     if experiment_dir:
         print(f"LOADING SPEC from {experiment_dir}/spec.pkl")
         with open(f"{experiment_dir}/spec.pkl", "rb") as input:
             spec = dill.load(input)
             new_spec = AgentSpec(
                 interface=spec.interface,
                 agent_params=dict(
                     policy_params=spec.agent_params["policy_params"],
                     checkpoint_dir=checkpoint_dir,
                 ),
                 agent_builder=spec.policy_builder,
                 observation_adapter=spec.observation_adapter,
                 reward_adapter=spec.reward_adapter,
             )
             spec = new_spec
     else:
         adapter = BaselineAdapter()
         policy_dir = "/".join(
             inspect.getfile(policy_class).split("/")[:-1])
         policy_params = load_yaml(f"{policy_dir}/params.yaml")
         spec = AgentSpec(
             interface=AgentInterface(
                 waypoints=Waypoints(lookahead=20),
                 neighborhood_vehicles=NeighborhoodVehicles(200),
                 action=action_type,
                 rgb=False,
                 max_episode_steps=max_episode_steps,
                 debug=True,
             ),
             agent_params=dict(policy_params=policy_params,
                               checkpoint_dir=checkpoint_dir),
             agent_builder=policy_class,
             observation_adapter=adapter.observation_adapter,
             reward_adapter=adapter.reward_adapter,
         )
     return spec
예제 #3
0
def train(
    task,
    num_episodes,
    max_episode_steps,
    rollout_fragment_length,
    policy,
    eval_info,
    timestep_sec,
    headless,
    seed,
    train_batch_size,
    sgd_minibatch_size,
    log_dir,
):
    agent_name = policy
    policy_params = load_yaml(
        f"ultra/baselines/{agent_name}/{agent_name}/params.yaml")

    action_type = adapters.type_from_string(policy_params["action_type"])
    observation_type = adapters.type_from_string(
        policy_params["observation_type"])
    reward_type = adapters.type_from_string(policy_params["reward_type"])

    if action_type != adapters.AdapterType.DefaultActionContinuous:
        raise Exception(
            f"RLlib training only supports the "
            f"{adapters.AdapterType.DefaultActionContinuous} action type.")
    if observation_type != adapters.AdapterType.DefaultObservationVector:
        # NOTE: The SMARTS observations adaptation that is done in ULTRA's Gym
        #       environment is not done in ULTRA's RLlib environment. If other
        #       observation adapters are used, they may raise an Exception.
        raise Exception(
            f"RLlib training only supports the "
            f"{adapters.AdapterType.DefaultObservationVector} observation type."
        )

    action_space = adapters.space_from_type(adapter_type=action_type)
    observation_space = adapters.space_from_type(adapter_type=observation_type)

    action_adapter = adapters.adapter_from_type(adapter_type=action_type)
    info_adapter = adapters.adapter_from_type(
        adapter_type=adapters.AdapterType.DefaultInfo)
    observation_adapter = adapters.adapter_from_type(
        adapter_type=observation_type)
    reward_adapter = adapters.adapter_from_type(adapter_type=reward_type)

    params_seed = policy_params["seed"]
    encoder_key = policy_params["social_vehicles"]["encoder_key"]
    num_social_features = observation_space["social_vehicles"].shape[1]
    social_capacity = observation_space["social_vehicles"].shape[0]
    social_policy_hidden_units = int(policy_params["social_vehicles"].get(
        "social_policy_hidden_units", 0))
    social_policy_init_std = int(policy_params["social_vehicles"].get(
        "social_policy_init_std", 0))
    social_vehicle_config = get_social_vehicle_configs(
        encoder_key=encoder_key,
        num_social_features=num_social_features,
        social_capacity=social_capacity,
        seed=params_seed,
        social_policy_hidden_units=social_policy_hidden_units,
        social_policy_init_std=social_policy_init_std,
    )

    ModelCatalog.register_custom_model("fc_model", CustomFCModel)
    config = RllibAgent.rllib_default_config(agent_name)

    rllib_policies = {
        "default_policy": (
            None,
            observation_space,
            action_space,
            {
                "model": {
                    "custom_model": "fc_model",
                    "custom_model_config": {
                        "social_vehicle_config": social_vehicle_config
                    },
                }
            },
        )
    }
    agent_specs = {
        "AGENT-007":
        AgentSpec(
            interface=AgentInterface(
                waypoints=Waypoints(lookahead=20),
                neighborhood_vehicles=NeighborhoodVehicles(200),
                action=ActionSpaceType.Continuous,
                rgb=False,
                max_episode_steps=max_episode_steps,
                debug=True,
            ),
            agent_params={},
            agent_builder=None,
            action_adapter=action_adapter,
            info_adapter=info_adapter,
            observation_adapter=observation_adapter,
            reward_adapter=reward_adapter,
        )
    }

    tune_config = {
        "env": RLlibUltraEnv,
        "log_level": "WARN",
        "callbacks": Callbacks,
        "framework": "torch",
        "num_workers": 1,
        "train_batch_size": train_batch_size,
        "sgd_minibatch_size": sgd_minibatch_size,
        "rollout_fragment_length": rollout_fragment_length,
        "in_evaluation": True,
        "evaluation_num_episodes": eval_info["eval_episodes"],
        "evaluation_interval": eval_info[
            "eval_rate"],  # Evaluation occurs after # of eval-intervals (episodes)
        "evaluation_config": {
            "env_config": {
                "seed": seed,
                "scenario_info": task,
                "headless": headless,
                "eval_mode": True,
                "ordered_scenarios": False,
                "agent_specs": agent_specs,
                "timestep_sec": timestep_sec,
            },
            "explore": False,
        },
        "env_config": {
            "seed": seed,
            "scenario_info": task,
            "headless": headless,
            "eval_mode": False,
            "ordered_scenarios": False,
            "agent_specs": agent_specs,
            "timestep_sec": timestep_sec,
        },
        "multiagent": {
            "policies": rllib_policies
        },
    }

    config.update(tune_config)
    agent = RllibAgent(
        agent_name=agent_name,
        env=RLlibUltraEnv,
        config=tune_config,
        logger_creator=log_creator(log_dir),
    )

    # Iteration value in trainer.py (self._iterations) is the technically the number of episodes
    for i in range(num_episodes):
        results = agent.train()
        agent.log_evaluation_metrics(
            results)  # Evaluation metrics will now be displayed on Tensorboard
예제 #4
0
    def __new__(
        self,
        policy_class,
        # action_type,
        policy_params=None,
        checkpoint_dir=None,
        # task=None,
        max_episode_steps=1200,
        experiment_dir=None,
        agent_id="",
    ):
        if experiment_dir:
            print(
                f"Loading spec for {agent_id} from {experiment_dir}/agent_metadata.pkl"
            )
            with open(f"{experiment_dir}/agent_metadata.pkl",
                      "rb") as metadata_file:
                agent_metadata = dill.load(metadata_file)
                spec = agent_metadata["agent_specs"][agent_id]

                new_spec = AgentSpec(
                    interface=spec.interface,
                    agent_params=dict(
                        policy_params=spec.agent_params["policy_params"],
                        checkpoint_dir=checkpoint_dir,
                    ),
                    agent_builder=spec.agent_builder,
                    observation_adapter=spec.observation_adapter,
                    reward_adapter=spec.reward_adapter,
                    info_adapter=spec.info_adapter,
                )

                spec = new_spec
        else:
            # If policy_params is None, then there must be a params.yaml file in the
            # same directory as the policy_class module.
            if not policy_params:
                policy_class_module_file = inspect.getfile(policy_class)
                policy_class_module_directory = os.path.dirname(
                    policy_class_module_file)
                policy_params = load_yaml(
                    os.path.join(policy_class_module_directory, "params.yaml"))

            action_type = adapters.type_from_string(
                string_type=policy_params["action_type"])
            observation_type = adapters.type_from_string(
                string_type=policy_params["observation_type"])
            reward_type = adapters.type_from_string(
                string_type=policy_params["reward_type"])
            info_type = adapters.AdapterType.DefaultInfo

            adapter_interface_requirements = adapters.required_interface_from_types(
                action_type, observation_type, reward_type, info_type)
            action_adapter = adapters.adapter_from_type(
                adapter_type=action_type)
            observation_adapter = adapters.adapter_from_type(
                adapter_type=observation_type)
            reward_adapter = adapters.adapter_from_type(
                adapter_type=reward_type)
            info_adapter = adapters.adapter_from_type(adapter_type=info_type)

            spec = AgentSpec(
                interface=AgentInterface(
                    **adapter_interface_requirements,
                    max_episode_steps=max_episode_steps,
                    debug=True,
                ),
                agent_params=dict(policy_params=policy_params,
                                  checkpoint_dir=checkpoint_dir),
                agent_builder=policy_class,
                action_adapter=action_adapter,
                observation_adapter=observation_adapter,
                reward_adapter=reward_adapter,
                info_adapter=info_adapter,
            )

        return spec