def test_default_action_discrete_adapter(self): ADAPTER_TYPE = adapters.AdapterType.DefaultActionDiscrete adapter = adapters.adapter_from_type(ADAPTER_TYPE) interface = adapters.required_interface_from_types(ADAPTER_TYPE) space = adapters.space_from_type(ADAPTER_TYPE) AVAILABLE_ACTIONS = [ "keep_lane", "slow_down", "change_lane_left", "change_lane_right", ] agent, environment = prepare_test_agent_and_environment( required_interface=interface, action_adapter=adapter, ) action_sequence, _, _, _ = run_experiment(agent, environment) for action in action_sequence: self.assertIsInstance(action, str) self.assertIn(action, AVAILABLE_ACTIONS) self.assertEqual(space.dtype, type(action)) self.assertEqual(space.shape, ()) self.assertTrue(space.contains(action))
def test_default_action_continuous_adapter(self): ADAPTER_TYPE = adapters.AdapterType.DefaultActionContinuous adapter = adapters.adapter_from_type(ADAPTER_TYPE) interface = adapters.required_interface_from_types(ADAPTER_TYPE) space = adapters.space_from_type(ADAPTER_TYPE) agent, environment = prepare_test_agent_and_environment( required_interface=interface, action_adapter=adapter, ) action_sequence, _, _, _ = run_experiment(agent, environment) for action in action_sequence: self.assertIsInstance(action, np.ndarray) self.assertEqual(action.dtype, "float32") self.assertEqual(action.shape, (3, )) self.assertGreaterEqual(action[0], 0.0) self.assertLessEqual(action[0], 1.0) self.assertGreaterEqual(action[1], 0.0) self.assertLessEqual(action[1], 1.0) self.assertGreaterEqual(action[2], -1.0) self.assertLessEqual(action[2], 1.0) self.assertEqual(space.dtype, action.dtype) self.assertEqual(space.shape, action.shape) self.assertTrue(space.contains(action))
def test_default_observation_vector_adapter(self): ADAPTER_TYPE = adapters.AdapterType.DefaultObservationVector adapter = adapters.adapter_from_type(ADAPTER_TYPE) interface = adapters.required_interface_from_types(ADAPTER_TYPE) space = adapters.space_from_type(ADAPTER_TYPE) agent, environment = prepare_test_agent_and_environment( required_interface=interface, observation_adapter=adapter, ) _, _, observations_sequence, _ = run_experiment(agent, environment, max_steps=1) observations = observations_sequence[0] self.assertIsInstance(observations, dict) self.assertIn(AGENT_ID, observations) self.assertIn("low_dim_states", observations[AGENT_ID]) self.assertIn("social_vehicles", observations[AGENT_ID]) self.assertIsInstance(observations[AGENT_ID]["low_dim_states"], np.ndarray) self.assertIsInstance(observations[AGENT_ID]["social_vehicles"], np.ndarray) self.assertEqual(observations[AGENT_ID]["low_dim_states"].dtype, "float32") self.assertEqual(observations[AGENT_ID]["social_vehicles"].dtype, "float32") self.assertEqual(observations[AGENT_ID]["low_dim_states"].shape, (47, )) self.assertEqual(observations[AGENT_ID]["social_vehicles"].shape, (10, 4)) self.assertEqual(space.dtype, None) self.assertEqual( space["low_dim_states"].dtype, observations[AGENT_ID]["low_dim_states"].dtype, ) self.assertEqual( space["social_vehicles"].dtype, observations[AGENT_ID]["social_vehicles"].dtype, ) self.assertEqual(space.shape, None) self.assertEqual( space["low_dim_states"].shape, observations[AGENT_ID]["low_dim_states"].shape, ) self.assertEqual( space["social_vehicles"].shape, observations[AGENT_ID]["social_vehicles"].shape, ) self.assertTrue(space.contains(observations[AGENT_ID]))
def test_default_observation_image_adapter(self): ADAPTER_TYPE = adapters.AdapterType.DefaultObservationImage adapter = adapters.adapter_from_type(ADAPTER_TYPE) interface = adapters.required_interface_from_types(ADAPTER_TYPE) space = adapters.space_from_type(ADAPTER_TYPE) agent, environment = prepare_test_agent_and_environment( required_interface=interface, observation_adapter=adapter, ) _, _, observations_sequence, _ = run_experiment(agent, environment, max_steps=1) observations = observations_sequence[0] self.assertIsInstance(observations, dict) self.assertIn(AGENT_ID, observations) self.assertIsInstance(observations[AGENT_ID], np.ndarray) self.assertEqual(observations[AGENT_ID].dtype, "float32") self.assertEqual(observations[AGENT_ID].shape, (4, 64, 64)) self.assertEqual(space.dtype, observations[AGENT_ID].dtype) self.assertEqual(space.shape, observations[AGENT_ID].shape) self.assertTrue(space.contains(observations[AGENT_ID]))
def train( task, num_episodes, max_episode_steps, rollout_fragment_length, policy, eval_info, timestep_sec, headless, seed, train_batch_size, sgd_minibatch_size, log_dir, ): agent_name = policy policy_params = load_yaml( f"ultra/baselines/{agent_name}/{agent_name}/params.yaml") action_type = adapters.type_from_string(policy_params["action_type"]) observation_type = adapters.type_from_string( policy_params["observation_type"]) reward_type = adapters.type_from_string(policy_params["reward_type"]) if action_type != adapters.AdapterType.DefaultActionContinuous: raise Exception( f"RLlib training only supports the " f"{adapters.AdapterType.DefaultActionContinuous} action type.") if observation_type != adapters.AdapterType.DefaultObservationVector: # NOTE: The SMARTS observations adaptation that is done in ULTRA's Gym # environment is not done in ULTRA's RLlib environment. If other # observation adapters are used, they may raise an Exception. raise Exception( f"RLlib training only supports the " f"{adapters.AdapterType.DefaultObservationVector} observation type." ) action_space = adapters.space_from_type(adapter_type=action_type) observation_space = adapters.space_from_type(adapter_type=observation_type) action_adapter = adapters.adapter_from_type(adapter_type=action_type) info_adapter = adapters.adapter_from_type( adapter_type=adapters.AdapterType.DefaultInfo) observation_adapter = adapters.adapter_from_type( adapter_type=observation_type) reward_adapter = adapters.adapter_from_type(adapter_type=reward_type) params_seed = policy_params["seed"] encoder_key = policy_params["social_vehicles"]["encoder_key"] num_social_features = observation_space["social_vehicles"].shape[1] social_capacity = observation_space["social_vehicles"].shape[0] social_policy_hidden_units = int(policy_params["social_vehicles"].get( "social_policy_hidden_units", 0)) social_policy_init_std = int(policy_params["social_vehicles"].get( "social_policy_init_std", 0)) social_vehicle_config = get_social_vehicle_configs( encoder_key=encoder_key, num_social_features=num_social_features, social_capacity=social_capacity, seed=params_seed, social_policy_hidden_units=social_policy_hidden_units, social_policy_init_std=social_policy_init_std, ) ModelCatalog.register_custom_model("fc_model", CustomFCModel) config = RllibAgent.rllib_default_config(agent_name) rllib_policies = { "default_policy": ( None, observation_space, action_space, { "model": { "custom_model": "fc_model", "custom_model_config": { "social_vehicle_config": social_vehicle_config }, } }, ) } agent_specs = { "AGENT-007": AgentSpec( interface=AgentInterface( waypoints=Waypoints(lookahead=20), neighborhood_vehicles=NeighborhoodVehicles(200), action=ActionSpaceType.Continuous, rgb=False, max_episode_steps=max_episode_steps, debug=True, ), agent_params={}, agent_builder=None, action_adapter=action_adapter, info_adapter=info_adapter, observation_adapter=observation_adapter, reward_adapter=reward_adapter, ) } tune_config = { "env": RLlibUltraEnv, "log_level": "WARN", "callbacks": Callbacks, "framework": "torch", "num_workers": 1, "train_batch_size": train_batch_size, "sgd_minibatch_size": sgd_minibatch_size, "rollout_fragment_length": rollout_fragment_length, "in_evaluation": True, "evaluation_num_episodes": eval_info["eval_episodes"], "evaluation_interval": eval_info[ "eval_rate"], # Evaluation occurs after # of eval-intervals (episodes) "evaluation_config": { "env_config": { "seed": seed, "scenario_info": task, "headless": headless, "eval_mode": True, "ordered_scenarios": False, "agent_specs": agent_specs, "timestep_sec": timestep_sec, }, "explore": False, }, "env_config": { "seed": seed, "scenario_info": task, "headless": headless, "eval_mode": False, "ordered_scenarios": False, "agent_specs": agent_specs, "timestep_sec": timestep_sec, }, "multiagent": { "policies": rllib_policies }, } config.update(tune_config) agent = RllibAgent( agent_name=agent_name, env=RLlibUltraEnv, config=tune_config, logger_creator=log_creator(log_dir), ) # Iteration value in trainer.py (self._iterations) is the technically the number of episodes for i in range(num_episodes): results = agent.train() agent.log_evaluation_metrics( results) # Evaluation metrics will now be displayed on Tensorboard
def __init__( self, policy_params=None, checkpoint_dir=None, ): self.policy_params = policy_params self.batch_size = int(policy_params["batch_size"]) self.hidden_units = int(policy_params["hidden_units"]) self.mini_batch_size = int(policy_params["mini_batch_size"]) self.epoch_count = int(policy_params["epoch_count"]) self.gamma = float(policy_params["gamma"]) self.l = float(policy_params["l"]) self.eps = float(policy_params["eps"]) self.actor_tau = float(policy_params["actor_tau"]) self.critic_tau = float(policy_params["critic_tau"]) self.entropy_tau = float(policy_params["entropy_tau"]) self.logging_freq = int(policy_params["logging_freq"]) self.current_iteration = 0 self.current_log_prob = None self.current_value = None self.seed = int(policy_params["seed"]) self.lr = float(policy_params["lr"]) self.log_probs = [] self.values = [] self.rewards = [] self.actions = [] self.states = [] self.terminals = [] self.action_size = 2 self.prev_action = np.zeros(self.action_size) self.action_type = adapters.type_from_string( policy_params["action_type"]) self.observation_type = adapters.type_from_string( policy_params["observation_type"]) self.reward_type = adapters.type_from_string( policy_params["reward_type"]) if self.action_type != adapters.AdapterType.DefaultActionContinuous: raise Exception( f"PPO baseline only supports the " f"{adapters.AdapterType.DefaultActionContinuous} action type.") if self.observation_type != adapters.AdapterType.DefaultObservationVector: raise Exception( f"PPO baseline only supports the " f"{adapters.AdapterType.DefaultObservationVector} observation type." ) self.observation_space = adapters.space_from_type( self.observation_type) self.low_dim_states_size = self.observation_space[ "low_dim_states"].shape[0] self.social_capacity = self.observation_space["social_vehicles"].shape[ 0] self.num_social_features = self.observation_space[ "social_vehicles"].shape[1] self.encoder_key = policy_params["social_vehicles"]["encoder_key"] self.social_policy_hidden_units = int( policy_params["social_vehicles"].get("social_policy_hidden_units", 0)) self.social_policy_init_std = int(policy_params["social_vehicles"].get( "social_policy_init_std", 0)) self.social_vehicle_config = get_social_vehicle_configs( encoder_key=self.encoder_key, num_social_features=self.num_social_features, social_capacity=self.social_capacity, seed=self.seed, social_policy_hidden_units=self.social_policy_hidden_units, social_policy_init_std=self.social_policy_init_std, ) self.social_vehicle_encoder = self.social_vehicle_config["encoder"] self.social_feature_encoder_class = self.social_vehicle_encoder[ "social_feature_encoder_class"] self.social_feature_encoder_params = self.social_vehicle_encoder[ "social_feature_encoder_params"] # others self.checkpoint_dir = checkpoint_dir self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_name) self.save_codes = (policy_params["save_codes"] if "save_codes" in policy_params else None) # PPO self.ppo_net = PPONetwork( self.action_size, self.state_size, hidden_units=self.hidden_units, init_std=self.social_policy_init_std, seed=self.seed, social_feature_encoder_class=self.social_feature_encoder_class, social_feature_encoder_params=self.social_feature_encoder_params, ).to(self.device) self.optimizer = torch.optim.Adam(self.ppo_net.parameters(), lr=self.lr) self.step_count = 0 if self.checkpoint_dir: self.load(self.checkpoint_dir)
def __init__( self, policy_params=None, checkpoint_dir=None, ): self.policy_params = policy_params self.lr = float(policy_params["lr"]) self.seed = int(policy_params["seed"]) self.train_step = int(policy_params["train_step"]) self.target_update = float(policy_params["target_update"]) self.warmup = int(policy_params["warmup"]) self.gamma = float(policy_params["gamma"]) self.batch_size = int(policy_params["batch_size"]) self.use_ddqn = policy_params["use_ddqn"] self.sticky_actions = int(policy_params["sticky_actions"]) self.epsilon_obj = EpsilonExplore(1.0, 0.05, 100000) self.step_count = 0 self.update_count = 0 self.num_updates = 0 self.current_sticky = 0 self.current_iteration = 0 self.action_type = adapters.type_from_string( policy_params["action_type"]) self.observation_type = adapters.type_from_string( policy_params["observation_type"]) self.reward_type = adapters.type_from_string( policy_params["reward_type"]) if self.action_type == adapters.AdapterType.DefaultActionContinuous: discrete_action_spaces = [ np.asarray([-0.25, 0.0, 0.5, 0.75, 1.0]), np.asarray([ -1.0, -0.75, -0.5, -0.25, -0.1, 0.0, 0.1, 0.25, 0.5, 0.75, 1.0 ]), ] self.index2actions = [ merge_discrete_action_spaces([discrete_action_space])[0] for discrete_action_space in discrete_action_spaces ] self.action2indexs = [ merge_discrete_action_spaces([discrete_action_space])[1] for discrete_action_space in discrete_action_spaces ] self.merge_action_spaces = 0 self.num_actions = [ len(discrete_action_space) for discrete_action_space in discrete_action_spaces ] self.action_size = 2 self.to_real_action = to_3d_action elif self.action_type == adapters.AdapterType.DefaultActionDiscrete: discrete_action_spaces = [[0], [1], [2], [3]] index_to_actions = [ discrete_action_space.tolist() if not isinstance(discrete_action_space, list) else discrete_action_space for discrete_action_space in discrete_action_spaces ] action_to_indexs = { str(discrete_action): index for discrete_action, index in zip( index_to_actions, np.arange(len(index_to_actions)).astype(np.int)) } self.index2actions = [index_to_actions] self.action2indexs = [action_to_indexs] self.merge_action_spaces = -1 self.num_actions = [len(index_to_actions)] self.action_size = 1 self.to_real_action = lambda action: self.lane_actions[action[0]] else: raise Exception( f"DQN baseline does not support the '{self.action_type}' action type." ) if self.observation_type == adapters.AdapterType.DefaultObservationVector: observation_space = adapters.space_from_type(self.observation_type) low_dim_states_size = observation_space["low_dim_states"].shape[0] social_capacity = observation_space["social_vehicles"].shape[0] num_social_features = observation_space["social_vehicles"].shape[1] # Get information to build the encoder. encoder_key = policy_params["social_vehicles"]["encoder_key"] social_policy_hidden_units = int( policy_params["social_vehicles"].get( "social_policy_hidden_units", 0)) social_policy_init_std = int(policy_params["social_vehicles"].get( "social_policy_init_std", 0)) social_vehicle_config = get_social_vehicle_configs( encoder_key=encoder_key, num_social_features=num_social_features, social_capacity=social_capacity, seed=self.seed, social_policy_hidden_units=social_policy_hidden_units, social_policy_init_std=social_policy_init_std, ) social_vehicle_encoder = social_vehicle_config["encoder"] social_feature_encoder_class = social_vehicle_encoder[ "social_feature_encoder_class"] social_feature_encoder_params = social_vehicle_encoder[ "social_feature_encoder_params"] # Calculate the state size based on the number of features (ego + social). state_size = low_dim_states_size if social_feature_encoder_class: state_size += social_feature_encoder_class( **social_feature_encoder_params).output_dim else: state_size += social_capacity * num_social_features # Add the action size to account for the previous action. state_size += self.action_size network_class = DQNWithSocialEncoder network_params = { "num_actions": self.num_actions, "state_size": state_size, "social_feature_encoder_class": social_feature_encoder_class, "social_feature_encoder_params": social_feature_encoder_params, } elif self.observation_type == adapters.AdapterType.DefaultObservationImage: observation_space = adapters.space_from_type(self.observation_type) stack_size = observation_space.shape[0] image_shape = (observation_space.shape[1], observation_space.shape[2]) network_class = DQNCNN network_params = { "n_in_channels": stack_size, "image_dim": image_shape, "num_actions": self.num_actions, } else: raise Exception( f"DQN baseline does not support the '{self.observation_type}' " f"observation type.") self.prev_action = np.zeros(self.action_size) self.checkpoint_dir = checkpoint_dir torch.manual_seed(self.seed) self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_name) self.online_q_network = network_class(**network_params).to(self.device) self.target_q_network = network_class(**network_params).to(self.device) self.update_target_network() self.optimizers = torch.optim.Adam( params=self.online_q_network.parameters(), lr=self.lr) self.loss_func = nn.MSELoss(reduction="none") self.replay = ReplayBuffer( buffer_size=int(policy_params["replay_buffer"]["buffer_size"]), batch_size=int(policy_params["replay_buffer"]["batch_size"]), observation_type=self.observation_type, device_name=self.device_name, ) self.reset() if self.checkpoint_dir: self.load(self.checkpoint_dir)
def __init__( self, policy_params=None, checkpoint_dir=None, ): # print("LOADING THE PARAMS", policy_params, checkpoint_dir) self.policy_params = policy_params self.gamma = float(policy_params["gamma"]) self.critic_lr = float(policy_params["critic_lr"]) self.actor_lr = float(policy_params["actor_lr"]) self.critic_update_rate = int(policy_params["critic_update_rate"]) self.policy_update_rate = int(policy_params["policy_update_rate"]) self.warmup = int(policy_params["warmup"]) self.seed = int(policy_params["seed"]) self.batch_size = int(policy_params["batch_size"]) self.hidden_units = int(policy_params["hidden_units"]) self.tau = float(policy_params["tau"]) self.initial_alpha = float(policy_params["initial_alpha"]) self.logging_freq = int(policy_params["logging_freq"]) self.action_size = 2 self.prev_action = np.zeros(self.action_size) self.action_type = adapters.type_from_string( policy_params["action_type"]) self.observation_type = adapters.type_from_string( policy_params["observation_type"]) self.reward_type = adapters.type_from_string( policy_params["reward_type"]) if self.action_type != adapters.AdapterType.DefaultActionContinuous: raise Exception( f"SAC baseline only supports the " f"{adapters.AdapterType.DefaultActionContinuous} action type.") if self.observation_type != adapters.AdapterType.DefaultObservationVector: raise Exception( f"SAC baseline only supports the " f"{adapters.AdapterType.DefaultObservationVector} observation type." ) self.observation_space = adapters.space_from_type( self.observation_type) self.low_dim_states_size = self.observation_space[ "low_dim_states"].shape[0] self.social_capacity = self.observation_space["social_vehicles"].shape[ 0] self.num_social_features = self.observation_space[ "social_vehicles"].shape[1] self.encoder_key = policy_params["social_vehicles"]["encoder_key"] self.social_policy_hidden_units = int( policy_params["social_vehicles"].get("social_policy_hidden_units", 0)) self.social_policy_init_std = int(policy_params["social_vehicles"].get( "social_policy_init_std", 0)) self.social_vehicle_config = get_social_vehicle_configs( encoder_key=self.encoder_key, num_social_features=self.num_social_features, social_capacity=self.social_capacity, seed=self.seed, social_policy_hidden_units=self.social_policy_hidden_units, social_policy_init_std=self.social_policy_init_std, ) self.social_vehicle_encoder = self.social_vehicle_config["encoder"] self.social_feature_encoder_class = self.social_vehicle_encoder[ "social_feature_encoder_class"] self.social_feature_encoder_params = self.social_vehicle_encoder[ "social_feature_encoder_params"] # others self.checkpoint_dir = checkpoint_dir self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_name) self.save_codes = (policy_params["save_codes"] if "save_codes" in policy_params else None) self.memory = ReplayBuffer( buffer_size=int(policy_params["replay_buffer"]["buffer_size"]), batch_size=int(policy_params["replay_buffer"]["batch_size"]), observation_type=self.observation_type, device_name=self.device_name, ) self.current_iteration = 0 self.steps = 0 self.init_networks() if checkpoint_dir: self.load(checkpoint_dir)
def __init__( self, policy_params=None, checkpoint_dir=None, ): self.policy_params = policy_params self.action_size = 2 self.action_range = np.asarray([[-1.0, 1.0], [-1.0, 1.0]], dtype=np.float32) self.actor_lr = float(policy_params["actor_lr"]) self.critic_lr = float(policy_params["critic_lr"]) self.critic_wd = float(policy_params["critic_wd"]) self.actor_wd = float(policy_params["actor_wd"]) self.noise_clip = float(policy_params["noise_clip"]) self.policy_noise = float(policy_params["policy_noise"]) self.update_rate = int(policy_params["update_rate"]) self.policy_delay = int(policy_params["policy_delay"]) self.warmup = int(policy_params["warmup"]) self.critic_tau = float(policy_params["critic_tau"]) self.actor_tau = float(policy_params["actor_tau"]) self.gamma = float(policy_params["gamma"]) self.batch_size = int(policy_params["batch_size"]) self.sigma = float(policy_params["sigma"]) self.theta = float(policy_params["theta"]) self.dt = float(policy_params["dt"]) self.action_low = torch.tensor([[each[0] for each in self.action_range]]) self.action_high = torch.tensor([[each[1] for each in self.action_range]]) self.seed = int(policy_params["seed"]) self.prev_action = np.zeros(self.action_size) self.action_type = adapters.type_from_string(policy_params["action_type"]) self.observation_type = adapters.type_from_string( policy_params["observation_type"] ) self.reward_type = adapters.type_from_string(policy_params["reward_type"]) if self.action_type != adapters.AdapterType.DefaultActionContinuous: raise Exception( f"TD3 baseline only supports the " f"{adapters.AdapterType.DefaultActionContinuous} action type." ) if self.observation_type == adapters.AdapterType.DefaultObservationVector: observation_space = adapters.space_from_type(self.observation_type) low_dim_states_size = observation_space["low_dim_states"].shape[0] social_capacity = observation_space["social_vehicles"].shape[0] num_social_features = observation_space["social_vehicles"].shape[1] # Get information to build the encoder. encoder_key = policy_params["social_vehicles"]["encoder_key"] social_policy_hidden_units = int( policy_params["social_vehicles"].get("social_policy_hidden_units", 0) ) social_policy_init_std = int( policy_params["social_vehicles"].get("social_policy_init_std", 0) ) social_vehicle_config = get_social_vehicle_configs( encoder_key=encoder_key, num_social_features=num_social_features, social_capacity=social_capacity, seed=self.seed, social_policy_hidden_units=social_policy_hidden_units, social_policy_init_std=social_policy_init_std, ) social_vehicle_encoder = social_vehicle_config["encoder"] social_feature_encoder_class = social_vehicle_encoder[ "social_feature_encoder_class" ] social_feature_encoder_params = social_vehicle_encoder[ "social_feature_encoder_params" ] # Calculate the state size based on the number of features (ego + social). state_size = low_dim_states_size if social_feature_encoder_class: state_size += social_feature_encoder_class( **social_feature_encoder_params ).output_dim else: state_size += social_capacity * num_social_features # Add the action size to account for the previous action. state_size += self.action_size actor_network_class = FCActorNetwork critic_network_class = FCCrtiicNetwork network_params = { "state_space": state_size, "action_space": self.action_size, "seed": self.seed, "social_feature_encoder": social_feature_encoder_class( **social_feature_encoder_params ) if social_feature_encoder_class else None, } elif self.observation_type == adapters.AdapterType.DefaultObservationImage: observation_space = adapters.space_from_type(self.observation_type) stack_size = observation_space.shape[0] image_shape = (observation_space.shape[1], observation_space.shape[2]) actor_network_class = CNNActorNetwork critic_network_class = CNNCriticNetwork network_params = { "input_channels": stack_size, "input_dimension": image_shape, "action_size": self.action_size, "seed": self.seed, } else: raise Exception( f"TD3 baseline does not support the '{self.observation_type}' " f"observation type." ) # others self.checkpoint_dir = checkpoint_dir self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_name) self.save_codes = ( policy_params["save_codes"] if "save_codes" in policy_params else None ) self.memory = ReplayBuffer( buffer_size=int(policy_params["replay_buffer"]["buffer_size"]), batch_size=int(policy_params["replay_buffer"]["batch_size"]), observation_type=self.observation_type, device_name=self.device_name, ) self.num_actor_updates = 0 self.current_iteration = 0 self.step_count = 0 self.init_networks(actor_network_class, critic_network_class, network_params) if checkpoint_dir: self.load(checkpoint_dir)