def load_model():
    depth_256_space = SpaceDict({
        'depth':
        spaces.Box(low=0., high=1., shape=(256, 256, 1)),
        'pointgoal_with_gps_compass':
        spaces.Box(
            low=np.finfo(np.float32).min,
            high=np.finfo(np.float32).max,
            shape=(2, ),
            dtype=np.float32,
        )
    })

    if GAUSSIAN:
        action_space = spaces.Box(np.array([float('-inf'),
                                            float('-inf')]),
                                  np.array([float('inf'),
                                            float('inf')]))
        action_distribution = 'gaussian'
        dim_actions = 2
    elif DISCRETE_4:
        action_space = spaces.Discrete(4)
        action_distribution = 'categorical'
        dim_actions = 4
    elif DISCRETE_6:
        action_space = spaces.Discrete(6)
        action_distribution = 'categorical'
        dim_actions = 6

    model = PointNavResNetPolicy(observation_space=depth_256_space,
                                 action_space=action_space,
                                 hidden_size=512,
                                 rnn_type='LSTM',
                                 num_recurrent_layers=2,
                                 backbone='resnet50',
                                 normalize_visual_inputs=False,
                                 action_distribution=action_distribution,
                                 dim_actions=dim_actions)
    model.to(torch.device("cpu"))

    data_dict = OrderedDict()
    with open(WEIGHTS_PATH, 'r') as f:
        data_dict = json.load(f)
    model.load_state_dict({
        k[len("actor_critic."):]: torch.tensor(v)
        for k, v in data_dict.items() if k.startswith("actor_critic.")
    })

    return model
Esempio n. 2
0
    def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None:
        r"""Sets up actor critic and agent for DD-PPO.

        Args:
            ppo_cfg: config node with relevant params

        Returns:
            None
        """
        logger.add_filehandler(self.config.LOG_FILE)

        self.actor_critic = PointNavResNetPolicy(
            observation_space=self.envs.observation_spaces[0],
            action_space=self.envs.action_spaces[0],
            hidden_size=ppo_cfg.hidden_size,
            rnn_type=self.config.RL.DDPPO.rnn_type,
            num_recurrent_layers=self.config.RL.DDPPO.num_recurrent_layers,
            backbone=self.config.RL.DDPPO.backbone,
            goal_sensor_uuid=self.config.TASK_CONFIG.TASK.GOAL_SENSOR_UUID,
            normalize_visual_inputs="rgb"
            in self.envs.observation_spaces[0].spaces,
        )
        self.actor_critic.to(self.device)

        if (self.config.RL.DDPPO.pretrained_encoder
                or self.config.RL.DDPPO.pretrained):
            pretrained_state = torch.load(
                self.config.RL.DDPPO.pretrained_weights, map_location="cpu")

        if self.config.RL.DDPPO.pretrained:
            self.actor_critic.load_state_dict({
                k[len("actor_critic."):]: v
                for k, v in pretrained_state["state_dict"].items()
            })
        elif self.config.RL.DDPPO.pretrained_encoder:
            prefix = "actor_critic.net.visual_encoder."
            self.actor_critic.net.visual_encoder.load_state_dict({
                k[len(prefix):]: v
                for k, v in pretrained_state["state_dict"].items()
                if k.startswith(prefix)
            })

        if not self.config.RL.DDPPO.train_encoder:
            self._static_encoder = True
            for param in self.actor_critic.net.visual_encoder.parameters():
                param.requires_grad_(False)

        if self.config.RL.DDPPO.reset_critic:
            nn.init.orthogonal_(self.actor_critic.critic.fc.weight)
            nn.init.constant_(self.actor_critic.critic.fc.bias, 0)

        self.agent = DDPPO(
            actor_critic=self.actor_critic,
            clip_param=ppo_cfg.clip_param,
            ppo_epoch=ppo_cfg.ppo_epoch,
            num_mini_batch=ppo_cfg.num_mini_batch,
            value_loss_coef=ppo_cfg.value_loss_coef,
            entropy_coef=ppo_cfg.entropy_coef,
            lr=ppo_cfg.lr,
            eps=ppo_cfg.eps,
            max_grad_norm=ppo_cfg.max_grad_norm,
            use_normalized_advantage=ppo_cfg.use_normalized_advantage,
        )
Esempio n. 3
0
    def __init__(self, config: Config):
        if "ObjectNav" in config.TASK_CONFIG.TASK.TYPE:
            OBJECT_CATEGORIES_NUM = 20
            spaces = {
                "objectgoal": Box(
                    low=0, 
                    high=OBJECT_CATEGORIES_NUM, 
                    shape=(1,), 
                    dtype=np.int64),
                "compass": Box(
                    low=-np.pi, 
                    high=np.pi, 
                    shape=(1,), 
                    dtype=np.float),
                "gps": Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(2,),
                   dtype=np.float32,)
            }
        else:
            spaces = {
                "pointgoal": Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(2,),
                    dtype=np.float32,
                )
            }

        if config.INPUT_TYPE in ["depth", "rgbd"]:
            spaces["depth"] = Box(
                low=0,
                high=1,
                shape=(config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.HEIGHT,
                        config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH, 1),
                dtype=np.float32,
            )

        if config.INPUT_TYPE in ["rgb", "rgbd"]:
            spaces["rgb"] = Box(
                low=0,
                high=255,
                shape=(config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.HEIGHT,
                        config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.WIDTH, 3),
                dtype=np.uint8,
            )
        observation_spaces = Dict(spaces)

        action_space = Discrete(len(config.TASK_CONFIG.TASK.POSSIBLE_ACTIONS))

        self.device = torch.device("cuda:{}".format(config.TORCH_GPU_ID))
        self.hidden_size = config.RL.PPO.hidden_size

        random.seed(config.RANDOM_SEED)
        torch.random.manual_seed(config.RANDOM_SEED)
        torch.backends.cudnn.deterministic = True

        self.actor_critic = PointNavResNetPolicy(
            observation_space=observation_spaces,
            action_space=action_space,
            hidden_size=self.hidden_size,
            normalize_visual_inputs="rgb" if config.INPUT_TYPE in ["rgb", "rgbd"] else False,
        )
        self.actor_critic.to(self.device)

        if config.MODEL_PATH:
            ckpt = torch.load(config.MODEL_PATH, map_location=self.device)
            print(f"Checkpoint loaded: {config.MODEL_PATH}")
            #  Filter only actor_critic weights
            self.actor_critic.load_state_dict(
                {
                    k.replace("actor_critic.", ""): v
                    for k, v in ckpt["state_dict"].items()
                    if "actor_critic" in k
                }
            )

        else:
            habitat.logger.error(
                "Model checkpoint wasn't loaded, evaluating " "a random model."
            )

        self.test_recurrent_hidden_states = None
        self.not_done_masks = None
        self.prev_actions = None