def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: r"""Sets up actor critic and agent for PPO. Args: ppo_cfg: config node with relevant params Returns: None """ logger.add_filehandler(self.config.LOG_FILE) self.actor_critic = PointNavBaselinePolicy( observation_space=self.envs.observation_spaces[0], action_space=self.envs.action_spaces[0], hidden_size=ppo_cfg.hidden_size, goal_sensor_uuid=self.config.TASK_CONFIG.TASK.GOAL_SENSOR_UUID, ) self.actor_critic.to(self.device) self.agent = PPO( actor_critic=self.actor_critic, clip_param=ppo_cfg.clip_param, ppo_epoch=ppo_cfg.ppo_epoch, num_mini_batch=ppo_cfg.num_mini_batch, value_loss_coef=ppo_cfg.value_loss_coef, entropy_coef=ppo_cfg.entropy_coef, lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, use_normalized_advantage=ppo_cfg.use_normalized_advantage, )
def __init__(self, config: Config): self.goal_sensor_uuid = config.GOAL_SENSOR_UUID spaces = { self.goal_sensor_uuid: Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2, ), dtype=np.float32, ) } if config.INPUT_TYPE in ["depth", "rgbd"]: spaces["depth"] = Box( low=0, high=1, shape=(config.RESOLUTION, config.RESOLUTION, 1), dtype=np.float32, ) if config.INPUT_TYPE in ["rgb", "rgbd"]: spaces["rgb"] = Box( low=0, high=255, shape=(config.RESOLUTION, config.RESOLUTION, 3), dtype=np.uint8, ) observation_spaces = Dict(spaces) action_spaces = Discrete(4) self.device = (torch.device("cuda:{}".format(config.PTH_GPU_ID)) if torch.cuda.is_available() else torch.device("cpu")) self.hidden_size = config.HIDDEN_SIZE random.seed(config.RANDOM_SEED) torch.random.manual_seed(config.RANDOM_SEED) if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True self.actor_critic = PointNavBaselinePolicy( observation_space=observation_spaces, action_space=action_spaces, hidden_size=self.hidden_size, goal_sensor_uuid=self.goal_sensor_uuid, ) self.actor_critic.to(self.device) if config.MODEL_PATH: ckpt = torch.load(config.MODEL_PATH, map_location=self.device) # Filter only actor_critic weights self.actor_critic.load_state_dict({ k[len("actor_critic."):]: v for k, v in ckpt["state_dict"].items() if "actor_critic" in k }) else: habitat.logger.error("Model checkpoint wasn't loaded, evaluating " "a random model.") self.test_recurrent_hidden_states = None self.not_done_masks = None self.prev_actions = None