def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: r"""Sets up actor critic and agent for PPO. Args: ppo_cfg: config node with relevant params Returns: None """ logger.add_filehandler(self.config.LOG_FILE) policy = baseline_registry.get_policy(self.config.RL.POLICY.name) observation_space = self.obs_space self.obs_transforms = get_active_obs_transforms(self.config) observation_space = apply_obs_transforms_obs_space( observation_space, self.obs_transforms) self.actor_critic = policy.from_config(self.config, observation_space, self.envs.action_spaces[0]) self.obs_space = observation_space self.actor_critic.to(self.device) if (self.config.RL.DDPPO.pretrained_encoder or self.config.RL.DDPPO.pretrained): pretrained_state = torch.load( self.config.RL.DDPPO.pretrained_weights, map_location="cpu") if self.config.RL.DDPPO.pretrained: self.actor_critic.load_state_dict({ k[len("actor_critic."):]: v for k, v in pretrained_state["state_dict"].items() }) elif self.config.RL.DDPPO.pretrained_encoder: prefix = "actor_critic.net.visual_encoder." self.actor_critic.net.visual_encoder.load_state_dict({ k[len(prefix):]: v for k, v in pretrained_state["state_dict"].items() if k.startswith(prefix) }) if not self.config.RL.DDPPO.train_encoder: self._static_encoder = True for param in self.actor_critic.net.visual_encoder.parameters(): param.requires_grad_(False) if self.config.RL.DDPPO.reset_critic: nn.init.orthogonal_(self.actor_critic.critic.fc.weight) nn.init.constant_(self.actor_critic.critic.fc.bias, 0) self.agent = (DDPPO if self._is_distributed else PPO)( actor_critic=self.actor_critic, clip_param=ppo_cfg.clip_param, ppo_epoch=ppo_cfg.ppo_epoch, num_mini_batch=ppo_cfg.num_mini_batch, value_loss_coef=ppo_cfg.value_loss_coef, entropy_coef=ppo_cfg.entropy_coef, lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, use_normalized_advantage=ppo_cfg.use_normalized_advantage, )
def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: r"""Sets up actor critic and agent for PPO. Args: ppo_cfg: config node with relevant params Returns: None """ logger.add_filehandler(self.config.LOG_FILE) policy = baseline_registry.get_policy(self.config.RL.POLICY.name) self.actor_critic = policy( observation_space=self.envs.observation_spaces[0], action_space=self.envs.action_spaces[0], hidden_size=ppo_cfg.hidden_size, ) self.actor_critic.to(self.device) self.agent = PPO( actor_critic=self.actor_critic, clip_param=ppo_cfg.clip_param, ppo_epoch=ppo_cfg.ppo_epoch, num_mini_batch=ppo_cfg.num_mini_batch, value_loss_coef=ppo_cfg.value_loss_coef, entropy_coef=ppo_cfg.entropy_coef, lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, use_normalized_advantage=ppo_cfg.use_normalized_advantage, )
def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None: r"""Sets up actor critic and agent for DD-PPO. Args: ppo_cfg: config node with relevant params Returns: None """ logger.add_filehandler(self.config.LOG_FILE) policy = baseline_registry.get_policy(self.config.RL.POLICY.name) self.actor_critic = policy( observation_space=self.envs.observation_spaces[0], action_space=self.envs.action_spaces[0], hidden_size=ppo_cfg.hidden_size, rnn_type=self.config.RL.DDPPO.rnn_type, num_recurrent_layers=self.config.RL.DDPPO.num_recurrent_layers, backbone=self.config.RL.DDPPO.backbone, normalize_visual_inputs="rgb" in self.envs.observation_spaces[0].spaces, force_blind_policy=self.config.FORCE_BLIND_POLICY, ) self.actor_critic.to(self.device) if (self.config.RL.DDPPO.pretrained_encoder or self.config.RL.DDPPO.pretrained): pretrained_state = torch.load( self.config.RL.DDPPO.pretrained_weights, map_location="cpu") if self.config.RL.DDPPO.pretrained: self.actor_critic.load_state_dict({ k[len("actor_critic."):]: v for k, v in pretrained_state["state_dict"].items() }) elif self.config.RL.DDPPO.pretrained_encoder: prefix = "actor_critic.net.visual_encoder." self.actor_critic.net.visual_encoder.load_state_dict({ k[len(prefix):]: v for k, v in pretrained_state["state_dict"].items() if k.startswith(prefix) }) if not self.config.RL.DDPPO.train_encoder: self._static_encoder = True for param in self.actor_critic.net.visual_encoder.parameters(): param.requires_grad_(False) if self.config.RL.DDPPO.reset_critic: nn.init.orthogonal_(self.actor_critic.critic.fc.weight) nn.init.constant_(self.actor_critic.critic.fc.bias, 0) self.agent = DDPPO( actor_critic=self.actor_critic, clip_param=ppo_cfg.clip_param, ppo_epoch=ppo_cfg.ppo_epoch, num_mini_batch=ppo_cfg.num_mini_batch, value_loss_coef=ppo_cfg.value_loss_coef, entropy_coef=ppo_cfg.entropy_coef, lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, use_normalized_advantage=ppo_cfg.use_normalized_advantage, )
def __init__(self, config: Config) -> None: image_size = config.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER if "ObjectNav" in config.TASK_CONFIG.TASK.TYPE: OBJECT_CATEGORIES_NUM = 20 spaces = { "objectgoal": Box(low=0, high=OBJECT_CATEGORIES_NUM, shape=(1, ), dtype=np.int64), "compass": Box(low=-np.pi, high=np.pi, shape=(1, ), dtype=np.float32), "gps": Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2, ), dtype=np.float32, ), } else: spaces = { "pointgoal": Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2, ), dtype=np.float32, ) } if config.INPUT_TYPE in ["depth", "rgbd"]: spaces["depth"] = Box( low=0, high=1, shape=(image_size.HEIGHT, image_size.WIDTH, 1), dtype=np.float32, ) if config.INPUT_TYPE in ["rgb", "rgbd"]: spaces["rgb"] = Box( low=0, high=255, shape=(image_size.HEIGHT, image_size.WIDTH, 3), dtype=np.uint8, ) observation_spaces = SpaceDict(spaces) action_spaces = (Discrete(6) if "ObjectNav" in config.TASK_CONFIG.TASK.TYPE else Discrete(4)) self.obs_transforms = get_active_obs_transforms(config) observation_spaces = apply_obs_transforms_obs_space( observation_spaces, self.obs_transforms) self.device = (torch.device("cuda:{}".format(config.PTH_GPU_ID)) if torch.cuda.is_available() else torch.device("cpu")) self.hidden_size = config.RL.PPO.hidden_size random.seed(config.RANDOM_SEED) np.random.seed(config.RANDOM_SEED) _seed_numba(config.RANDOM_SEED) torch.random.manual_seed(config.RANDOM_SEED) if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # type: ignore policy = baseline_registry.get_policy(config.RL.POLICY.name) self.actor_critic = policy.from_config(config, observation_spaces, action_spaces) self.actor_critic.to(self.device) if config.MODEL_PATH: ckpt = torch.load(config.MODEL_PATH, map_location=self.device) # Filter only actor_critic weights self.actor_critic.load_state_dict({ k[len("actor_critic."):]: v for k, v in ckpt["state_dict"].items() if "actor_critic" in k }) else: habitat.logger.error("Model checkpoint wasn't loaded, evaluating " "a random model.") self.test_recurrent_hidden_states: Optional[torch.Tensor] = None self.not_done_masks: Optional[torch.Tensor] = None self.prev_actions: Optional[torch.Tensor] = None