예제 #1
0
파일: adapter.py 프로젝트: zbzhu99/SMARTS
    def __init__(self, agent_name):
        self.policy_params = load_yaml(
            f"ultra/baselines/{agent_name}/{agent_name}/params.yaml"
        )

        social_vehicle_params = self.policy_params["social_vehicles"]
        social_vehicle_params["observation_num_lookahead"] = self.policy_params[
            "observation_num_lookahead"
        ]
        self.observation_num_lookahead = social_vehicle_params[
            "observation_num_lookahead"
        ]
        self.num_social_features = social_vehicle_params["num_social_features"]
        self.social_capacity = social_vehicle_params["social_capacity"]
        self.social_vehicle_config = get_social_vehicle_configs(
            encoder_key=social_vehicle_params["encoder_key"],
            num_social_features=self.num_social_features,
            social_capacity=self.social_capacity,
            seed=social_vehicle_params["seed"],
            social_policy_hidden_units=social_vehicle_params[
                "social_policy_hidden_units"
            ],
            social_policy_init_std=social_vehicle_params["social_policy_init_std"],
        )

        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]

        self.state_preprocessor = BaselineStatePreprocessor(
            social_vehicle_config=self.social_vehicle_config,
            observation_waypoints_lookahead=self.observation_num_lookahead,
            action_size=2,
        )

        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"
        ]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"
        ]

        self.state_size = self.state_preprocessor.num_low_dim_states
        if self.social_feature_encoder_class:
            self.state_size += self.social_feature_encoder_class(
                **self.social_feature_encoder_params
            ).output_dim
        else:
            self.state_size += self.social_capacity * self.num_social_features
예제 #2
0
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.action_size = int(policy_params["action_size"])
        self.action_range = np.asarray([[-1.0, 1.0], [-1.0, 1.0]],
                                       dtype=np.float32)
        self.actor_lr = float(policy_params["actor_lr"])
        self.critic_lr = float(policy_params["critic_lr"])
        self.critic_wd = float(policy_params["critic_wd"])
        self.actor_wd = float(policy_params["actor_wd"])
        self.noise_clip = float(policy_params["noise_clip"])
        self.policy_noise = float(policy_params["policy_noise"])
        self.update_rate = int(policy_params["update_rate"])
        self.policy_delay = int(policy_params["policy_delay"])
        self.warmup = int(policy_params["warmup"])
        self.critic_tau = float(policy_params["critic_tau"])
        self.actor_tau = float(policy_params["actor_tau"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.sigma = float(policy_params["sigma"])
        self.theta = float(policy_params["theta"])
        self.dt = float(policy_params["dt"])
        self.action_low = torch.tensor(
            [[each[0] for each in self.action_range]])
        self.action_high = torch.tensor(
            [[each[1] for each in self.action_range]])
        self.seed = int(policy_params["seed"])
        self.prev_action = np.zeros(self.action_size)

        # state preprocessing
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_capacity = int(policy_params["social_vehicles"].get(
            "social_capacity", 0))
        self.observation_num_lookahead = int(
            policy_params.get("observation_num_lookahead", 0))
        self.social_policy_init_std = int(policy_params["social_vehicles"].get(
            "social_policy_init_std", 0))
        self.num_social_features = int(policy_params["social_vehicles"].get(
            "num_social_features", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            **policy_params["social_vehicles"])

        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.state_description = BaselineStatePreprocessor.get_state_description(
            policy_params["social_vehicles"],
            policy_params["observation_num_lookahead"],
            self.action_size,
        )

        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (policy_params["save_codes"]
                           if "save_codes" in policy_params else None)
        self.memory = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            device_name=self.device_name,
        )
        self.num_actor_updates = 0
        self.current_iteration = 0
        self.step_count = 0
        self.init_networks()
        if checkpoint_dir:
            self.load(checkpoint_dir)
예제 #3
0
def train(
    task,
    num_episodes,
    max_episode_steps,
    rollout_fragment_length,
    policy,
    eval_info,
    timestep_sec,
    headless,
    seed,
    train_batch_size,
    sgd_minibatch_size,
    log_dir,
):
    agent_name = policy
    policy_params = load_yaml(
        f"ultra/baselines/{agent_name}/{agent_name}/params.yaml")

    action_type = adapters.type_from_string(policy_params["action_type"])
    observation_type = adapters.type_from_string(
        policy_params["observation_type"])
    reward_type = adapters.type_from_string(policy_params["reward_type"])

    if action_type != adapters.AdapterType.DefaultActionContinuous:
        raise Exception(
            f"RLlib training only supports the "
            f"{adapters.AdapterType.DefaultActionContinuous} action type.")
    if observation_type != adapters.AdapterType.DefaultObservationVector:
        # NOTE: The SMARTS observations adaptation that is done in ULTRA's Gym
        #       environment is not done in ULTRA's RLlib environment. If other
        #       observation adapters are used, they may raise an Exception.
        raise Exception(
            f"RLlib training only supports the "
            f"{adapters.AdapterType.DefaultObservationVector} observation type."
        )

    action_space = adapters.space_from_type(adapter_type=action_type)
    observation_space = adapters.space_from_type(adapter_type=observation_type)

    action_adapter = adapters.adapter_from_type(adapter_type=action_type)
    info_adapter = adapters.adapter_from_type(
        adapter_type=adapters.AdapterType.DefaultInfo)
    observation_adapter = adapters.adapter_from_type(
        adapter_type=observation_type)
    reward_adapter = adapters.adapter_from_type(adapter_type=reward_type)

    params_seed = policy_params["seed"]
    encoder_key = policy_params["social_vehicles"]["encoder_key"]
    num_social_features = observation_space["social_vehicles"].shape[1]
    social_capacity = observation_space["social_vehicles"].shape[0]
    social_policy_hidden_units = int(policy_params["social_vehicles"].get(
        "social_policy_hidden_units", 0))
    social_policy_init_std = int(policy_params["social_vehicles"].get(
        "social_policy_init_std", 0))
    social_vehicle_config = get_social_vehicle_configs(
        encoder_key=encoder_key,
        num_social_features=num_social_features,
        social_capacity=social_capacity,
        seed=params_seed,
        social_policy_hidden_units=social_policy_hidden_units,
        social_policy_init_std=social_policy_init_std,
    )

    ModelCatalog.register_custom_model("fc_model", CustomFCModel)
    config = RllibAgent.rllib_default_config(agent_name)

    rllib_policies = {
        "default_policy": (
            None,
            observation_space,
            action_space,
            {
                "model": {
                    "custom_model": "fc_model",
                    "custom_model_config": {
                        "social_vehicle_config": social_vehicle_config
                    },
                }
            },
        )
    }
    agent_specs = {
        "AGENT-007":
        AgentSpec(
            interface=AgentInterface(
                waypoints=Waypoints(lookahead=20),
                neighborhood_vehicles=NeighborhoodVehicles(200),
                action=ActionSpaceType.Continuous,
                rgb=False,
                max_episode_steps=max_episode_steps,
                debug=True,
            ),
            agent_params={},
            agent_builder=None,
            action_adapter=action_adapter,
            info_adapter=info_adapter,
            observation_adapter=observation_adapter,
            reward_adapter=reward_adapter,
        )
    }

    tune_config = {
        "env": RLlibUltraEnv,
        "log_level": "WARN",
        "callbacks": Callbacks,
        "framework": "torch",
        "num_workers": 1,
        "train_batch_size": train_batch_size,
        "sgd_minibatch_size": sgd_minibatch_size,
        "rollout_fragment_length": rollout_fragment_length,
        "in_evaluation": True,
        "evaluation_num_episodes": eval_info["eval_episodes"],
        "evaluation_interval": eval_info[
            "eval_rate"],  # Evaluation occurs after # of eval-intervals (episodes)
        "evaluation_config": {
            "env_config": {
                "seed": seed,
                "scenario_info": task,
                "headless": headless,
                "eval_mode": True,
                "ordered_scenarios": False,
                "agent_specs": agent_specs,
                "timestep_sec": timestep_sec,
            },
            "explore": False,
        },
        "env_config": {
            "seed": seed,
            "scenario_info": task,
            "headless": headless,
            "eval_mode": False,
            "ordered_scenarios": False,
            "agent_specs": agent_specs,
            "timestep_sec": timestep_sec,
        },
        "multiagent": {
            "policies": rllib_policies
        },
    }

    config.update(tune_config)
    agent = RllibAgent(
        agent_name=agent_name,
        env=RLlibUltraEnv,
        config=tune_config,
        logger_creator=log_creator(log_dir),
    )

    # Iteration value in trainer.py (self._iterations) is the technically the number of episodes
    for i in range(num_episodes):
        results = agent.train()
        agent.log_evaluation_metrics(
            results)  # Evaluation metrics will now be displayed on Tensorboard
예제 #4
0
파일: policy.py 프로젝트: yunpeng-ma/SMARTS
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        network_class = DQNWithSocialEncoder
        self.epsilon_obj = EpsilonExplore(1.0, 0.05, 100000)

        discrete_action_spaces = [[0], [1]]
        action_size = discrete_action_spaces
        self.merge_action_spaces = -1

        self.step_count = 0
        self.update_count = 0
        self.num_updates = 0
        self.current_sticky = 0
        self.current_iteration = 0

        lr = float(policy_params["lr"])
        seed = int(policy_params["seed"])
        self.train_step = int(policy_params["train_step"])
        self.target_update = float(policy_params["target_update"])
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.warmup = int(policy_params["warmup"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.use_ddqn = policy_params["use_ddqn"]
        self.sticky_actions = int(policy_params["sticky_actions"])
        prev_action_size = 1
        self.prev_action = np.zeros(prev_action_size)

        index_to_actions = [
            e.tolist() if not isinstance(e, list) else e for e in action_size
        ]

        action_to_indexs = {
            str(k): v
            for k, v in zip(index_to_actions,
                            np.arange(len(index_to_actions)).astype(np.int))
        }
        self.index2actions, self.action2indexs = (
            [index_to_actions],
            [action_to_indexs],
        )
        self.num_actions = [len(index_to_actions)]

        # state preprocessing
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_capacity = int(policy_params["social_vehicles"].get(
            "social_capacity", 0))
        self.observation_num_lookahead = int(
            policy_params.get("observation_num_lookahead", 0))
        self.social_polciy_init_std = int(policy_params["social_vehicles"].get(
            "social_polciy_init_std", 0))
        self.num_social_features = int(policy_params["social_vehicles"].get(
            "num_social_features", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            **policy_params["social_vehicles"])

        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.state_description = get_state_description(
            policy_params["social_vehicles"],
            policy_params["observation_num_lookahead"],
            prev_action_size,
        )
        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        self.checkpoint_dir = checkpoint_dir
        self.reset()

        torch.manual_seed(seed)
        network_params = {
            "state_size": self.state_size,
            "social_feature_encoder_class": self.social_feature_encoder_class,
            "social_feature_encoder_params":
            self.social_feature_encoder_params,
        }
        self.online_q_network = network_class(
            num_actions=self.num_actions,
            **(network_params if network_params else {}),
        ).to(self.device)
        self.target_q_network = network_class(
            num_actions=self.num_actions,
            **(network_params if network_params else {}),
        ).to(self.device)
        self.update_target_network()

        self.optimizers = torch.optim.Adam(
            params=self.online_q_network.parameters(), lr=lr)
        self.loss_func = nn.MSELoss(reduction="none")

        if self.checkpoint_dir:
            self.load(self.checkpoint_dir)

        self.action_space_type = "lane"
        self.to_real_action = lambda action: self.lane_actions[action[0]]
        self.state_preprocessor = StatePreprocessor(preprocess_state,
                                                    self.lane_action_to_index,
                                                    self.state_description)
        self.replay = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            state_preprocessor=self.state_preprocessor,
            device_name=self.device_name,
        )
예제 #5
0
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.batch_size = int(policy_params["batch_size"])
        self.hidden_units = int(policy_params["hidden_units"])
        self.mini_batch_size = int(policy_params["mini_batch_size"])
        self.epoch_count = int(policy_params["epoch_count"])
        self.gamma = float(policy_params["gamma"])
        self.l = float(policy_params["l"])
        self.eps = float(policy_params["eps"])
        self.actor_tau = float(policy_params["actor_tau"])
        self.critic_tau = float(policy_params["critic_tau"])
        self.entropy_tau = float(policy_params["entropy_tau"])
        self.logging_freq = int(policy_params["logging_freq"])
        self.current_iteration = 0
        self.current_log_prob = None
        self.current_value = None
        self.seed = int(policy_params["seed"])
        self.lr = float(policy_params["lr"])
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.actions = []
        self.states = []
        self.terminals = []
        self.action_size = 2
        self.prev_action = np.zeros(self.action_size)
        self.action_type = adapters.type_from_string(
            policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"])
        self.reward_type = adapters.type_from_string(
            policy_params["reward_type"])

        if self.action_type != adapters.AdapterType.DefaultActionContinuous:
            raise Exception(
                f"PPO baseline only supports the "
                f"{adapters.AdapterType.DefaultActionContinuous} action type.")
        if self.observation_type != adapters.AdapterType.DefaultObservationVector:
            raise Exception(
                f"PPO baseline only supports the "
                f"{adapters.AdapterType.DefaultObservationVector} observation type."
            )

        self.observation_space = adapters.space_from_type(
            self.observation_type)
        self.low_dim_states_size = self.observation_space[
            "low_dim_states"].shape[0]
        self.social_capacity = self.observation_space["social_vehicles"].shape[
            0]
        self.num_social_features = self.observation_space[
            "social_vehicles"].shape[1]

        self.encoder_key = policy_params["social_vehicles"]["encoder_key"]
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_policy_init_std = int(policy_params["social_vehicles"].get(
            "social_policy_init_std", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            encoder_key=self.encoder_key,
            num_social_features=self.num_social_features,
            social_capacity=self.social_capacity,
            seed=self.seed,
            social_policy_hidden_units=self.social_policy_hidden_units,
            social_policy_init_std=self.social_policy_init_std,
        )
        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (policy_params["save_codes"]
                           if "save_codes" in policy_params else None)

        # PPO
        self.ppo_net = PPONetwork(
            self.action_size,
            self.state_size,
            hidden_units=self.hidden_units,
            init_std=self.social_policy_init_std,
            seed=self.seed,
            social_feature_encoder_class=self.social_feature_encoder_class,
            social_feature_encoder_params=self.social_feature_encoder_params,
        ).to(self.device)
        self.optimizer = torch.optim.Adam(self.ppo_net.parameters(),
                                          lr=self.lr)
        self.step_count = 0
        if self.checkpoint_dir:
            self.load(self.checkpoint_dir)
예제 #6
0
파일: policy.py 프로젝트: valaxkong/SMARTS
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.lr = float(policy_params["lr"])
        self.seed = int(policy_params["seed"])
        self.train_step = int(policy_params["train_step"])
        self.target_update = float(policy_params["target_update"])
        self.warmup = int(policy_params["warmup"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.use_ddqn = policy_params["use_ddqn"]
        self.sticky_actions = int(policy_params["sticky_actions"])
        self.epsilon_obj = EpsilonExplore(1.0, 0.05, 100000)
        self.step_count = 0
        self.update_count = 0
        self.num_updates = 0
        self.current_sticky = 0
        self.current_iteration = 0
        self.action_type = adapters.type_from_string(
            policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"])
        self.reward_type = adapters.type_from_string(
            policy_params["reward_type"])

        if self.action_type == adapters.AdapterType.DefaultActionContinuous:
            discrete_action_spaces = [
                np.asarray([-0.25, 0.0, 0.5, 0.75, 1.0]),
                np.asarray([
                    -1.0, -0.75, -0.5, -0.25, -0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
                    1.0
                ]),
            ]
            self.index2actions = [
                merge_discrete_action_spaces([discrete_action_space])[0]
                for discrete_action_space in discrete_action_spaces
            ]
            self.action2indexs = [
                merge_discrete_action_spaces([discrete_action_space])[1]
                for discrete_action_space in discrete_action_spaces
            ]
            self.merge_action_spaces = 0
            self.num_actions = [
                len(discrete_action_space)
                for discrete_action_space in discrete_action_spaces
            ]
            self.action_size = 2
            self.to_real_action = to_3d_action
        elif self.action_type == adapters.AdapterType.DefaultActionDiscrete:
            discrete_action_spaces = [[0], [1], [2], [3]]
            index_to_actions = [
                discrete_action_space.tolist()
                if not isinstance(discrete_action_space, list) else
                discrete_action_space
                for discrete_action_space in discrete_action_spaces
            ]
            action_to_indexs = {
                str(discrete_action): index
                for discrete_action, index in zip(
                    index_to_actions,
                    np.arange(len(index_to_actions)).astype(np.int))
            }
            self.index2actions = [index_to_actions]
            self.action2indexs = [action_to_indexs]
            self.merge_action_spaces = -1
            self.num_actions = [len(index_to_actions)]
            self.action_size = 1
            self.to_real_action = lambda action: self.lane_actions[action[0]]
        else:
            raise Exception(
                f"DQN baseline does not support the '{self.action_type}' action type."
            )

        if self.observation_type == adapters.AdapterType.DefaultObservationVector:
            observation_space = adapters.space_from_type(self.observation_type)
            low_dim_states_size = observation_space["low_dim_states"].shape[0]
            social_capacity = observation_space["social_vehicles"].shape[0]
            num_social_features = observation_space["social_vehicles"].shape[1]

            # Get information to build the encoder.
            encoder_key = policy_params["social_vehicles"]["encoder_key"]
            social_policy_hidden_units = int(
                policy_params["social_vehicles"].get(
                    "social_policy_hidden_units", 0))
            social_policy_init_std = int(policy_params["social_vehicles"].get(
                "social_policy_init_std", 0))
            social_vehicle_config = get_social_vehicle_configs(
                encoder_key=encoder_key,
                num_social_features=num_social_features,
                social_capacity=social_capacity,
                seed=self.seed,
                social_policy_hidden_units=social_policy_hidden_units,
                social_policy_init_std=social_policy_init_std,
            )
            social_vehicle_encoder = social_vehicle_config["encoder"]
            social_feature_encoder_class = social_vehicle_encoder[
                "social_feature_encoder_class"]
            social_feature_encoder_params = social_vehicle_encoder[
                "social_feature_encoder_params"]

            # Calculate the state size based on the number of features (ego + social).
            state_size = low_dim_states_size
            if social_feature_encoder_class:
                state_size += social_feature_encoder_class(
                    **social_feature_encoder_params).output_dim
            else:
                state_size += social_capacity * num_social_features
            # Add the action size to account for the previous action.
            state_size += self.action_size

            network_class = DQNWithSocialEncoder
            network_params = {
                "num_actions": self.num_actions,
                "state_size": state_size,
                "social_feature_encoder_class": social_feature_encoder_class,
                "social_feature_encoder_params": social_feature_encoder_params,
            }
        elif self.observation_type == adapters.AdapterType.DefaultObservationImage:
            observation_space = adapters.space_from_type(self.observation_type)
            stack_size = observation_space.shape[0]
            image_shape = (observation_space.shape[1],
                           observation_space.shape[2])

            network_class = DQNCNN
            network_params = {
                "n_in_channels": stack_size,
                "image_dim": image_shape,
                "num_actions": self.num_actions,
            }
        else:
            raise Exception(
                f"DQN baseline does not support the '{self.observation_type}' "
                f"observation type.")

        self.prev_action = np.zeros(self.action_size)
        self.checkpoint_dir = checkpoint_dir
        torch.manual_seed(self.seed)
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.online_q_network = network_class(**network_params).to(self.device)
        self.target_q_network = network_class(**network_params).to(self.device)
        self.update_target_network()
        self.optimizers = torch.optim.Adam(
            params=self.online_q_network.parameters(), lr=self.lr)
        self.loss_func = nn.MSELoss(reduction="none")
        self.replay = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            observation_type=self.observation_type,
            device_name=self.device_name,
        )
        self.reset()
        if self.checkpoint_dir:
            self.load(self.checkpoint_dir)
예제 #7
0
파일: policy.py 프로젝트: zbzhu99/SMARTS
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.batch_size = int(policy_params["batch_size"])
        self.hidden_units = int(policy_params["hidden_units"])
        self.mini_batch_size = int(policy_params["mini_batch_size"])
        self.epoch_count = int(policy_params["epoch_count"])
        self.gamma = float(policy_params["gamma"])
        self.l = float(policy_params["l"])
        self.eps = float(policy_params["eps"])
        self.actor_tau = float(policy_params["actor_tau"])
        self.critic_tau = float(policy_params["critic_tau"])
        self.entropy_tau = float(policy_params["entropy_tau"])
        self.logging_freq = int(policy_params["logging_freq"])
        self.current_iteration = 0
        self.current_log_prob = None
        self.current_value = None
        self.seed = int(policy_params["seed"])
        self.lr = float(policy_params["lr"])
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.actions = []
        self.states = []
        self.terminals = []
        self.action_size = int(policy_params["action_size"])
        self.prev_action = np.zeros(self.action_size)

        # state preprocessing
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_capacity = int(policy_params["social_vehicles"].get(
            "social_capacity", 0))
        self.observation_num_lookahead = int(
            policy_params.get("observation_num_lookahead", 0))
        self.social_policy_init_std = int(policy_params["social_vehicles"].get(
            "social_policy_init_std", 0))
        self.num_social_features = int(policy_params["social_vehicles"].get(
            "num_social_features", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            **policy_params["social_vehicles"])

        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.state_description = BaselineStatePreprocessor.get_state_description(
            policy_params["social_vehicles"],
            policy_params["observation_num_lookahead"],
            self.action_size,
        )
        # self.state_preprocessor = StatePreprocessor(
        #     preprocess_state, to_2d_action, self.state_description
        # )
        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (policy_params["save_codes"]
                           if "save_codes" in policy_params else None)

        # PPO
        self.ppo_net = PPONetwork(
            self.action_size,
            self.state_size,
            hidden_units=self.hidden_units,
            init_std=self.social_policy_init_std,
            seed=self.seed,
            social_feature_encoder_class=self.social_feature_encoder_class,
            social_feature_encoder_params=self.social_feature_encoder_params,
        ).to(self.device)
        self.optimizer = torch.optim.Adam(self.ppo_net.parameters(),
                                          lr=self.lr)
        self.step_count = 0
        if self.checkpoint_dir:
            self.load(self.checkpoint_dir)
예제 #8
0
파일: policy.py 프로젝트: qyshen815/SMARTS
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        # print("LOADING THE PARAMS", policy_params, checkpoint_dir)
        self.policy_params = policy_params
        self.gamma = float(policy_params["gamma"])
        self.critic_lr = float(policy_params["critic_lr"])
        self.actor_lr = float(policy_params["actor_lr"])
        self.critic_update_rate = int(policy_params["critic_update_rate"])
        self.policy_update_rate = int(policy_params["policy_update_rate"])
        self.warmup = int(policy_params["warmup"])
        self.seed = int(policy_params["seed"])
        self.batch_size = int(policy_params["batch_size"])
        self.hidden_units = int(policy_params["hidden_units"])
        self.tau = float(policy_params["tau"])
        self.initial_alpha = float(policy_params["initial_alpha"])
        self.logging_freq = int(policy_params["logging_freq"])
        self.action_size = 2
        self.prev_action = np.zeros(self.action_size)
        self.action_type = adapters.type_from_string(
            policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"])
        self.reward_type = adapters.type_from_string(
            policy_params["reward_type"])

        if self.action_type != adapters.AdapterType.DefaultActionContinuous:
            raise Exception(
                f"SAC baseline only supports the "
                f"{adapters.AdapterType.DefaultActionContinuous} action type.")
        if self.observation_type != adapters.AdapterType.DefaultObservationVector:
            raise Exception(
                f"SAC baseline only supports the "
                f"{adapters.AdapterType.DefaultObservationVector} observation type."
            )

        self.observation_space = adapters.space_from_type(
            self.observation_type)
        self.low_dim_states_size = self.observation_space[
            "low_dim_states"].shape[0]
        self.social_capacity = self.observation_space["social_vehicles"].shape[
            0]
        self.num_social_features = self.observation_space[
            "social_vehicles"].shape[1]

        self.encoder_key = policy_params["social_vehicles"]["encoder_key"]
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_policy_init_std = int(policy_params["social_vehicles"].get(
            "social_policy_init_std", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            encoder_key=self.encoder_key,
            num_social_features=self.num_social_features,
            social_capacity=self.social_capacity,
            seed=self.seed,
            social_policy_hidden_units=self.social_policy_hidden_units,
            social_policy_init_std=self.social_policy_init_std,
        )
        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (policy_params["save_codes"]
                           if "save_codes" in policy_params else None)
        self.memory = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            observation_type=self.observation_type,
            device_name=self.device_name,
        )
        self.current_iteration = 0
        self.steps = 0
        self.init_networks()
        if checkpoint_dir:
            self.load(checkpoint_dir)
예제 #9
0
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        # print("LOADING THE PARAMS", policy_params, checkpoint_dir)
        self.policy_params = policy_params
        self.gamma = float(policy_params["gamma"])
        self.critic_lr = float(policy_params["critic_lr"])
        self.actor_lr = float(policy_params["actor_lr"])
        self.critic_update_rate = int(policy_params["critic_update_rate"])
        self.policy_update_rate = int(policy_params["policy_update_rate"])
        self.warmup = int(policy_params["warmup"])
        self.seed = int(policy_params["seed"])
        self.batch_size = int(policy_params["batch_size"])
        self.hidden_units = int(policy_params["hidden_units"])
        self.tau = float(policy_params["tau"])
        self.initial_alpha = float(policy_params["initial_alpha"])
        self.logging_freq = int(policy_params["logging_freq"])
        self.action_size = int(policy_params["action_size"])
        self.prev_action = np.zeros(self.action_size)

        # state preprocessing
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_capacity = int(policy_params["social_vehicles"].get(
            "social_capacity", 0))
        self.observation_num_lookahead = int(
            policy_params.get("observation_num_lookahead", 0))
        self.social_polciy_init_std = int(policy_params["social_vehicles"].get(
            "social_polciy_init_std", 0))
        self.num_social_features = int(policy_params["social_vehicles"].get(
            "num_social_features", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            **policy_params["social_vehicles"])

        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.state_description = get_state_description(
            policy_params["social_vehicles"],
            policy_params["observation_num_lookahead"],
            self.action_size,
        )
        self.state_preprocessor = StatePreprocessor(preprocess_state,
                                                    to_2d_action,
                                                    self.state_description)
        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (policy_params["save_codes"]
                           if "save_codes" in policy_params else None)
        self.memory = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            state_preprocessor=self.state_preprocessor,
            device_name=self.device_name,
        )
        self.current_iteration = 0
        self.steps = 0
        self.init_networks()
        if checkpoint_dir:
            self.load(checkpoint_dir)
예제 #10
0
파일: policy.py 프로젝트: valaxkong/SMARTS
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.action_size = 2
        self.action_range = np.asarray([[-1.0, 1.0], [-1.0, 1.0]], dtype=np.float32)
        self.actor_lr = float(policy_params["actor_lr"])
        self.critic_lr = float(policy_params["critic_lr"])
        self.critic_wd = float(policy_params["critic_wd"])
        self.actor_wd = float(policy_params["actor_wd"])
        self.noise_clip = float(policy_params["noise_clip"])
        self.policy_noise = float(policy_params["policy_noise"])
        self.update_rate = int(policy_params["update_rate"])
        self.policy_delay = int(policy_params["policy_delay"])
        self.warmup = int(policy_params["warmup"])
        self.critic_tau = float(policy_params["critic_tau"])
        self.actor_tau = float(policy_params["actor_tau"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.sigma = float(policy_params["sigma"])
        self.theta = float(policy_params["theta"])
        self.dt = float(policy_params["dt"])
        self.action_low = torch.tensor([[each[0] for each in self.action_range]])
        self.action_high = torch.tensor([[each[1] for each in self.action_range]])
        self.seed = int(policy_params["seed"])
        self.prev_action = np.zeros(self.action_size)
        self.action_type = adapters.type_from_string(policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"]
        )
        self.reward_type = adapters.type_from_string(policy_params["reward_type"])

        if self.action_type != adapters.AdapterType.DefaultActionContinuous:
            raise Exception(
                f"TD3 baseline only supports the "
                f"{adapters.AdapterType.DefaultActionContinuous} action type."
            )

        if self.observation_type == adapters.AdapterType.DefaultObservationVector:
            observation_space = adapters.space_from_type(self.observation_type)
            low_dim_states_size = observation_space["low_dim_states"].shape[0]
            social_capacity = observation_space["social_vehicles"].shape[0]
            num_social_features = observation_space["social_vehicles"].shape[1]

            # Get information to build the encoder.
            encoder_key = policy_params["social_vehicles"]["encoder_key"]
            social_policy_hidden_units = int(
                policy_params["social_vehicles"].get("social_policy_hidden_units", 0)
            )
            social_policy_init_std = int(
                policy_params["social_vehicles"].get("social_policy_init_std", 0)
            )
            social_vehicle_config = get_social_vehicle_configs(
                encoder_key=encoder_key,
                num_social_features=num_social_features,
                social_capacity=social_capacity,
                seed=self.seed,
                social_policy_hidden_units=social_policy_hidden_units,
                social_policy_init_std=social_policy_init_std,
            )
            social_vehicle_encoder = social_vehicle_config["encoder"]
            social_feature_encoder_class = social_vehicle_encoder[
                "social_feature_encoder_class"
            ]
            social_feature_encoder_params = social_vehicle_encoder[
                "social_feature_encoder_params"
            ]

            # Calculate the state size based on the number of features (ego + social).
            state_size = low_dim_states_size
            if social_feature_encoder_class:
                state_size += social_feature_encoder_class(
                    **social_feature_encoder_params
                ).output_dim
            else:
                state_size += social_capacity * num_social_features
            # Add the action size to account for the previous action.
            state_size += self.action_size

            actor_network_class = FCActorNetwork
            critic_network_class = FCCrtiicNetwork
            network_params = {
                "state_space": state_size,
                "action_space": self.action_size,
                "seed": self.seed,
                "social_feature_encoder": social_feature_encoder_class(
                    **social_feature_encoder_params
                )
                if social_feature_encoder_class
                else None,
            }
        elif self.observation_type == adapters.AdapterType.DefaultObservationImage:
            observation_space = adapters.space_from_type(self.observation_type)
            stack_size = observation_space.shape[0]
            image_shape = (observation_space.shape[1], observation_space.shape[2])

            actor_network_class = CNNActorNetwork
            critic_network_class = CNNCriticNetwork
            network_params = {
                "input_channels": stack_size,
                "input_dimension": image_shape,
                "action_size": self.action_size,
                "seed": self.seed,
            }
        else:
            raise Exception(
                f"TD3 baseline does not support the '{self.observation_type}' "
                f"observation type."
            )

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (
            policy_params["save_codes"] if "save_codes" in policy_params else None
        )
        self.memory = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            observation_type=self.observation_type,
            device_name=self.device_name,
        )
        self.num_actor_updates = 0
        self.current_iteration = 0
        self.step_count = 0
        self.init_networks(actor_network_class, critic_network_class, network_params)
        if checkpoint_dir:
            self.load(checkpoint_dir)