Esempio n. 1
0
    def test_default_action_discrete_adapter(self):
        ADAPTER_TYPE = adapters.AdapterType.DefaultActionDiscrete
        adapter = adapters.adapter_from_type(ADAPTER_TYPE)
        interface = adapters.required_interface_from_types(ADAPTER_TYPE)
        space = adapters.space_from_type(ADAPTER_TYPE)

        AVAILABLE_ACTIONS = [
            "keep_lane",
            "slow_down",
            "change_lane_left",
            "change_lane_right",
        ]

        agent, environment = prepare_test_agent_and_environment(
            required_interface=interface,
            action_adapter=adapter,
        )
        action_sequence, _, _, _ = run_experiment(agent, environment)

        for action in action_sequence:
            self.assertIsInstance(action, str)
            self.assertIn(action, AVAILABLE_ACTIONS)
            self.assertEqual(space.dtype, type(action))
            self.assertEqual(space.shape, ())
            self.assertTrue(space.contains(action))
Esempio n. 2
0
    def test_default_action_continuous_adapter(self):
        ADAPTER_TYPE = adapters.AdapterType.DefaultActionContinuous
        adapter = adapters.adapter_from_type(ADAPTER_TYPE)
        interface = adapters.required_interface_from_types(ADAPTER_TYPE)
        space = adapters.space_from_type(ADAPTER_TYPE)

        agent, environment = prepare_test_agent_and_environment(
            required_interface=interface,
            action_adapter=adapter,
        )
        action_sequence, _, _, _ = run_experiment(agent, environment)

        for action in action_sequence:
            self.assertIsInstance(action, np.ndarray)
            self.assertEqual(action.dtype, "float32")
            self.assertEqual(action.shape, (3, ))
            self.assertGreaterEqual(action[0], 0.0)
            self.assertLessEqual(action[0], 1.0)
            self.assertGreaterEqual(action[1], 0.0)
            self.assertLessEqual(action[1], 1.0)
            self.assertGreaterEqual(action[2], -1.0)
            self.assertLessEqual(action[2], 1.0)
            self.assertEqual(space.dtype, action.dtype)
            self.assertEqual(space.shape, action.shape)
            self.assertTrue(space.contains(action))
Esempio n. 3
0
    def test_default_observation_vector_adapter(self):
        ADAPTER_TYPE = adapters.AdapterType.DefaultObservationVector
        adapter = adapters.adapter_from_type(ADAPTER_TYPE)
        interface = adapters.required_interface_from_types(ADAPTER_TYPE)
        space = adapters.space_from_type(ADAPTER_TYPE)

        agent, environment = prepare_test_agent_and_environment(
            required_interface=interface,
            observation_adapter=adapter,
        )
        _, _, observations_sequence, _ = run_experiment(agent,
                                                        environment,
                                                        max_steps=1)

        observations = observations_sequence[0]
        self.assertIsInstance(observations, dict)
        self.assertIn(AGENT_ID, observations)
        self.assertIn("low_dim_states", observations[AGENT_ID])
        self.assertIn("social_vehicles", observations[AGENT_ID])
        self.assertIsInstance(observations[AGENT_ID]["low_dim_states"],
                              np.ndarray)
        self.assertIsInstance(observations[AGENT_ID]["social_vehicles"],
                              np.ndarray)
        self.assertEqual(observations[AGENT_ID]["low_dim_states"].dtype,
                         "float32")
        self.assertEqual(observations[AGENT_ID]["social_vehicles"].dtype,
                         "float32")
        self.assertEqual(observations[AGENT_ID]["low_dim_states"].shape,
                         (47, ))
        self.assertEqual(observations[AGENT_ID]["social_vehicles"].shape,
                         (10, 4))
        self.assertEqual(space.dtype, None)
        self.assertEqual(
            space["low_dim_states"].dtype,
            observations[AGENT_ID]["low_dim_states"].dtype,
        )
        self.assertEqual(
            space["social_vehicles"].dtype,
            observations[AGENT_ID]["social_vehicles"].dtype,
        )
        self.assertEqual(space.shape, None)
        self.assertEqual(
            space["low_dim_states"].shape,
            observations[AGENT_ID]["low_dim_states"].shape,
        )
        self.assertEqual(
            space["social_vehicles"].shape,
            observations[AGENT_ID]["social_vehicles"].shape,
        )
        self.assertTrue(space.contains(observations[AGENT_ID]))
Esempio n. 4
0
    def test_default_observation_image_adapter(self):
        ADAPTER_TYPE = adapters.AdapterType.DefaultObservationImage
        adapter = adapters.adapter_from_type(ADAPTER_TYPE)
        interface = adapters.required_interface_from_types(ADAPTER_TYPE)
        space = adapters.space_from_type(ADAPTER_TYPE)

        agent, environment = prepare_test_agent_and_environment(
            required_interface=interface,
            observation_adapter=adapter,
        )
        _, _, observations_sequence, _ = run_experiment(agent,
                                                        environment,
                                                        max_steps=1)

        observations = observations_sequence[0]
        self.assertIsInstance(observations, dict)
        self.assertIn(AGENT_ID, observations)
        self.assertIsInstance(observations[AGENT_ID], np.ndarray)
        self.assertEqual(observations[AGENT_ID].dtype, "float32")
        self.assertEqual(observations[AGENT_ID].shape, (4, 64, 64))
        self.assertEqual(space.dtype, observations[AGENT_ID].dtype)
        self.assertEqual(space.shape, observations[AGENT_ID].shape)
        self.assertTrue(space.contains(observations[AGENT_ID]))
Esempio n. 5
0
def train(
    task,
    num_episodes,
    max_episode_steps,
    rollout_fragment_length,
    policy,
    eval_info,
    timestep_sec,
    headless,
    seed,
    train_batch_size,
    sgd_minibatch_size,
    log_dir,
):
    agent_name = policy
    policy_params = load_yaml(
        f"ultra/baselines/{agent_name}/{agent_name}/params.yaml")

    action_type = adapters.type_from_string(policy_params["action_type"])
    observation_type = adapters.type_from_string(
        policy_params["observation_type"])
    reward_type = adapters.type_from_string(policy_params["reward_type"])

    if action_type != adapters.AdapterType.DefaultActionContinuous:
        raise Exception(
            f"RLlib training only supports the "
            f"{adapters.AdapterType.DefaultActionContinuous} action type.")
    if observation_type != adapters.AdapterType.DefaultObservationVector:
        # NOTE: The SMARTS observations adaptation that is done in ULTRA's Gym
        #       environment is not done in ULTRA's RLlib environment. If other
        #       observation adapters are used, they may raise an Exception.
        raise Exception(
            f"RLlib training only supports the "
            f"{adapters.AdapterType.DefaultObservationVector} observation type."
        )

    action_space = adapters.space_from_type(adapter_type=action_type)
    observation_space = adapters.space_from_type(adapter_type=observation_type)

    action_adapter = adapters.adapter_from_type(adapter_type=action_type)
    info_adapter = adapters.adapter_from_type(
        adapter_type=adapters.AdapterType.DefaultInfo)
    observation_adapter = adapters.adapter_from_type(
        adapter_type=observation_type)
    reward_adapter = adapters.adapter_from_type(adapter_type=reward_type)

    params_seed = policy_params["seed"]
    encoder_key = policy_params["social_vehicles"]["encoder_key"]
    num_social_features = observation_space["social_vehicles"].shape[1]
    social_capacity = observation_space["social_vehicles"].shape[0]
    social_policy_hidden_units = int(policy_params["social_vehicles"].get(
        "social_policy_hidden_units", 0))
    social_policy_init_std = int(policy_params["social_vehicles"].get(
        "social_policy_init_std", 0))
    social_vehicle_config = get_social_vehicle_configs(
        encoder_key=encoder_key,
        num_social_features=num_social_features,
        social_capacity=social_capacity,
        seed=params_seed,
        social_policy_hidden_units=social_policy_hidden_units,
        social_policy_init_std=social_policy_init_std,
    )

    ModelCatalog.register_custom_model("fc_model", CustomFCModel)
    config = RllibAgent.rllib_default_config(agent_name)

    rllib_policies = {
        "default_policy": (
            None,
            observation_space,
            action_space,
            {
                "model": {
                    "custom_model": "fc_model",
                    "custom_model_config": {
                        "social_vehicle_config": social_vehicle_config
                    },
                }
            },
        )
    }
    agent_specs = {
        "AGENT-007":
        AgentSpec(
            interface=AgentInterface(
                waypoints=Waypoints(lookahead=20),
                neighborhood_vehicles=NeighborhoodVehicles(200),
                action=ActionSpaceType.Continuous,
                rgb=False,
                max_episode_steps=max_episode_steps,
                debug=True,
            ),
            agent_params={},
            agent_builder=None,
            action_adapter=action_adapter,
            info_adapter=info_adapter,
            observation_adapter=observation_adapter,
            reward_adapter=reward_adapter,
        )
    }

    tune_config = {
        "env": RLlibUltraEnv,
        "log_level": "WARN",
        "callbacks": Callbacks,
        "framework": "torch",
        "num_workers": 1,
        "train_batch_size": train_batch_size,
        "sgd_minibatch_size": sgd_minibatch_size,
        "rollout_fragment_length": rollout_fragment_length,
        "in_evaluation": True,
        "evaluation_num_episodes": eval_info["eval_episodes"],
        "evaluation_interval": eval_info[
            "eval_rate"],  # Evaluation occurs after # of eval-intervals (episodes)
        "evaluation_config": {
            "env_config": {
                "seed": seed,
                "scenario_info": task,
                "headless": headless,
                "eval_mode": True,
                "ordered_scenarios": False,
                "agent_specs": agent_specs,
                "timestep_sec": timestep_sec,
            },
            "explore": False,
        },
        "env_config": {
            "seed": seed,
            "scenario_info": task,
            "headless": headless,
            "eval_mode": False,
            "ordered_scenarios": False,
            "agent_specs": agent_specs,
            "timestep_sec": timestep_sec,
        },
        "multiagent": {
            "policies": rllib_policies
        },
    }

    config.update(tune_config)
    agent = RllibAgent(
        agent_name=agent_name,
        env=RLlibUltraEnv,
        config=tune_config,
        logger_creator=log_creator(log_dir),
    )

    # Iteration value in trainer.py (self._iterations) is the technically the number of episodes
    for i in range(num_episodes):
        results = agent.train()
        agent.log_evaluation_metrics(
            results)  # Evaluation metrics will now be displayed on Tensorboard
Esempio n. 6
0
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.batch_size = int(policy_params["batch_size"])
        self.hidden_units = int(policy_params["hidden_units"])
        self.mini_batch_size = int(policy_params["mini_batch_size"])
        self.epoch_count = int(policy_params["epoch_count"])
        self.gamma = float(policy_params["gamma"])
        self.l = float(policy_params["l"])
        self.eps = float(policy_params["eps"])
        self.actor_tau = float(policy_params["actor_tau"])
        self.critic_tau = float(policy_params["critic_tau"])
        self.entropy_tau = float(policy_params["entropy_tau"])
        self.logging_freq = int(policy_params["logging_freq"])
        self.current_iteration = 0
        self.current_log_prob = None
        self.current_value = None
        self.seed = int(policy_params["seed"])
        self.lr = float(policy_params["lr"])
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.actions = []
        self.states = []
        self.terminals = []
        self.action_size = 2
        self.prev_action = np.zeros(self.action_size)
        self.action_type = adapters.type_from_string(
            policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"])
        self.reward_type = adapters.type_from_string(
            policy_params["reward_type"])

        if self.action_type != adapters.AdapterType.DefaultActionContinuous:
            raise Exception(
                f"PPO baseline only supports the "
                f"{adapters.AdapterType.DefaultActionContinuous} action type.")
        if self.observation_type != adapters.AdapterType.DefaultObservationVector:
            raise Exception(
                f"PPO baseline only supports the "
                f"{adapters.AdapterType.DefaultObservationVector} observation type."
            )

        self.observation_space = adapters.space_from_type(
            self.observation_type)
        self.low_dim_states_size = self.observation_space[
            "low_dim_states"].shape[0]
        self.social_capacity = self.observation_space["social_vehicles"].shape[
            0]
        self.num_social_features = self.observation_space[
            "social_vehicles"].shape[1]

        self.encoder_key = policy_params["social_vehicles"]["encoder_key"]
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_policy_init_std = int(policy_params["social_vehicles"].get(
            "social_policy_init_std", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            encoder_key=self.encoder_key,
            num_social_features=self.num_social_features,
            social_capacity=self.social_capacity,
            seed=self.seed,
            social_policy_hidden_units=self.social_policy_hidden_units,
            social_policy_init_std=self.social_policy_init_std,
        )
        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (policy_params["save_codes"]
                           if "save_codes" in policy_params else None)

        # PPO
        self.ppo_net = PPONetwork(
            self.action_size,
            self.state_size,
            hidden_units=self.hidden_units,
            init_std=self.social_policy_init_std,
            seed=self.seed,
            social_feature_encoder_class=self.social_feature_encoder_class,
            social_feature_encoder_params=self.social_feature_encoder_params,
        ).to(self.device)
        self.optimizer = torch.optim.Adam(self.ppo_net.parameters(),
                                          lr=self.lr)
        self.step_count = 0
        if self.checkpoint_dir:
            self.load(self.checkpoint_dir)
Esempio n. 7
0
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.lr = float(policy_params["lr"])
        self.seed = int(policy_params["seed"])
        self.train_step = int(policy_params["train_step"])
        self.target_update = float(policy_params["target_update"])
        self.warmup = int(policy_params["warmup"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.use_ddqn = policy_params["use_ddqn"]
        self.sticky_actions = int(policy_params["sticky_actions"])
        self.epsilon_obj = EpsilonExplore(1.0, 0.05, 100000)
        self.step_count = 0
        self.update_count = 0
        self.num_updates = 0
        self.current_sticky = 0
        self.current_iteration = 0
        self.action_type = adapters.type_from_string(
            policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"])
        self.reward_type = adapters.type_from_string(
            policy_params["reward_type"])

        if self.action_type == adapters.AdapterType.DefaultActionContinuous:
            discrete_action_spaces = [
                np.asarray([-0.25, 0.0, 0.5, 0.75, 1.0]),
                np.asarray([
                    -1.0, -0.75, -0.5, -0.25, -0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
                    1.0
                ]),
            ]
            self.index2actions = [
                merge_discrete_action_spaces([discrete_action_space])[0]
                for discrete_action_space in discrete_action_spaces
            ]
            self.action2indexs = [
                merge_discrete_action_spaces([discrete_action_space])[1]
                for discrete_action_space in discrete_action_spaces
            ]
            self.merge_action_spaces = 0
            self.num_actions = [
                len(discrete_action_space)
                for discrete_action_space in discrete_action_spaces
            ]
            self.action_size = 2
            self.to_real_action = to_3d_action
        elif self.action_type == adapters.AdapterType.DefaultActionDiscrete:
            discrete_action_spaces = [[0], [1], [2], [3]]
            index_to_actions = [
                discrete_action_space.tolist()
                if not isinstance(discrete_action_space, list) else
                discrete_action_space
                for discrete_action_space in discrete_action_spaces
            ]
            action_to_indexs = {
                str(discrete_action): index
                for discrete_action, index in zip(
                    index_to_actions,
                    np.arange(len(index_to_actions)).astype(np.int))
            }
            self.index2actions = [index_to_actions]
            self.action2indexs = [action_to_indexs]
            self.merge_action_spaces = -1
            self.num_actions = [len(index_to_actions)]
            self.action_size = 1
            self.to_real_action = lambda action: self.lane_actions[action[0]]
        else:
            raise Exception(
                f"DQN baseline does not support the '{self.action_type}' action type."
            )

        if self.observation_type == adapters.AdapterType.DefaultObservationVector:
            observation_space = adapters.space_from_type(self.observation_type)
            low_dim_states_size = observation_space["low_dim_states"].shape[0]
            social_capacity = observation_space["social_vehicles"].shape[0]
            num_social_features = observation_space["social_vehicles"].shape[1]

            # Get information to build the encoder.
            encoder_key = policy_params["social_vehicles"]["encoder_key"]
            social_policy_hidden_units = int(
                policy_params["social_vehicles"].get(
                    "social_policy_hidden_units", 0))
            social_policy_init_std = int(policy_params["social_vehicles"].get(
                "social_policy_init_std", 0))
            social_vehicle_config = get_social_vehicle_configs(
                encoder_key=encoder_key,
                num_social_features=num_social_features,
                social_capacity=social_capacity,
                seed=self.seed,
                social_policy_hidden_units=social_policy_hidden_units,
                social_policy_init_std=social_policy_init_std,
            )
            social_vehicle_encoder = social_vehicle_config["encoder"]
            social_feature_encoder_class = social_vehicle_encoder[
                "social_feature_encoder_class"]
            social_feature_encoder_params = social_vehicle_encoder[
                "social_feature_encoder_params"]

            # Calculate the state size based on the number of features (ego + social).
            state_size = low_dim_states_size
            if social_feature_encoder_class:
                state_size += social_feature_encoder_class(
                    **social_feature_encoder_params).output_dim
            else:
                state_size += social_capacity * num_social_features
            # Add the action size to account for the previous action.
            state_size += self.action_size

            network_class = DQNWithSocialEncoder
            network_params = {
                "num_actions": self.num_actions,
                "state_size": state_size,
                "social_feature_encoder_class": social_feature_encoder_class,
                "social_feature_encoder_params": social_feature_encoder_params,
            }
        elif self.observation_type == adapters.AdapterType.DefaultObservationImage:
            observation_space = adapters.space_from_type(self.observation_type)
            stack_size = observation_space.shape[0]
            image_shape = (observation_space.shape[1],
                           observation_space.shape[2])

            network_class = DQNCNN
            network_params = {
                "n_in_channels": stack_size,
                "image_dim": image_shape,
                "num_actions": self.num_actions,
            }
        else:
            raise Exception(
                f"DQN baseline does not support the '{self.observation_type}' "
                f"observation type.")

        self.prev_action = np.zeros(self.action_size)
        self.checkpoint_dir = checkpoint_dir
        torch.manual_seed(self.seed)
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.online_q_network = network_class(**network_params).to(self.device)
        self.target_q_network = network_class(**network_params).to(self.device)
        self.update_target_network()
        self.optimizers = torch.optim.Adam(
            params=self.online_q_network.parameters(), lr=self.lr)
        self.loss_func = nn.MSELoss(reduction="none")
        self.replay = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            observation_type=self.observation_type,
            device_name=self.device_name,
        )
        self.reset()
        if self.checkpoint_dir:
            self.load(self.checkpoint_dir)
Esempio n. 8
0
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        # print("LOADING THE PARAMS", policy_params, checkpoint_dir)
        self.policy_params = policy_params
        self.gamma = float(policy_params["gamma"])
        self.critic_lr = float(policy_params["critic_lr"])
        self.actor_lr = float(policy_params["actor_lr"])
        self.critic_update_rate = int(policy_params["critic_update_rate"])
        self.policy_update_rate = int(policy_params["policy_update_rate"])
        self.warmup = int(policy_params["warmup"])
        self.seed = int(policy_params["seed"])
        self.batch_size = int(policy_params["batch_size"])
        self.hidden_units = int(policy_params["hidden_units"])
        self.tau = float(policy_params["tau"])
        self.initial_alpha = float(policy_params["initial_alpha"])
        self.logging_freq = int(policy_params["logging_freq"])
        self.action_size = 2
        self.prev_action = np.zeros(self.action_size)
        self.action_type = adapters.type_from_string(
            policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"])
        self.reward_type = adapters.type_from_string(
            policy_params["reward_type"])

        if self.action_type != adapters.AdapterType.DefaultActionContinuous:
            raise Exception(
                f"SAC baseline only supports the "
                f"{adapters.AdapterType.DefaultActionContinuous} action type.")
        if self.observation_type != adapters.AdapterType.DefaultObservationVector:
            raise Exception(
                f"SAC baseline only supports the "
                f"{adapters.AdapterType.DefaultObservationVector} observation type."
            )

        self.observation_space = adapters.space_from_type(
            self.observation_type)
        self.low_dim_states_size = self.observation_space[
            "low_dim_states"].shape[0]
        self.social_capacity = self.observation_space["social_vehicles"].shape[
            0]
        self.num_social_features = self.observation_space[
            "social_vehicles"].shape[1]

        self.encoder_key = policy_params["social_vehicles"]["encoder_key"]
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_policy_init_std = int(policy_params["social_vehicles"].get(
            "social_policy_init_std", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            encoder_key=self.encoder_key,
            num_social_features=self.num_social_features,
            social_capacity=self.social_capacity,
            seed=self.seed,
            social_policy_hidden_units=self.social_policy_hidden_units,
            social_policy_init_std=self.social_policy_init_std,
        )
        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (policy_params["save_codes"]
                           if "save_codes" in policy_params else None)
        self.memory = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            observation_type=self.observation_type,
            device_name=self.device_name,
        )
        self.current_iteration = 0
        self.steps = 0
        self.init_networks()
        if checkpoint_dir:
            self.load(checkpoint_dir)
Esempio n. 9
0
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.action_size = 2
        self.action_range = np.asarray([[-1.0, 1.0], [-1.0, 1.0]], dtype=np.float32)
        self.actor_lr = float(policy_params["actor_lr"])
        self.critic_lr = float(policy_params["critic_lr"])
        self.critic_wd = float(policy_params["critic_wd"])
        self.actor_wd = float(policy_params["actor_wd"])
        self.noise_clip = float(policy_params["noise_clip"])
        self.policy_noise = float(policy_params["policy_noise"])
        self.update_rate = int(policy_params["update_rate"])
        self.policy_delay = int(policy_params["policy_delay"])
        self.warmup = int(policy_params["warmup"])
        self.critic_tau = float(policy_params["critic_tau"])
        self.actor_tau = float(policy_params["actor_tau"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.sigma = float(policy_params["sigma"])
        self.theta = float(policy_params["theta"])
        self.dt = float(policy_params["dt"])
        self.action_low = torch.tensor([[each[0] for each in self.action_range]])
        self.action_high = torch.tensor([[each[1] for each in self.action_range]])
        self.seed = int(policy_params["seed"])
        self.prev_action = np.zeros(self.action_size)
        self.action_type = adapters.type_from_string(policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"]
        )
        self.reward_type = adapters.type_from_string(policy_params["reward_type"])

        if self.action_type != adapters.AdapterType.DefaultActionContinuous:
            raise Exception(
                f"TD3 baseline only supports the "
                f"{adapters.AdapterType.DefaultActionContinuous} action type."
            )

        if self.observation_type == adapters.AdapterType.DefaultObservationVector:
            observation_space = adapters.space_from_type(self.observation_type)
            low_dim_states_size = observation_space["low_dim_states"].shape[0]
            social_capacity = observation_space["social_vehicles"].shape[0]
            num_social_features = observation_space["social_vehicles"].shape[1]

            # Get information to build the encoder.
            encoder_key = policy_params["social_vehicles"]["encoder_key"]
            social_policy_hidden_units = int(
                policy_params["social_vehicles"].get("social_policy_hidden_units", 0)
            )
            social_policy_init_std = int(
                policy_params["social_vehicles"].get("social_policy_init_std", 0)
            )
            social_vehicle_config = get_social_vehicle_configs(
                encoder_key=encoder_key,
                num_social_features=num_social_features,
                social_capacity=social_capacity,
                seed=self.seed,
                social_policy_hidden_units=social_policy_hidden_units,
                social_policy_init_std=social_policy_init_std,
            )
            social_vehicle_encoder = social_vehicle_config["encoder"]
            social_feature_encoder_class = social_vehicle_encoder[
                "social_feature_encoder_class"
            ]
            social_feature_encoder_params = social_vehicle_encoder[
                "social_feature_encoder_params"
            ]

            # Calculate the state size based on the number of features (ego + social).
            state_size = low_dim_states_size
            if social_feature_encoder_class:
                state_size += social_feature_encoder_class(
                    **social_feature_encoder_params
                ).output_dim
            else:
                state_size += social_capacity * num_social_features
            # Add the action size to account for the previous action.
            state_size += self.action_size

            actor_network_class = FCActorNetwork
            critic_network_class = FCCrtiicNetwork
            network_params = {
                "state_space": state_size,
                "action_space": self.action_size,
                "seed": self.seed,
                "social_feature_encoder": social_feature_encoder_class(
                    **social_feature_encoder_params
                )
                if social_feature_encoder_class
                else None,
            }
        elif self.observation_type == adapters.AdapterType.DefaultObservationImage:
            observation_space = adapters.space_from_type(self.observation_type)
            stack_size = observation_space.shape[0]
            image_shape = (observation_space.shape[1], observation_space.shape[2])

            actor_network_class = CNNActorNetwork
            critic_network_class = CNNCriticNetwork
            network_params = {
                "input_channels": stack_size,
                "input_dimension": image_shape,
                "action_size": self.action_size,
                "seed": self.seed,
            }
        else:
            raise Exception(
                f"TD3 baseline does not support the '{self.observation_type}' "
                f"observation type."
            )

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (
            policy_params["save_codes"] if "save_codes" in policy_params else None
        )
        self.memory = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            observation_type=self.observation_type,
            device_name=self.device_name,
        )
        self.num_actor_updates = 0
        self.current_iteration = 0
        self.step_count = 0
        self.init_networks(actor_network_class, critic_network_class, network_params)
        if checkpoint_dir:
            self.load(checkpoint_dir)