def train_agent(agent):
    r"""Runs experiment given mode and config

    Args:
        exp_config: path to config file.
        run_type: "train" or "eval.
        opts: list of strings of additional config options.

    Returns:
        None.
    """
    config_file = r"/home/userone/workspace/bed1/ppo_custom/ppo_pointnav_example.yaml"
    config = get_config(config_file, None)
    # config = get_config1(agent)
    env = construct_envs(config, get_env_class(config.ENV_NAME))

    config.defrost()
    config.TASK_CONFIG.DATASET.DATA_PATH = config.DATA_PATH_SET
    config.TASK_CONFIG.DATASET.SCENES_DIR = config.SCENE_DIR_SET
    config.freeze()

    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)

    if agent == "ppo_agent":
        trainer = ppo_trainer.PPOTrainer(config)

    trainer.train(env)
Example #2
0
    def _load_real_data(self, config):
        real_mem = self.memory_real

        config = config.clone()
        config.defrost()
        config.TASK_CONFIG.DATASET.TYPE = "PepperRobot"
        config.ENV_NAME = "PepperPlaybackEnv"
        config.NUM_PROCESSES = 1
        config.freeze()

        real_env = construct_envs(
            config, get_env_class(config.ENV_NAME)
        )

        observations = real_env.reset()
        real_mem.insert([observations[0]["rgb"], observations[0]["depth"]])

        print(f"Loading {self.real_mem_size}  ROBOT observations ...")
        for i in range(self.real_mem_size):
            outputs = real_env.step([0])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            real_mem.insert([observations[0]["rgb"], observations[0]["depth"]])
        print(f"Loaded {self.real_mem_size} observations from ROBOT dataset")
Example #3
0
    def _init_envs(self, config=None):
        if config is None:
            config = self.config

        self.envs = construct_envs(
            config,
            get_env_class(config.ENV_NAME),
            workers_ignore_signals=is_slurm_batch_job(),
        )
Example #4
0
 def __init__(self, config=None):
     super().__init__(config)
     self.envs = construct_envs(self.config,
                                get_env_class(self.config.ENV_NAME))
     self._setup_actor_critic_agent(self.config)
     self.rollout = RolloutStorage(self.config.RL.PPO.num_steps,
                                   self.envs.num_envs,
                                   self.envs.observation_spaces[0],
                                   self.envs.action_spaces[0],
                                   self.config.RL.PPO.hidden_size)
     self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                    if torch.cuda.is_available() else torch.device("cpu"))
     self.rollout.to(self.device)
def run_random_agent():
    config_file = r"/home/userone/workspace/bed1/ppo_custom/ppo_pointnav_example.yaml"
    config = get_config(config_file, None)
    # config = get_config1(agent)
    env = construct_envs(config, get_env_class(config.ENV_NAME))

    config.defrost()
    config.TASK_CONFIG.DATASET.DATA_PATH = config.DATA_PATH_SET
    config.TASK_CONFIG.DATASET.SCENES_DIR = config.SCENE_DIR_SET
    config.freeze()

    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)

    random_agent.run(config, env, 3)
def trigger_transfer_learn(agent):
    config_file = r"/home/userone/workspace/bed1/ppo_custom/ppo_pointnav_example.yaml"
    config = get_config(config_file, None)
    # config = get_config1(agent)
    env = construct_envs(config, get_env_class(config.ENV_NAME))

    config.defrost()
    config.TASK_CONFIG.DATASET.DATA_PATH = config.DATA_PATH_SET
    config.TASK_CONFIG.DATASET.SCENES_DIR = config.SCENE_DIR_SET
    config.freeze()

    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)

    if agent == "ppo_agent":
        trainer = transfer_ppo.PPOTrainer(config)

    trainer.train(env)
Example #7
0
    def __init__(
        self,
        config: Any,
        video_option: List[str],
        video_dir: str,
        vid_filename_metrics: Set[str],
        traj_save_dir: str = None,
        should_save_fn=None,
        writer=None,
    ) -> None:

        env_class = get_env_class(config.ENV_NAME)

        env = make_env_fn(env_class=env_class, config=config)
        self._gym_env = HabGymWrapper(env, save_orig_obs=True)
        self._video_option = video_option
        self._video_dir = video_dir
        self._writer = writer
        self._vid_filename_metrics = vid_filename_metrics
        self._traj_save_path = traj_save_dir
        self._should_save_fn = should_save_fn
Example #8
0
def test_gym_wrapper_contract(config_file):
    config = baselines_get_config(config_file)
    env_class = get_env_class(config.ENV_NAME)

    env = habitat_baselines.utils.env_utils.make_env_fn(env_class=env_class,
                                                        config=config)
    env = HabGymWrapper(env)
    env = HabRenderWrapper(env)
    assert isinstance(env.action_space, spaces.Box)
    obs = env.reset()
    assert isinstance(obs, np.ndarray), f"Obs {obs}"
    assert obs.shape == env.observation_space.shape
    obs, _, _, info = env.step(env.action_space.sample())
    assert isinstance(obs, np.ndarray), f"Obs {obs}"
    assert obs.shape == env.observation_space.shape

    frame = env.render()
    assert isinstance(frame, np.ndarray)
    assert len(frame.shape) == 3 and frame.shape[-1] == 3

    for _, v in info.items():
        assert not isinstance(v, dict)
def test_rearrange_task(test_cfg_path):
    config = baselines_get_config(test_cfg_path)

    env_class = get_env_class(config.ENV_NAME)

    env = habitat_baselines.utils.env_utils.make_env_fn(
        env_class=env_class, config=config
    )

    with env:
        for _ in range(10):
            env.reset()
            done = False
            while not done:
                action = env.action_space.sample()
                habitat.logger.info(
                    f"Action : "
                    f"{action['action']}, "
                    f"args: {action['action_args']}."
                )
                _, _, done, info = env.step(action=action)

            logger.info(info)
def test_rearrange_task():
    config = baselines_get_config(
        "habitat_baselines/config/rearrange/ddppo_rearrangepick.yaml")
    # if not RearrangeDatasetV0.check_config_paths_exist(config.TASK_CONFIG.DATASET):
    #     pytest.skip("Test skipped as dataset files are missing.")

    env_class = get_env_class(config.ENV_NAME)

    env = habitat_baselines.utils.env_utils.make_env_fn(env_class=env_class,
                                                        config=config)

    with env:
        for _ in range(10):
            env.reset()
            done = False
            while not done:
                action = env.action_space.sample()
                habitat.logger.info(f"Action : "
                                    f"{action['action']}, "
                                    f"args: {action['action_args']}.")
                _, _, done, info = env.step(action=action)

            logger.info(info)
Example #11
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        ckpt_dict = self.load_checkpoint(checkpoint_path,
                                         map_location=self.device)

        config = self._setup_eval_config(ckpt_dict["config"])
        ppo_cfg = config.RL.PPO

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        # get name of performance metric, e.g. "spl"
        metric_name = self.config.TASK_CONFIG.TASK.MEASUREMENTS[0]
        metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name)
        measure_type = baseline_registry.get_measure(metric_cfg.TYPE)
        assert measure_type is not None, "invalid measurement type {}".format(
            metric_cfg.TYPE)
        self.metric_uuid = measure_type(None, None)._get_uuid()

        observations = self.envs.reset()
        batch = batch_obs(observations)
        for sensor in batch:
            batch[sensor] = batch[sensor].to(self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [
            []
        ] * self.config.NUM_PROCESSES  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                _, actions, _, test_recurrent_hidden_states = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    prev_actions,
                    not_done_masks,
                    deterministic=False,
                )

                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations)
            for sensor in batch:
                batch[sensor] = batch[sensor].to(self.device)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                if (
                        next_episodes[i].scene_id,
                        next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    episode_stats = dict()
                    episode_stats[self.metric_uuid] = infos[i][
                        self.metric_uuid]
                    episode_stats["success"] = int(
                        infos[i][self.metric_uuid] > 0)
                    episode_stats["reward"] = current_episode_reward[i].item()
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metric_name=self.metric_uuid,
                            metric_value=infos[i][self.metric_uuid],
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            # pausing self.envs with no new episode
            if len(envs_to_pause) > 0:
                state_index = list(range(self.envs.num_envs))
                for idx in reversed(envs_to_pause):
                    state_index.pop(idx)
                    self.envs.pause_at(idx)

                # indexing along the batch dimensions
                test_recurrent_hidden_states = test_recurrent_hidden_states[
                    state_index]
                not_done_masks = not_done_masks[state_index]
                current_episode_reward = current_episode_reward[state_index]
                prev_actions = prev_actions[state_index]

                for k, v in batch.items():
                    batch[k] = v[state_index]

                if len(self.config.VIDEO_OPTION) > 0:
                    rgb_frames = [rgb_frames[i] for i in state_index]

        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = sum(
                [v[stat_key] for v in stats_episodes.values()])
        num_episodes = len(stats_episodes)

        episode_reward_mean = aggregated_stats["reward"] / num_episodes
        episode_metric_mean = aggregated_stats[self.metric_uuid] / num_episodes
        episode_success_mean = aggregated_stats["success"] / num_episodes

        logger.info(f"Average episode reward: {episode_reward_mean:.6f}")
        logger.info(f"Average episode success: {episode_success_mean:.6f}")
        logger.info(
            f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}")

        writer.add_scalars(
            "eval_reward",
            {"average reward": episode_reward_mean},
            checkpoint_index,
        )
        writer.add_scalars(
            f"eval_{self.metric_uuid}",
            {f"average {self.metric_uuid}": episode_metric_mean},
            checkpoint_index,
        )
        writer.add_scalars(
            "eval_success",
            {"average success": episode_success_mean},
            checkpoint_index,
        )

        self.envs.close()
Example #12
0
    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        self.device = torch.device("cuda", self.config.TORCH_GPU_ID)
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        observations = self.envs.reset()
        batch = batch_obs(observations)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
        )
        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])
        rollouts.to(self.device)

        episode_rewards = torch.zeros(self.envs.num_envs, 1)
        episode_counts = torch.zeros(self.envs.num_envs, 1)
        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for update in range(self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                for step in range(ppo_cfg.num_steps):
                    delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step(
                        rollouts,
                        current_episode_reward,
                        episode_rewards,
                        episode_counts,
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent(
                    ppo_cfg, rollouts)
                pth_time += delta_pth_time

                window_episode_reward.append(episode_rewards.clone())
                window_episode_counts.append(episode_counts.clone())

                losses = [value_loss, action_loss]
                stats = zip(
                    ["count", "reward"],
                    [window_episode_counts, window_episode_reward],
                )
                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in stats
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                writer.add_scalars(
                    "losses",
                    {k: l
                     for l, k in zip(losses, ["value", "policy"])},
                    count_steps,
                )

                # log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    window_rewards = (window_episode_reward[-1] -
                                      window_episode_reward[0]).sum()
                    window_counts = (window_episode_counts[-1] -
                                     window_episode_counts[0]).sum()

                    if window_counts > 0:
                        logger.info(
                            "Average window size {} reward: {:3f}".format(
                                len(window_episode_reward),
                                (window_rewards / window_counts).item(),
                            ))
                    else:
                        logger.info("No episodes finish in current window")

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
                    count_checkpoints += 1

            self.envs.close()
Example #13
0
    def train(self):
        # Get environments for training
        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        #logger.info(
        #    "agent number of parameters: {}".format(
        #        sum(param.numel() for param in self.agent.parameters())
        #    )
        #)

        # Change for the actual value
        cfg = self.config.RL.PPO

        rollouts = RolloutStorage(
            cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            cfg.hidden_size,
        )
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations)
        for sensor in rollouts.observations:
            print(batch[sensor].shape)

        # Copy the information to the wrapper
        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        episode_rewards = torch.zeros(self.envs.num_envs, 1)
        episode_counts = torch.zeros(self.envs.num_envs, 1)
        #current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        #window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        #window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )
        '''
Example #14
0
    def __init__(self,
                 wrapper,
                 config_file,
                 base_config_file=None,
                 horizon=None,
                 gamma=0.99,
                 width=None,
                 height=None):
        """
        Constructor. For more details on how to pass YAML configuration files,
        please see <MUSHROOM_RL PATH>/examples/habitat/README.md

        Args:
             wrapper (str): wrapper for converting observations and actions
                (e.g., HabitatRearrangeWrapper);
             config_file (str): path to the YAML file specifying the RL task
                configuration (see <HABITAT_LAB PATH>/habitat_baselines/configs/);
             base_config_file (str, None): path to an optional YAML file, used
                as 'BASE_TASK_CONFIG_PATH' in the first YAML
                (see <HABITAT_LAB PATH>/configs/);
             horizon (int, None): the horizon;
             gamma (float, 0.99): the discount factor;
             width (int, None): width of the pixel observation. If None, the
                value specified in the config file is used.
             height (int, None): height of the pixel observation. If None, the
                value specified in the config file is used.

        """
        # MDP creation
        self._not_pybullet = False
        self._first = True

        if base_config_file is None:
            base_config_file = config_file

        config = get_config(config_paths=config_file,
                            opts=['BASE_TASK_CONFIG_PATH', base_config_file])

        config.defrost()

        if horizon is None:
            horizon = config.TASK_CONFIG.ENVIRONMENT.MAX_EPISODE_STEPS  # Get the default horizon
        config.TASK_CONFIG.ENVIRONMENT.MAX_EPISODE_STEPS = horizon + 1  # Hack to ignore gym time limit

        # Overwrite all RGB width / height used for the TASK (not SIMULATOR)
        for k in config['TASK_CONFIG']['SIMULATOR']:
            if 'rgb' in k.lower():
                if height is not None:
                    config['TASK_CONFIG']['SIMULATOR'][k]['HEIGHT'] = height
                if width is not None:
                    config['TASK_CONFIG']['SIMULATOR'][k]['WIDTH'] = width

        config.freeze()

        env_class = get_env_class(config.ENV_NAME)
        env = make_env_fn(env_class=env_class, config=config)
        env = globals()[wrapper](env)
        self.env = env

        self._img_size = env.observation_space.shape[0:2]

        # MDP properties
        action_space = self.env.action_space
        observation_space = Box(low=0.,
                                high=255.,
                                shape=(3, self._img_size[1],
                                       self._img_size[0]))
        mdp_info = MDPInfo(observation_space, action_space, gamma, horizon)

        if isinstance(action_space, Discrete):
            self._convert_action = lambda a: a[0]
        else:
            self._convert_action = lambda a: a

        self._viewer = ImageViewer((self._img_size[1], self._img_size[0]),
                                   1 / 10)

        Environment.__init__(self, mdp_info)
Example #15
0
    def _eval_checkpoint(self,
                         checkpoint_path: str,
                         writer: TensorboardWriter,
                         checkpoint_index: int = 0) -> None:
        r"""Evaluates a single checkpoint. Assumes episode IDs are unique.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        logger.info(f"checkpoint_path: {checkpoint_path}")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(
                self.load_checkpoint(checkpoint_path,
                                     map_location="cpu")["config"])
        else:
            config = self.config.clone()

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.TASK.NDTW.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.TASK.SDTW.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1
        if len(config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")

        config.freeze()

        # setup agent
        self.envs = construct_envs_auto_reset_false(
            config, get_env_class(config.ENV_NAME))
        self.device = (torch.device("cuda", config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        self._setup_actor_critic_agent(config.MODEL, True, checkpoint_path)

        observations = self.envs.reset()
        observations = transform_obs(
            observations, config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID)
        batch = batch_obs(observations, self.device)

        eval_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            config.NUM_PROCESSES,
            config.MODEL.STATE_ENCODER.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(config.NUM_PROCESSES,
                                     1,
                                     device=self.device)

        stats_episodes = {}  # dict of dicts that stores stats per episode

        if len(config.VIDEO_OPTION) > 0:
            os.makedirs(config.VIDEO_DIR, exist_ok=True)
            rgb_frames = [[] for _ in range(config.NUM_PROCESSES)]

        self.actor_critic.eval()
        while (self.envs.num_envs > 0
               and len(stats_episodes) < config.EVAL.EPISODE_COUNT):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                (_, actions, _,
                 eval_recurrent_hidden_states) = self.actor_critic.act(
                     batch,
                     eval_recurrent_hidden_states,
                     prev_actions,
                     not_done_masks,
                     deterministic=True,
                 )
                # actions = batch["vln_oracle_action_sensor"].long()
                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])
            observations, _, dones, infos = [list(x) for x in zip(*outputs)]

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            # reset envs and observations if necessary
            for i in range(self.envs.num_envs):
                if len(config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    frame = append_text_to_image(
                        frame,
                        current_episodes[i].instruction.instruction_text)
                    rgb_frames[i].append(frame)

                if not dones[i]:
                    continue

                stats_episodes[current_episodes[i].episode_id] = infos[i]
                observations[i] = self.envs.reset_at(i)[0]
                prev_actions[i] = torch.zeros(1, dtype=torch.long)

                if len(config.VIDEO_OPTION) > 0:
                    generate_video(
                        video_option=config.VIDEO_OPTION,
                        video_dir=config.VIDEO_DIR,
                        images=rgb_frames[i],
                        episode_id=current_episodes[i].episode_id,
                        checkpoint_idx=checkpoint_index,
                        metrics={
                            "SPL":
                            round(
                                stats_episodes[current_episodes[i].episode_id]
                                ["spl"], 6)
                        },
                        tb_writer=writer,
                    )

                    del stats_episodes[
                        current_episodes[i].episode_id]["top_down_map"]
                    del stats_episodes[
                        current_episodes[i].episode_id]["collisions"]
                    rgb_frames[i] = []

            observations = transform_obs(
                observations, config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID)
            batch = batch_obs(observations, self.device)

            envs_to_pause = []
            next_episodes = self.envs.current_episodes()

            for i in range(self.envs.num_envs):
                if next_episodes[i].episode_id in stats_episodes:
                    envs_to_pause.append(i)

            (
                self.envs,
                eval_recurrent_hidden_states,
                not_done_masks,
                prev_actions,
                batch,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                eval_recurrent_hidden_states,
                not_done_masks,
                prev_actions,
                batch,
            )

        self.envs.close()

        aggregated_stats = {}
        num_episodes = len(stats_episodes)
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = (
                sum([v[stat_key]
                     for v in stats_episodes.values()]) / num_episodes)

        split = config.TASK_CONFIG.DATASET.SPLIT
        with open(f"stats_ckpt_{checkpoint_index}_{split}.json", "w") as f:
            json.dump(aggregated_stats, f, indent=4)

        logger.info(f"Episodes evaluated: {num_episodes}")
        checkpoint_num = checkpoint_index + 1
        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.6f}")
            writer.add_scalar(f"eval_{split}_{k}", v, checkpoint_num)
Example #16
0
    def _update_dataset(self, data_it):
        if torch.cuda.is_available():
            with torch.cuda.device(self.device):
                torch.cuda.empty_cache()

        if self.envs is None:
            self.envs = construct_envs(self.config,
                                       get_env_class(self.config.ENV_NAME))

        recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            self.config.MODEL.STATE_ENCODER.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)

        observations = self.envs.reset()
        observations = transform_obs(
            observations, self.config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID)
        batch = batch_obs(observations, self.device)

        episodes = [[] for _ in range(self.envs.num_envs)]
        skips = [False for _ in range(self.envs.num_envs)]
        # Populate dones with False initially
        dones = [False for _ in range(self.envs.num_envs)]

        # https://arxiv.org/pdf/1011.0686.pdf
        # Theoretically, any beta function is fine so long as it converges to
        # zero as data_it -> inf. The paper suggests starting with beta = 1 and
        # exponential decay.
        if self.config.DAGGER.P == 0.0:
            # in Python 0.0 ** 0.0 == 1.0, but we want 0.0
            beta = 0.0
        else:
            beta = self.config.DAGGER.P**data_it

        ensure_unique_episodes = beta == 1.0

        def hook_builder(tgt_tensor):
            def hook(m, i, o):
                tgt_tensor.set_(o.cpu())

            return hook

        rgb_features = None
        rgb_hook = None
        if self.config.MODEL.RGB_ENCODER.cnn_type == "TorchVisionResNet50":
            rgb_features = torch.zeros((1, ), device="cpu")
            rgb_hook = self.actor_critic.net.rgb_encoder.layer_extract.register_forward_hook(
                hook_builder(rgb_features))

        depth_features = None
        depth_hook = None
        if self.config.MODEL.DEPTH_ENCODER.cnn_type == "VlnResnetDepthEncoder":
            depth_features = torch.zeros((1, ), device="cpu")
            depth_hook = self.actor_critic.net.depth_encoder.visual_encoder.register_forward_hook(
                hook_builder(depth_features))

        collected_eps = 0
        if ensure_unique_episodes:
            ep_ids_collected = set(
                [ep.episode_id for ep in self.envs.current_episodes()])

        with tqdm.tqdm(
                total=self.config.DAGGER.UPDATE_SIZE) as pbar, lmdb.open(
                    self.lmdb_features_dir,
                    map_size=int(self.config.DAGGER.LMDB_MAP_SIZE
                                 )) as lmdb_env, torch.no_grad():
            start_id = lmdb_env.stat()["entries"]
            txn = lmdb_env.begin(write=True)

            while collected_eps < self.config.DAGGER.UPDATE_SIZE:
                if ensure_unique_episodes:
                    envs_to_pause = []
                    current_episodes = self.envs.current_episodes()

                for i in range(self.envs.num_envs):
                    if dones[i] and not skips[i]:
                        ep = episodes[i]
                        traj_obs = batch_obs([step[0] for step in ep],
                                             device=torch.device("cpu"))
                        del traj_obs["vln_oracle_action_sensor"]
                        for k, v in traj_obs.items():
                            traj_obs[k] = v.numpy()

                        transposed_ep = [
                            traj_obs,
                            np.array([step[1] for step in ep], dtype=np.int64),
                            np.array([step[2] for step in ep], dtype=np.int64),
                        ]
                        txn.put(
                            str(start_id + collected_eps).encode(),
                            msgpack_numpy.packb(transposed_ep,
                                                use_bin_type=True),
                        )

                        pbar.update()
                        collected_eps += 1

                        if (collected_eps %
                                self.config.DAGGER.LMDB_COMMIT_FREQUENCY) == 0:
                            txn.commit()
                            txn = lmdb_env.begin(write=True)

                        if ensure_unique_episodes:
                            if current_episodes[
                                    i].episode_id in ep_ids_collected:
                                envs_to_pause.append(i)
                            else:
                                ep_ids_collected.add(
                                    current_episodes[i].episode_id)

                    if dones[i]:
                        episodes[i] = []

                if ensure_unique_episodes:
                    (
                        self.envs,
                        recurrent_hidden_states,
                        not_done_masks,
                        prev_actions,
                        batch,
                    ) = self._pause_envs(
                        envs_to_pause,
                        self.envs,
                        recurrent_hidden_states,
                        not_done_masks,
                        prev_actions,
                        batch,
                    )
                    if self.envs.num_envs == 0:
                        break

                (_, actions, _,
                 recurrent_hidden_states) = self.actor_critic.act(
                     batch,
                     recurrent_hidden_states,
                     prev_actions,
                     not_done_masks,
                     deterministic=False,
                 )
                # print("action: ", actions)
                # print("batch[vln_oracle_action_sensor]: ", batch["vln_oracle_action_sensor"])
                # print("torch.rand_like(actions, dtype=torch.float) < beta: ", torch.rand_like(actions, dtype=torch.float) < beta)
                actions = torch.where(
                    torch.rand_like(actions, dtype=torch.float) < beta,
                    batch["vln_oracle_action_sensor"].long(),
                    actions,
                )

                for i in range(self.envs.num_envs):
                    if rgb_features is not None:
                        observations[i]["rgb_features"] = rgb_features[i]
                        del observations[i]["rgb"]

                    if depth_features is not None:
                        observations[i]["depth_features"] = depth_features[i]
                        del observations[i]["depth"]

                    episodes[i].append((
                        observations[i],
                        prev_actions[i].item(),
                        batch["vln_oracle_action_sensor"][i].item(),
                    ))

                skips = batch["vln_oracle_action_sensor"].long() == -1
                actions = torch.where(skips, torch.zeros_like(actions),
                                      actions)
                skips = skips.squeeze(-1).to(device="cpu", non_blocking=True)

                prev_actions.copy_(actions)

                outputs = self.envs.step([a[0].item() for a in actions])
                observations, rewards, dones, _ = [
                    list(x) for x in zip(*outputs)
                ]

                not_done_masks = torch.tensor(
                    [[0.0] if done else [1.0] for done in dones],
                    dtype=torch.float,
                    device=self.device,
                )

                observations = transform_obs(
                    observations,
                    self.config.TASK_CONFIG.TASK.INSTRUCTION_SENSOR_UUID)
                batch = batch_obs(observations, self.device)

            txn.commit()

        self.envs.close()
        self.envs = None

        if rgb_hook is not None:
            rgb_hook.remove()
        if depth_hook is not None:
            depth_hook.remove()
Example #17
0
    def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        self.envs = construct_envs(
            self.config, get_env_class(self.config.ENV_NAME)
        )

        observation_space = self.envs.observation_spaces[0]

        ppo_cfg = self.config.RL.PPO
        task_cfg = self.config.TASK_CONFIG.TASK
        aux_cfg = self.config.RL.AUX_TASKS

        self.device = (
            torch.device("cuda", self.config.TORCH_GPU_ID)
            if torch.cuda.is_available()
            else torch.device("cpu")
        )

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        self._setup_dqn_agent(ppo_cfg, task_cfg, aux_cfg, [])

        self.dataset = RolloutDataset()
        self.dataloader = DataLoader(self.dataset, batch_size=16, num_workers=0)

        # Use environment to initialize the metadata for training the model
        self.envs.close()

        if self.config.RESUME_CURIOUS:
            weights = torch.load(self.config.RESUME_CURIOUS)['state_dict']
            state_dict = self.q_network.state_dict()

            weights_new = {}

            for k, v in weights.items():
                if "model_encoder" in k:
                    k = k.replace("model_encoder", "visual_resnet").replace("actor_critic.", "")
                    if k in state_dict:
                        weights_new[k] = v

            state_dict.update(weights_new)
            self.q_network.load_state_dict(state_dict)

        logger.info(
            "agent number of parameters: {}".format(
                sum(param.numel() for param in self.q_network.parameters())
            )
        )

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        if ckpt != -1:
            logger.info(
                f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly."
            )
            assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported"
            # This is the checkpoint we start saving at
            count_checkpoints = ckpt + 1
            count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES
            ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu")
            self.q_network.load_state_dict(ckpt_dict["state_dict"])
            self.q_network_target.load_state_dict(ckpt_dict["target_state_dict"])
            if "optim_state" in ckpt_dict:
                self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"])
            else:
                logger.warn("No optimizer state loaded, results may be funky")
            if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
                count_steps = ckpt_dict["extra_state"]["step"]


        lr_scheduler = LambdaLR(
            optimizer=self.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )
        im_size = 256

        with TensorboardWriter(
            self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
        ) as writer:

            update = 0
            for i in range(self.config.NUM_EPOCHS):
                for im, pointgoal, action, mask, reward in self.dataloader:
                    if ppo_cfg.use_linear_lr_decay:
                        lr_scheduler.step()

                    im, pointgoal, action, mask, reward = collate(im), collate(pointgoal), collate(action), collate(mask), collate(reward)
                    im = im.to(self.device).float()
                    pointgoal = pointgoal.to(self.device).float()
                    mask = mask.to(self.device).float()
                    reward = reward.to(self.device).float()
                    action = action.to(self.device).long()
                    nstep = im.size(1)

                    hidden_states = None
                    hidden_states_target = None

                    # q_vals = []
                    # q_vals_target = []

                    step = random.randint(0, nstep-1)
                    output = self.q_network({'rgb': im[:, step]},  None, None)
                    mse_loss = torch.pow(output - im[:, step] / 255., 2).mean()
                    mse_loss.backward()

                    # for step in range(nstep):
                    #     q_val, hidden_states = self.q_network({'rgb': im[:, step], 'pointgoal_with_gps_compass': pointgoal[:, step]}, hidden_states, mask[:, step])

                    #     q_val_target, hidden_states_target = self.q_network_target({'rgb': im[:, step], 'pointgoal_with_gps_compass': pointgoal[:, step]}, hidden_states_target, mask[:, step])

                    #     q_vals.append(q_val)
                    #     q_vals_target.append(q_val_target)

                    # q_vals = torch.stack(q_vals, dim=1)
                    # q_vals_target = torch.stack(q_vals_target, dim=1)

                    # a_select = torch.argmax(q_vals, dim=-1, keepdim=True)
                    # target_select = torch.gather(q_vals_target, -1, a_select)

                    # target = reward + ppo_cfg.gamma * target_select[:, 1:] * mask[:, 1:]
                    # target = target.detach()

                    # pred_q = torch.gather(q_vals[:, :-1], -1, action)

                    # mse_loss = torch.pow(pred_q - target, 2).mean()
                    # mse_loss.backward()
                    # grad_norm = torch.nn.utils.clip_grad_norm(self.q_network.parameters(), 80)

                    self.optimizer.step()
                    self.optimizer.zero_grad()

                    writer.add_scalar(
                        "loss",
                        mse_loss,
                        update,
                    )

                   #  writer.add_scalar(
                   #      "q_val",
                   #      q_vals.max(),
                   #      update,
                   #  )

                    if update % 10 == 0:
                        print("Update: {}, loss: {}".format(update, mse_loss))

                    if update % 100 == 0:
                        self.sync_model()

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"{self.checkpoint_prefix}.{count_checkpoints}.pth", dict(step=count_steps)
                        )
                        count_checkpoints += 1
                    update = update + 1
    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        self.envs = construct_envs(
            self.config,
            get_env_class(self.config.ENV_NAME),
            devices=self._assign_devices(),
        )

        # set up configurations
        ppo_cfg = self.config.RL.PPO
        ans_cfg = self.config.RL.ANS

        # set up device
        self.device = (
            torch.device("cuda", self.config.TORCH_GPU_ID)
            if torch.cuda.is_available()
            else torch.device("cpu")
        )

        # set up checkpoint folder
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        
        # create rollouts for mapper
        self.mapper_rollouts = self._create_mapper_rollouts(ans_cfg)

        self.depth_projection_net = DepthProjectionNet(
            ans_cfg.SEMANTIC_ANTICIPATOR.EGO_PROJECTION
        )

        self._setup_anticipator(ppo_cfg, ans_cfg)

        logger.info(
            "mapper_agent number of parameters: {}".format(
                sum(param.numel() for param in self.mapper_agent.parameters())
            )
        )


        mapper_rollouts = self.mapper_rollouts

        # ===================== Create statistics buffers =====================
        statistics_dict = {}
        # Mapper statistics
        statistics_dict["mapper"] = defaultdict(
            lambda: deque(maxlen=ppo_cfg.loss_stats_window_size)
        )

        # Overall count statistics
        episode_counts = torch.zeros(self.envs.num_envs, 1)
        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        # ==================== Measuring memory consumption ===================
        total_memory_size = 0
        print("=================== Mapper rollouts ======================")
        for k, v in mapper_rollouts.observations.items():
            mem = v.element_size() * v.nelement() * 1e-9
            print(f"key: {k:<40s}, memory: {mem:>10.4f} GB")
            total_memory_size += mem
        print(f"Total memory: {total_memory_size:>10.4f} GB")

        # Resume checkpoint if available
        (
            num_updates_start,
            count_steps_start,
            count_checkpoints,
        ) = self.resume_checkpoint()
        count_steps = count_steps_start

        imH, imW = ans_cfg.image_scale_hw
        M = ans_cfg.overall_map_size
        # ==================== Create state variables =================
        state_estimates = {
            # Agent's pose estimate
            "pose_estimates": torch.zeros(self.envs.num_envs, 3).to(self.device),
            # Agent's map
            "map_states": torch.zeros(self.envs.num_envs, 2, M, M).to(self.device),
            "recurrent_hidden_states": torch.zeros(
                1, self.envs.num_envs, ans_cfg.LOCAL_POLICY.hidden_size
            ).to(self.device),
            "visited_states": torch.zeros(self.envs.num_envs, 1, M, M).to(self.device),
        }

        ground_truth_states = {
            # To measure area seen
            "visible_occupancy": torch.zeros(
                self.envs.num_envs, 2, M, M, device=self.device
            ),
            "pose": torch.zeros(self.envs.num_envs, 3, device=self.device),
            "prev_global_reward_metric": torch.zeros(
                self.envs.num_envs, 1, device=self.device
            ),
        }

        masks = torch.zeros(self.envs.num_envs, 1)
        episode_step_count = torch.zeros(self.envs.num_envs, 1, device=self.device)

        # ==================== Reset the environments =================
        observations = self.envs.reset()
        batch = self._prepare_batch(observations)
        prev_batch = batch

        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            local_reward=torch.zeros(self.envs.num_envs, 1),
            global_reward=torch.zeros(self.envs.num_envs, 1),
        )
        
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )

        # Useful variables
        NUM_MAPPER_STEPS = ans_cfg.MAPPER.num_mapper_steps
        NUM_LOCAL_STEPS = ppo_cfg.num_local_steps
        NUM_GLOBAL_STEPS = ppo_cfg.num_global_steps
        GLOBAL_UPDATE_INTERVAL = NUM_GLOBAL_STEPS * ans_cfg.goal_interval
        NUM_GLOBAL_UPDATES_PER_EPISODE = self.config.T_EXP // GLOBAL_UPDATE_INTERVAL
        NUM_GLOBAL_UPDATES = (
            self.config.NUM_EPISODES
            * NUM_GLOBAL_UPDATES_PER_EPISODE
            // self.config.NUM_PROCESSES
        )
        # Sanity checks
        assert (
            NUM_MAPPER_STEPS % NUM_LOCAL_STEPS == 0
        ), "Mapper steps must be a multiple of global steps interval"
        assert (
            NUM_LOCAL_STEPS == ans_cfg.goal_interval
        ), "Local steps must be same as subgoal sampling interval"
        with TensorboardWriter(
            self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
        ) as writer:
            for update in range(num_updates_start, NUM_GLOBAL_UPDATES):
                for step in range(NUM_GLOBAL_STEPS):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                        prev_batch,
                        batch,
                        state_estimates,
                        ground_truth_states,
                    ) = self._collect_rollout_step(
                        batch,
                        prev_batch,
                        episode_step_count,
                        state_estimates,
                        ground_truth_states,
                        masks,
                        mapper_rollouts,
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                    # Useful flags
                    UPDATE_MAPPER_FLAG = (
                        True
                        if episode_step_count[0].item() % NUM_MAPPER_STEPS == 0
                        else False
                    )


                    # ------------------------ update mapper --------------------------
                    if UPDATE_MAPPER_FLAG:
                        (
                            delta_pth_time,
                            update_metrics_mapper,
                        ) = self._update_mapper_agent(mapper_rollouts)

                        for k, v in update_metrics_mapper.items():
                            statistics_dict["mapper"][k].append(v)

                    pth_time += delta_pth_time

                    # -------------------------- log statistics -----------------------
                    for k, v in statistics_dict.items():
                        logger.info(
                            "=========== {:20s} ============".format(k + " stats")
                        )
                        for kp, vp in v.items():
                            if len(vp) > 0:
                                writer.add_scalar(f"{k}/{kp}", np.mean(vp), count_steps)
                                logger.info(f"{kp:25s}: {np.mean(vp).item():10.5f}")

                    for k, v in running_episode_stats.items():
                        window_episode_stats[k].append(v.clone())

                    deltas = {
                        k: (
                            (v[-1] - v[0]).sum().item()
                            if len(v) > 1
                            else v[0].sum().item()
                        )
                        for k, v in window_episode_stats.items()
                    }

                    deltas["count"] = max(deltas["count"], 1.0)


                    fps = (count_steps - count_steps_start) / (time.time() - t_start)
                    writer.add_scalar("fps", fps, count_steps)

                    if update > 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(update, fps))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time, count_steps)
                        )

                        logger.info(
                            "Average window size: {}  {}".format(
                                len(window_episode_stats["count"]),
                                "  ".join(
                                    "{}: {:.3f}".format(k, v / deltas["count"])
                                    for k, v in deltas.items()
                                    if k != "count"
                                ),
                            )
                        )

                    pth_time += delta_pth_time

                # At episode termination, manually set masks to zeros.
                if episode_step_count[0].item() == self.config.T_EXP:
                    masks.fill_(0)

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(
                        f"ckpt.{count_checkpoints}.pth",
                        dict(step=count_steps, update=update),
                    )
                    count_checkpoints += 1

                if episode_step_count[0].item() == self.config.T_EXP:
                    seed = int(time.time())
                    print('change random seed to {}'.format(seed))
                    random.seed(seed)
                    np.random.seed(seed)
                    torch.manual_seed(seed)

                    observations = self.envs.reset()
                    batch = self._prepare_batch(observations)
                    prev_batch = batch
                    # Reset episode step counter
                    episode_step_count.fill_(0)
                    # Reset states
                    for k in ground_truth_states.keys():
                        ground_truth_states[k].fill_(0)
                    for k in state_estimates.keys():
                        state_estimates[k].fill_(0)

            self.envs.close()
Example #19
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """

        from habitat_baselines.common.environments import get_env_class

        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        #  logger.info(f"env config: {config}")
        self.envs = construct_envs_habitat(config,
                                           get_env_class(config.ENV_NAME))
        self.observation_space = SpaceDict({
            "pointgoal_with_gps_compass":
            spaces.Box(low=0.0, high=1.0, shape=(2, ), dtype=np.float32)
        })

        if self.config.COLOR:
            self.observation_space = SpaceDict({
                "rgb":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(3, *self.config.RESOLUTION),
                    dtype=np.uint8,
                ),
                **self.observation_space.spaces,
            })

        if self.config.DEPTH:
            self.observation_space = SpaceDict({
                "depth":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(1, *self.config.RESOLUTION),
                    dtype=np.float32,
                ),
                **self.observation_space.spaces,
            })

        self.action_space = self.envs.action_spaces[0]
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic
        self.actor_critic.script_net()
        self.actor_critic.eval()

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.config.NUM_PROCESSES,
            self.actor_critic.num_recurrent_layers,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device,
                                     dtype=torch.bool)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}.")
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        evals_per_ep = 5
        count_per_ep = defaultdict(lambda: 0)

        pbar = tqdm.tqdm(total=number_of_eval_episodes * evals_per_ep)
        self.actor_critic.eval()
        while (len(stats_episodes) < (number_of_eval_episodes * evals_per_ep)
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                (_, dist_result,
                 test_recurrent_hidden_states) = self.actor_critic.act(
                     batch,
                     test_recurrent_hidden_states,
                     prev_actions,
                     not_done_masks,
                     deterministic=False,
                 )
                actions = dist_result["actions"]

                prev_actions.copy_(actions)
                actions = actions.to("cpu")

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, device=self.device)

            not_done_masks = torch.tensor(
                [[False] if done else [True] for done in dones],
                dtype=torch.bool,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                next_count_key = (
                    next_episodes[i].scene_id,
                    next_episodes[i].episode_id,
                )

                if count_per_ep[next_count_key] == evals_per_ep:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    pbar.update()
                    episode_stats = dict()
                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i]))
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    count_key = (
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )
                    count_per_ep[count_key] = count_per_ep[count_key] + 1

                    ep_key = (count_key, count_per_ep[count_key])
                    stats_episodes[ep_key] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metrics=self._extract_scalars_from_info(infos[i]),
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        self.envs.close()
        pbar.close()
        num_episodes = len(stats_episodes)
        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            values = [
                v[stat_key] for v in stats_episodes.values()
                if v[stat_key] is not None
            ]
            if len(values) > 0:
                aggregated_stats[stat_key] = sum(values) / len(values)
            else:
                aggregated_stats[stat_key] = 0

        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.4f}")

        step_id = checkpoint_index
        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        writer.add_scalars(
            "eval_reward",
            {"average reward": aggregated_stats["reward"]},
            step_id,
        )

        metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"}
        if len(metrics) > 0:
            writer.add_scalars("eval_metrics", metrics, step_id)

        self.num_frames = step_id

        time.sleep(30)
Example #20
0
    def train(self) -> None:
        r"""Main method for DD-PPO SLAM.

        Returns:
            None
        """

        #####################################################################
        ## init distrib and configuration #####################################################################
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend
        )
        # self.local_rank = 1
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore(
            "rollout_tracker", tcp_store
        )
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank() # server number
        self.world_size = distrib.get_world_size() 

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank # gpu number in one server
        self.config.SIMULATOR_GPU_ID = self.local_rank
        print("********************* TORCH_GPU_ID: ", self.config.TORCH_GPU_ID)
        print("********************* SIMULATOR_GPU_ID: ", self.config.SIMULATOR_GPU_ID)

        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (
            self.world_rank * self.config.NUM_PROCESSES
        )
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")


        #####################################################################
        ## build distrib NavSLAMRLEnv environment
        #####################################################################
        print("#############################################################")
        print("## build distrib NavSLAMRLEnv environment")
        print("#############################################################")
        self.envs = construct_envs(
            self.config, get_env_class(self.config.ENV_NAME)
        )
        observations = self.envs.reset()
        print("*************************** observations len:", len(observations))

        # semantic process
        for i in range(len(observations)):
            observations[i]["semantic"] = observations[i]["semantic"].astype(np.int32)
            se = list(set(observations[i]["semantic"].ravel()))
            print(se)
        # print("*************************** observations type:", observations)
        # print("*************************** observations type:", observations[0]["map_sum"].shape) # 480*480*23
        # print("*************************** observations curr_pose:", observations[0]["curr_pose"]) # []

        batch = batch_obs(observations, device=self.device)
        print("*************************** batch len:", len(batch))
        # print("*************************** batch:", batch)

        # print("************************************* current_episodes:", (self.envs.current_episodes()))

        #####################################################################
        ## init actor_critic agent
        #####################################################################  
        print("#############################################################")
        print("## init actor_critic agent")
        print("#############################################################")
        self.map_w = observations[0]["map_sum"].shape[0]
        self.map_h = observations[0]["map_sum"].shape[1]
        # print("map_: ", observations[0]["curr_pose"].shape)


        ppo_cfg = self.config.RL.PPO
        if (
            not os.path.isdir(self.config.CHECKPOINT_FOLDER)
            and self.world_rank == 0
        ):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(observations, ppo_cfg)

        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info(
                "agent number of trainable parameters: {}".format(
                    sum(
                        param.numel()
                        for param in self.agent.parameters()
                        if param.requires_grad
                    )
                )
            )

        #####################################################################
        ## init Global Rollout Storage
        #####################################################################  
        print("#############################################################")
        print("## init Global Rollout Storage")
        print("#############################################################") 
        self.num_each_global_step = self.config.RL.SLAMDDPPO.num_each_global_step
        rollouts = GlobalRolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.obs_space,
            self.g_action_space,
        )
        rollouts.to(self.device)

        print('rollouts type:', type(rollouts))
        print('--------------------------')
        # for k in rollouts.keys():
        # print("rollouts: {0}".format(rollouts.observations))

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        with torch.no_grad():
            step_observation = {
                k: v[rollouts.step] for k, v in rollouts.observations.items()
            }
    
            _, actions, _, = self.actor_critic.act(
                step_observation,
                rollouts.prev_g_actions[0],
                rollouts.masks[0],
            )

        self.global_goals = [[int(action[0].item() * self.map_w), 
                            int(action[1].item() * self.map_h)]
                            for action in actions]

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(
            self.envs.num_envs, 1, device=self.device
        )
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )

        print("*************************** current_episode_reward:", current_episode_reward)
        print("*************************** running_episode_stats:", running_episode_stats)
        # print("*************************** window_episode_stats:", window_episode_stats)


        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        # interrupted_state = load_interrupted_state("/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth")
        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"]
            )
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        deif = {}
        with (
            TensorboardWriter(
                self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
            )
            if self.world_rank == 0
            else contextlib.suppress()
        ) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES
                    )
                # print("************************************* current_episodes:", type(self.envs.count_episodes()))
                
                # print(EXIT.is_set())
                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ),
                            "/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth"
                        )
                    print("********************EXIT*********************")

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_global_rollout_step(
                        rollouts, current_episode_reward, running_episode_stats
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # print("************************************* current_episodes:")

                    for i in range(len(self.envs.current_episodes())):
                        # print(" ", self.envs.current_episodes()[i].episode_id," ", self.envs.current_episodes()[i].scene_id," ", self.envs.current_episodes()[i].object_category)
                        if self.envs.current_episodes()[i].scene_id not in deif:
                            deif[self.envs.current_episodes()[i].scene_id]=[int(self.envs.current_episodes()[i].episode_id)]
                        else:
                            deif[self.envs.current_episodes()[i].scene_id].append(int(self.envs.current_episodes()[i].episode_id))


                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (
                        step
                        >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                    ) and int(num_rollouts_done_store.get("num_done")) > (
                        self.config.RL.DDPPO.sync_frac * self.world_size
                    ):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self._static_encoder:
                    self._encoder.eval()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0
                )
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [value_loss, action_loss, count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[2].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[0].item() / self.world_size,
                        stats[1].item() / self.world_size,
                    ]
                    deltas = {
                        k: (
                            (v[-1] - v[0]).sum().item()
                            if len(v) > 1
                            else v[0].sum().item()
                        )
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {k: l for l, k in zip(losses, ["value", "policy"])},
                        count_steps,
                    )

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info(
                            "update: {}\tfps: {:.3f}\t".format(
                                update,
                                count_steps
                                / ((time.time() - t_start) + prev_time),
                            )
                        )

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(
                                update, env_time, pth_time, count_steps
                            )
                        )
                        logger.info(
                            "Average window size: {}  {}".format(
                                len(window_episode_stats["count"]),
                                "  ".join(
                                    "{}: {:.3f}".format(k, v / deltas["count"])
                                    for k, v in deltas.items()
                                    if k != "count"
                                ),
                            )
                        )

                        # for k in deif:
                        #     deif[k] = list(set(deif[k]))
                        #     deif[k].sort()
                        #     print("deif: k", k, " : ", deif[k])

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )
                        print('=' * 20 + 'Save Model' + '=' * 20)
                        logger.info(
                            "Save Model : {}".format(count_checkpoints)
                        )
                        count_checkpoints += 1

            self.envs.close()
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO
        ans_cfg = config.RL.ANS

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        self.envs = construct_envs(config, get_env_class(config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg, ans_cfg)

        # Convert the state_dict of mapper_agent to mapper
        mapper_dict = {
            k.replace("mapper.", ""): v
            for k, v in ckpt_dict["mapper_state_dict"].items()
        }
        # Converting the state_dict of local_agent to just the local_policy.
        local_dict = {
            k.replace("actor_critic.", ""): v
            for k, v in ckpt_dict["local_state_dict"].items()
        }
        # Strict = False is set to ignore to handle the case where
        # pose_estimator is not required.
        self.mapper.load_state_dict(mapper_dict, strict=False)
        self.local_actor_critic.load_state_dict(local_dict)

        # Set models to evaluation
        self.mapper.eval()
        self.local_actor_critic.eval()

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}."
                )
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        M = ans_cfg.overall_map_size
        V = ans_cfg.MAPPER.map_size
        s = ans_cfg.MAPPER.map_scale
        imH, imW = ans_cfg.image_scale_hw

        assert (
            self.envs.num_envs == 1
        ), "Number of environments needs to be 1 for evaluation"

        # Define metric accumulators
        # Navigation metrics
        navigation_metrics = {
            "success_rate": Metric(),
            "spl": Metric(),
            "distance_to_goal": Metric(),
            "time": Metric(),
            "softspl": Metric(),
        }
        per_difficulty_navigation_metrics = {
            "easy": {
                "success_rate": Metric(),
                "spl": Metric(),
                "distance_to_goal": Metric(),
                "time": Metric(),
                "softspl": Metric(),
            },
            "medium": {
                "success_rate": Metric(),
                "spl": Metric(),
                "distance_to_goal": Metric(),
                "time": Metric(),
                "softspl": Metric(),
            },
            "hard": {
                "success_rate": Metric(),
                "spl": Metric(),
                "distance_to_goal": Metric(),
                "time": Metric(),
                "softspl": Metric(),
            },
        }

        times_per_episode = deque()
        times_per_step = deque()
        # Define a simple function to return episode difficulty based on
        # the geodesic distance
        def classify_difficulty(gd):
            if gd < 5.0:
                return "easy"
            elif gd < 10.0:
                return "medium"
            else:
                return "hard"

        eval_start_time = time.time()
        # Reset environments only for the very first batch
        observations = self.envs.reset()
        for ep in range(number_of_eval_episodes):
            # ============================== Reset agent ==============================
            # Reset agent states
            state_estimates = {
                "pose_estimates": torch.zeros(self.envs.num_envs, 3).to(self.device),
                "map_states": torch.zeros(self.envs.num_envs, 2, M, M).to(self.device),
                "recurrent_hidden_states": torch.zeros(
                    1, self.envs.num_envs, ans_cfg.LOCAL_POLICY.hidden_size
                ).to(self.device),
            }
            # Reset ANS states
            self.ans_net.reset()
            self.not_done_masks = torch.zeros(self.envs.num_envs, 1, device=self.device)
            self.prev_actions = torch.zeros(self.envs.num_envs, 1, device=self.device)
            self.prev_batch = None
            self.ep_time = torch.zeros(self.envs.num_envs, 1, device=self.device)
            # =========================== Episode loop ================================
            ep_start_time = time.time()
            current_episodes = self.envs.current_episodes()
            for ep_step in range(self.config.T_MAX):
                step_start_time = time.time()
                # ============================ Action step ============================
                batch = self._prepare_batch(observations)
                if self.prev_batch is None:
                    self.prev_batch = copy.deepcopy(batch)

                prev_pose_estimates = state_estimates["pose_estimates"]
                with torch.no_grad():
                    (
                        _,
                        _,
                        mapper_outputs,
                        local_policy_outputs,
                        state_estimates,
                    ) = self.ans_net.act(
                        batch,
                        self.prev_batch,
                        state_estimates,
                        self.ep_time,
                        self.not_done_masks,
                        deterministic=ans_cfg.LOCAL_POLICY.deterministic_flag,
                    )
                    actions = local_policy_outputs["actions"]
                    # Make masks not done till reset (end of episode)
                    self.not_done_masks = torch.ones(
                        self.envs.num_envs, 1, device=self.device
                    )
                    self.prev_actions.copy_(actions)

                if ep_step == 0:
                    state_estimates["pose_estimates"].copy_(prev_pose_estimates)

                self.ep_time += 1
                # Update prev batch
                for k, v in batch.items():
                    self.prev_batch[k].copy_(v)

                # Remap actions from exploration to navigation agent.
                actions_rmp = self._remap_actions(actions)

                # =========================== Environment step ========================
                outputs = self.envs.step([a[0].item() for a in actions_rmp])

                observations, _, dones, infos = [list(x) for x in zip(*outputs)]

                times_per_step.append(time.time() - step_start_time)
                # ============================ Process metrics ========================
                if dones[0]:
                    times_per_episode.append(time.time() - ep_start_time)
                    mins_per_episode = np.mean(times_per_episode).item() / 60.0
                    eta_completion = mins_per_episode * (
                        number_of_eval_episodes - ep - 1
                    )
                    secs_per_step = np.mean(times_per_step).item()
                    for i in range(self.envs.num_envs):
                        episode_id = int(current_episodes[i].episode_id)
                        curr_metrics = {
                            "spl": infos[i]["spl"],
                            "softspl": infos[i]["softspl"],
                            "success_rate": infos[i]["success"],
                            "time": ep_step + 1,
                            "distance_to_goal": infos[i]["distance_to_goal"],
                        }
                        # Estimate difficulty of episode
                        episode_difficulty = classify_difficulty(
                            current_episodes[i].info["geodesic_distance"]
                        )
                        for k, v in curr_metrics.items():
                            navigation_metrics[k].update(v, 1.0)
                            per_difficulty_navigation_metrics[episode_difficulty][
                                k
                            ].update(v, 1.0)

                        logger.info(f"====> {ep}/{number_of_eval_episodes} done")
                        for k, v in curr_metrics.items():
                            logger.info(f"{k:25s} : {v:10.3f}")
                        logger.info("{:25s} : {:10d}".format("episode_id", episode_id))
                        logger.info(f"Time per episode: {mins_per_episode:.3f} mins")
                        logger.info(f"Time per step: {secs_per_step:.3f} secs")
                        logger.info(f"ETA: {eta_completion:.3f} mins")

                    # For navigation, terminate episode loop when dones is called
                    break
            # done-for

        if checkpoint_index == 0:
            try:
                eval_ckpt_idx = self.config.EVAL_CKPT_PATH_DIR.split("/")[-1].split(
                    "."
                )[1]
                logger.add_filehandler(
                    f"{self.config.TENSORBOARD_DIR}/navigation_results_ckpt_final_{eval_ckpt_idx}.txt"
                )
            except:
                logger.add_filehandler(
                    f"{self.config.TENSORBOARD_DIR}/navigation_results_ckpt_{checkpoint_index}.txt"
                )
        else:
            logger.add_filehandler(
                f"{self.config.TENSORBOARD_DIR}/navigation_results_ckpt_{checkpoint_index}.txt"
            )

        logger.info(
            f"======= Evaluating over {number_of_eval_episodes} episodes ============="
        )

        logger.info(f"=======> Navigation metrics")
        for k, v in navigation_metrics.items():
            logger.info(f"{k}: {v.get_metric():.3f}")
            writer.add_scalar(f"navigation/{k}", v.get_metric(), checkpoint_index)

        for diff, diff_metrics in per_difficulty_navigation_metrics.items():
            logger.info(f"=============== {diff:^10s} metrics ==============")
            for k, v in diff_metrics.items():
                logger.info(f"{k}: {v.get_metric():.3f}")
                writer.add_scalar(
                    f"{diff}_navigation/{k}", v.get_metric(), checkpoint_index
                )

        total_eval_time = (time.time() - eval_start_time) / 60.0
        logger.info(f"Total evaluation time: {total_eval_time:.3f} mins")
        self.envs.close()
Example #22
0
    def train(self) -> None:
        r"""Main method for training PPO.
        Returns:
            None
        """

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
        )
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for update in range(self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                for k, v in running_episode_stats.items():
                    window_episode_stats[k].append(v.clone())

                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in window_episode_stats.items()
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                # Check to see if there are any metrics
                # that haven't been logged yet
                metrics = {
                    k: v / deltas["count"]
                    for k, v in deltas.items() if k not in {"reward", "count"}
                }
                if len(metrics) > 0:
                    writer.add_scalars("metrics", metrics, count_steps)

                losses = [value_loss, action_loss]
                writer.add_scalars(
                    "losses",
                    {k: l
                     for l, k in zip(losses, ["value", "policy"])},
                    count_steps,
                )

                # log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    logger.info("Average window size: {}  {}".format(
                        len(window_episode_stats["count"]),
                        "  ".join("{}: {:.3f}".format(k, v / deltas["count"])
                                  for k, v in deltas.items() if k != "count"),
                    ))

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth",
                                         dict(step=count_steps))
                    count_checkpoints += 1

            self.envs.close()
Example #23
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(config, get_env_class(config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)
        self.actor_critic.eval()

        if self._static_encoder:
            self._encoder = self.agent.actor_critic.net.visual_encoder

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        if self._static_encoder:
            batch["visual_features"] = self._encoder(batch)
            batch["prev_visual_features"] = torch.zeros_like(
                batch["visual_features"])

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}.")
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        pbar = tqdm.tqdm(total=number_of_eval_episodes)
        self.actor_critic.eval()
        while (len(stats_episodes) < number_of_eval_episodes
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                step_batch = batch
                (
                    _,
                    actions,
                    _,
                    test_recurrent_hidden_states,
                ) = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    prev_actions,
                    not_done_masks,
                    deterministic=False,
                )

                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, device=self.device)

            if self._static_encoder:
                batch["prev_visual_features"] = step_batch["visual_features"]
                batch["visual_features"] = self._encoder(batch)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                if (
                        next_episodes[i].scene_id,
                        next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    pbar.update()
                    episode_stats = dict()
                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i]))
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metrics=self._extract_scalars_from_info(infos[i]),
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        num_episodes = len(stats_episodes)
        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = (
                sum([v[stat_key]
                     for v in stats_episodes.values()]) / num_episodes)

        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.4f}")

        step_id = checkpoint_index
        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        writer.add_scalars(
            "eval_reward",
            {"average reward": aggregated_stats["reward"]},
            step_id,
        )

        metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"}
        if len(metrics) > 0:
            writer.add_scalars("eval_metrics", metrics, step_id)

        self.envs.close()
    def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        task_cfg = self.config.TASK_CONFIG.TASK
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        # Initialize auxiliary tasks
        observation_space = self.envs.observation_spaces[0]
        aux_cfg = self.config.RL.AUX_TASKS
        init_aux_tasks, num_recurrent_memories, aux_task_strings = \
            self._setup_auxiliary_tasks(aux_cfg, ppo_cfg, task_cfg, observation_space)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            observation_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_memories=num_recurrent_memories)
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg,
                                       init_aux_tasks)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        if ckpt != -1:
            logger.info(
                f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly."
            )
            assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported"
            # This is the checkpoint we start saving at
            count_checkpoints = ckpt + 1
            count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES
            ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu")
            self.agent.load_state_dict(ckpt_dict["state_dict"])
            if "optim_state" in ckpt_dict:
                self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"])
            else:
                logger.warn("No optimizer state loaded, results may be funky")
            if "extra_state" in ckpt_dict and "step" in ckpt_dict[
                    "extra_state"]:
                count_steps = ckpt_dict["extra_state"]["step"]

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:

            for update in range(start_updates, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights = self._update_agent(
                    ppo_cfg, rollouts)
                pth_time += delta_pth_time

                for k, v in running_episode_stats.items():
                    window_episode_stats[k].append(v.clone())

                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in window_episode_stats.items()
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar(
                    "entropy",
                    dist_entropy,
                    count_steps,
                )

                writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                # Check to see if there are any metrics
                # that haven't been logged yet
                metrics = {
                    k: v / deltas["count"]
                    for k, v in deltas.items() if k not in {"reward", "count"}
                }
                if len(metrics) > 0:
                    writer.add_scalars("metrics", metrics, count_steps)

                losses = [value_loss, action_loss] + aux_task_losses
                writer.add_scalars(
                    "losses",
                    {
                        k: l
                        for l, k in zip(losses, ["value", "policy"] +
                                        aux_task_strings)
                    },
                    count_steps,
                )

                writer.add_scalars(
                    "aux_weights",
                    {k: l
                     for l, k in zip(aux_weights, aux_task_strings)},
                    count_steps,
                )

                writer.add_scalar(
                    "success",
                    deltas["success"] / deltas["count"],
                    count_steps,
                )

                # Log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info(
                        "update: {}\tvalue_loss: {}\t action_loss: {}\taux_task_loss: {} \t aux_entropy {}"
                        .format(update, value_loss, action_loss,
                                aux_task_losses, aux_dist_entropy))
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    logger.info("Average window size: {}  {}".format(
                        len(window_episode_stats["count"]),
                        "  ".join("{}: {:.3f}".format(k, v / deltas["count"])
                                  for k, v in deltas.items() if k != "count"),
                    ))

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(
                        f"{self.checkpoint_prefix}.{count_checkpoints}.pth",
                        dict(step=count_steps))
                    count_checkpoints += 1

        self.envs.close()
Example #25
0
    def train(self) -> None:
        r"""Main method for training DAgger.

        Returns:
            None
        """
        os.makedirs(self.lmdb_features_dir, exist_ok=True)
        os.makedirs(self.config.CHECKPOINT_FOLDER, exist_ok=True)

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            try:
                lmdb.open(self.lmdb_features_dir, readonly=True)
            except lmdb.Error as err:
                logger.error(
                    "Cannot open database for teacher forcing preload.")
                raise err
        else:
            with lmdb.open(self.lmdb_features_dir,
                           map_size=int(self.config.DAGGER.LMDB_MAP_SIZE)
                           ) as lmdb_env, lmdb_env.begin(write=True) as txn:
                txn.drop(lmdb_env.open_db())

        split = self.config.TASK_CONFIG.DATASET.SPLIT
        self.config.defrost()
        self.config.TASK_CONFIG.TASK.NDTW.SPLIT = split
        self.config.TASK_CONFIG.TASK.SDTW.SPLIT = split

        # if doing teacher forcing, don't switch the scene until it is complete
        if self.config.DAGGER.P == 1.0:
            self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = (
                -1)
        self.config.freeze()

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            # when preloadeding features, its quicker to just load one env as we just
            # need the observation space from it.
            single_proc_config = self.config.clone()
            single_proc_config.defrost()
            single_proc_config.NUM_PROCESSES = 1
            single_proc_config.freeze()
            self.envs = construct_envs(single_proc_config,
                                       get_env_class(self.config.ENV_NAME))
        else:
            self.envs = construct_envs(self.config,
                                       get_env_class(self.config.ENV_NAME))

        self._setup_actor_critic_agent(
            self.config.MODEL,
            self.config.DAGGER.LOAD_FROM_CKPT,
            self.config.DAGGER.CKPT_TO_LOAD,
        )

        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.actor_critic.parameters())))
        logger.info("agent number of trainable parameters: {}".format(
            sum(p.numel() for p in self.actor_critic.parameters()
                if p.requires_grad)))

        if self.config.DAGGER.PRELOAD_LMDB_FEATURES:
            self.envs.close()
            del self.envs
            self.envs = None

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs,
                               purge_step=0) as writer:
            for dagger_it in range(self.config.DAGGER.ITERATIONS):
                step_id = 0
                if not self.config.DAGGER.PRELOAD_LMDB_FEATURES:
                    self._update_dataset(dagger_it + (
                        1 if self.config.DAGGER.LOAD_FROM_CKPT else 0))

                if torch.cuda.is_available():
                    with torch.cuda.device(self.device):
                        torch.cuda.empty_cache()
                gc.collect()

                dataset = IWTrajectoryDataset(
                    self.lmdb_features_dir,
                    self.config.DAGGER.USE_IW,
                    inflection_weight_coef=self.config.MODEL.
                    inflection_weight_coef,
                    lmdb_map_size=self.config.DAGGER.LMDB_MAP_SIZE,
                    batch_size=self.config.DAGGER.BATCH_SIZE,
                )

                AuxLosses.activate()
                for epoch in tqdm.trange(self.config.DAGGER.EPOCHS):
                    diter = torch.utils.data.DataLoader(
                        dataset,
                        batch_size=self.config.DAGGER.BATCH_SIZE,
                        shuffle=False,
                        collate_fn=collate_fn,
                        pin_memory=False,
                        drop_last=True,  # drop last batch if smaller
                        num_workers=0,
                    )
                    for batch in tqdm.tqdm(diter,
                                           total=dataset.length //
                                           dataset.batch_size,
                                           leave=False):
                        (
                            observations_batch,
                            prev_actions_batch,
                            not_done_masks,
                            corrected_actions_batch,
                            weights_batch,
                        ) = batch
                        observations_batch = {
                            k: v.to(device=self.device, non_blocking=True)
                            for k, v in observations_batch.items()
                        }
                        try:
                            loss, action_loss, aux_loss = self._update_agent(
                                observations_batch,
                                prev_actions_batch.to(device=self.device,
                                                      non_blocking=True),
                                not_done_masks.to(device=self.device,
                                                  non_blocking=True),
                                corrected_actions_batch.to(device=self.device,
                                                           non_blocking=True),
                                weights_batch.to(device=self.device,
                                                 non_blocking=True),
                            )
                        except:
                            logger.info(
                                "ERROR: failed to update agent. Updating agent with batch size of 1."
                            )
                            loss, action_loss, aux_loss = 0, 0, 0
                            prev_actions_batch = prev_actions_batch.cpu()
                            not_done_masks = not_done_masks.cpu()
                            corrected_actions_batch = corrected_actions_batch.cpu(
                            )
                            weights_batch = weights_batch.cpu()
                            observations_batch = {
                                k: v.cpu()
                                for k, v in observations_batch.items()
                            }
                            for i in range(not_done_masks.size(0)):
                                output = self._update_agent(
                                    {
                                        k: v[i].to(device=self.device,
                                                   non_blocking=True)
                                        for k, v in observations_batch.items()
                                    },
                                    prev_actions_batch[i].to(
                                        device=self.device, non_blocking=True),
                                    not_done_masks[i].to(device=self.device,
                                                         non_blocking=True),
                                    corrected_actions_batch[i].to(
                                        device=self.device, non_blocking=True),
                                    weights_batch[i].to(device=self.device,
                                                        non_blocking=True),
                                )
                                loss += output[0]
                                action_loss += output[1]
                                aux_loss += output[2]

                        logger.info(f"train_loss: {loss}")
                        logger.info(f"train_action_loss: {action_loss}")
                        logger.info(f"train_aux_loss: {aux_loss}")
                        logger.info(f"Batches processed: {step_id}.")
                        logger.info(
                            f"On DAgger iter {dagger_it}, Epoch {epoch}.")
                        writer.add_scalar(f"train_loss_iter_{dagger_it}", loss,
                                          step_id)
                        writer.add_scalar(
                            f"train_action_loss_iter_{dagger_it}", action_loss,
                            step_id)
                        writer.add_scalar(f"train_aux_loss_iter_{dagger_it}",
                                          aux_loss, step_id)
                        step_id += 1

                    self.save_checkpoint(
                        f"ckpt.{dagger_it * self.config.DAGGER.EPOCHS + epoch}.pth"
                    )
                AuxLosses.deactivate()
    def _eval_checkpoint(self,
                         checkpoint_path: str,
                         writer: TensorboardWriter,
                         checkpoint_index: int = 0,
                         log_diagnostics=[],
                         output_dir='.',
                         label='.',
                         num_eval_runs=1) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        if checkpoint_index == -1:
            ckpt_file = checkpoint_path.split('/')[-1]
            split_info = ckpt_file.split('.')
            checkpoint_index = split_info[1]
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO
        task_cfg = config.TASK_CONFIG.TASK

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(config, get_env_class(config.ENV_NAME))
        # pass in aux config if we're doing attention
        aux_cfg = self.config.RL.AUX_TASKS
        self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg)

        # Check if we accidentally recorded `visual_resnet` in our checkpoint and drop it (it's redundant with `visual_encoder`)
        ckpt_dict['state_dict'] = {
            k: v
            for k, v in ckpt_dict['state_dict'].items()
            if 'visual_resnet' not in k
        }
        self.agent.load_state_dict(ckpt_dict["state_dict"])

        logger.info("agent number of trainable parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters()
                if param.requires_grad)))

        self.actor_critic = self.agent.actor_critic

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        _, num_recurrent_memories, _ = self._setup_auxiliary_tasks(
            aux_cfg, ppo_cfg, task_cfg, is_eval=True)
        if self.config.RL.PPO.policy in MULTIPLE_BELIEF_CLASSES:
            aux_tasks = self.config.RL.AUX_TASKS.tasks
            num_recurrent_memories = len(self.config.RL.AUX_TASKS.tasks)
            test_recurrent_hidden_states = test_recurrent_hidden_states.unsqueeze(
                2).repeat(1, 1, num_recurrent_memories, 1)

        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)

        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]

        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}.")
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        videos_cap = 2  # number of videos to generate per checkpoint
        if len(log_diagnostics) > 0:
            videos_cap = 10
        # video_indices = random.sample(range(self.config.TEST_EPISODE_COUNT),
        # min(videos_cap, self.config.TEST_EPISODE_COUNT))
        video_indices = range(10)
        print(f"Videos: {video_indices}")

        total_stats = []
        dones_per_ep = dict()

        # Logging more extensive evaluation stats for analysis
        if len(log_diagnostics) > 0:
            d_stats = {}
            for d in log_diagnostics:
                d_stats[d] = [
                    [] for _ in range(self.config.NUM_PROCESSES)
                ]  # stored as nested list envs x timesteps x k (# tasks)

        pbar = tqdm.tqdm(total=number_of_eval_episodes * num_eval_runs)
        self.agent.eval()
        while (len(stats_episodes) < number_of_eval_episodes * num_eval_runs
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()
            with torch.no_grad():
                weights_output = None
                if self.config.RL.PPO.policy in MULTIPLE_BELIEF_CLASSES:
                    weights_output = torch.empty(self.envs.num_envs,
                                                 len(aux_tasks))
                (
                    _,
                    actions,
                    _,
                    test_recurrent_hidden_states,
                ) = self.actor_critic.act(batch,
                                          test_recurrent_hidden_states,
                                          prev_actions,
                                          not_done_masks,
                                          deterministic=False,
                                          weights_output=weights_output)
                prev_actions.copy_(actions)

                for i in range(self.envs.num_envs):
                    if Diagnostics.actions in log_diagnostics:
                        d_stats[Diagnostics.actions][i].append(
                            prev_actions[i].item())
                    if Diagnostics.weights in log_diagnostics:
                        aux_weights = None if weights_output is None else weights_output[
                            i]
                        if aux_weights is not None:
                            d_stats[Diagnostics.weights][i].append(
                                aux_weights.half().tolist())

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, device=self.device)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                next_k = (
                    next_episodes[i].scene_id,
                    next_episodes[i].episode_id,
                )
                if dones_per_ep.get(next_k, 0) == num_eval_runs:
                    envs_to_pause.append(i)  # wait for the rest

                if not_done_masks[i].item() == 0:
                    episode_stats = dict()

                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i]))

                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats

                    k = (
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )
                    dones_per_ep[k] = dones_per_ep.get(k, 0) + 1

                    if dones_per_ep.get(k, 0) == 1 and len(
                            self.config.VIDEO_OPTION) > 0 and len(
                                stats_episodes) in video_indices:
                        logger.info(f"Generating video {len(stats_episodes)}")
                        category = getattr(current_episodes[i],
                                           "object_category", "")
                        if category != "":
                            category += "_"
                        try:
                            generate_video(
                                video_option=self.config.VIDEO_OPTION,
                                video_dir=self.config.VIDEO_DIR,
                                images=rgb_frames[i],
                                episode_id=current_episodes[i].episode_id,
                                checkpoint_idx=checkpoint_index,
                                metrics=self._extract_scalars_from_info(
                                    infos[i]),
                                tag=f"{category}{label}",
                                tb_writer=writer,
                            )
                        except Exception as e:
                            logger.warning(str(e))
                    rgb_frames[i] = []

                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                        dones_per_ep[k],
                    )] = episode_stats

                    if len(log_diagnostics) > 0:
                        diagnostic_info = dict()
                        for metric in log_diagnostics:
                            diagnostic_info[metric] = d_stats[metric][i]
                            d_stats[metric][i] = []
                        if Diagnostics.top_down_map in log_diagnostics:
                            top_down_map = torch.tensor([])
                            if len(self.config.VIDEO_OPTION) > 0:
                                top_down_map = infos[i]["top_down_map"]["map"]
                                top_down_map = maps.colorize_topdown_map(
                                    top_down_map, fog_of_war_mask=None)
                            diagnostic_info.update(
                                dict(top_down_map=top_down_map))
                        total_stats.append(
                            dict(
                                stats=episode_stats,
                                did_stop=bool(prev_actions[i] == 0),
                                episode_info=attr.asdict(current_episodes[i]),
                                info=diagnostic_info,
                            ))
                    pbar.update()

                # episode continues
                else:
                    if len(self.config.VIDEO_OPTION) > 0:
                        aux_weights = None if weights_output is None else weights_output[
                            i]
                        frame = observations_to_image(
                            observations[i], infos[i],
                            current_episode_reward[i].item(), aux_weights,
                            aux_tasks)
                        rgb_frames[i].append(frame)
                    if Diagnostics.gps in log_diagnostics:
                        d_stats[Diagnostics.gps][i].append(
                            observations[i]["gps"].tolist())
                    if Diagnostics.heading in log_diagnostics:
                        d_stats[Diagnostics.heading][i].append(
                            observations[i]["heading"].tolist())

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        num_episodes = len(stats_episodes)
        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = (
                sum([v[stat_key]
                     for v in stats_episodes.values()]) / num_episodes)

        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.4f}")

        step_id = checkpoint_index
        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        writer.add_scalars(
            "eval_reward",
            {"average reward": aggregated_stats["reward"]},
            step_id,
        )

        metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"}
        if len(metrics) > 0:
            writer.add_scalars("eval_metrics", metrics, step_id)
            logger.info("eval_metrics")
            logger.info(metrics)
        if len(log_diagnostics) > 0:
            os.makedirs(output_dir, exist_ok=True)
            eval_fn = f"{label}.json"
            with open(os.path.join(output_dir, eval_fn), 'w',
                      encoding='utf-8') as f:
                json.dump(total_stats, f, ensure_ascii=False, indent=4)
        self.envs.close()
Example #27
0
    def eval(self, checkpoint_path):
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        self.env = construct_envs(config, get_env_class(config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        # get name of performance metric, e.g. "spl"
        metric_name = self.config.TASK_CONFIG.TASK.MEASUREMENTS[0]
        metric_cfg = getattr(self.config.TASK_CONFIG.TASK, metric_name)
        measure_type = baseline_registry.get_measure(metric_cfg.TYPE)
        assert measure_type is not None, "invalid measurement type {}".format(
            metric_cfg.TYPE)
        self.metric_uuid = measure_type(sim=None, task=None,
                                        config=None)._get_uuid()

        observations = self.env.reset()
        batch = batch_obs(observations, self.device)

        current_episode_reward = torch.zeros(self.env.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        self.actor_critic.eval()
        while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT
               and self.env.num_envs > 0):
            current_episodes = self.env.current_episodes()

            with torch.no_grad():
                (
                    _,
                    actions,
                    _,
                    test_recurrent_hidden_states,
                ) = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    prev_actions,
                    not_done_masks,
                    deterministic=False,
                )

                prev_actions.copy_(actions)

            outputs = self.env.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, self.device)

            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.env.current_episodes()
            envs_to_pause = []
            n_envs = self.env.num_envs
            for i in range(n_envs):
                if (
                        next_episodes[i].scene_id,
                        next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    episode_stats = dict()
                    episode_stats[self.metric_uuid] = infos[i][
                        self.metric_uuid]
                    episode_stats["success"] = int(
                        infos[i][self.metric_uuid] > 0)
                    episode_stats["reward"] = current_episode_reward[i].item()
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=0,
                            metric_name=self.metric_uuid,
                            metric_value=infos[i][self.metric_uuid],
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            (
                self.env,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.env,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = sum(
                [v[stat_key] for v in stats_episodes.values()])
        num_episodes = len(stats_episodes)

        episode_reward_mean = aggregated_stats["reward"] / num_episodes
        episode_metric_mean = aggregated_stats[self.metric_uuid] / num_episodes
        episode_success_mean = aggregated_stats["success"] / num_episodes

        print(f"Average episode reward: {episode_reward_mean:.6f}")
        print(f"Average episode success: {episode_success_mean:.6f}")
        print(f"Average episode {self.metric_uuid}: {episode_metric_mean:.6f}")

        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        print("eval_reward", {"average reward": episode_reward_mean})
        print(
            f"eval_{self.metric_uuid}",
            {f"average {self.metric_uuid}": episode_metric_mean},
        )
        print("eval_success", {"average success": episode_success_mean})

        self.env.close()
Example #28
0
    def _eval_checkpoint_bruce(
        self,
        checkpoint_path: str,
        checkpoint_index: int = 0,
    ):
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            actor_critic
            batch
            not_done_masks
            test_recurrent_hidden_states
        """
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")
        print(ckpt_dict)

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(config, get_env_class(config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        return (self.actor_critic, batch, not_done_masks,
                test_recurrent_hidden_states)
Example #29
0
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (self.world_rank *
                                         self.config.NUM_PROCESSES)
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        if (not os.path.isdir(self.config.CHECKPOINT_FOLDER)
                and self.world_rank == 0):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        obs_space = self.envs.observation_spaces[0]
        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            obs_space = SpaceDict({
                "visual_features":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=self._encoder.output_shape,
                    dtype=np.float32,
                ),
                **obs_space.spaces,
            })
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            obs_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
        )
        rollouts.to(self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):

                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self._static_encoder:
                    self._encoder.eval()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [value_loss, action_loss, count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[2].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[0].item() / self.world_size,
                        stats[1].item() / self.world_size,
                    ]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {k: l
                         for l, k in zip(losses, ["value", "policy"])},
                        count_steps,
                    )

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            count_steps /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )
                        count_checkpoints += 1

            self.envs.close()
Example #30
0
    def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        random.seed(self.config.TASK_CONFIG.SEED + self.world_rank)
        np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank)

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        self.config.freeze()

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        task_cfg = self.config.TASK_CONFIG.TASK

        observation_space = self.envs.observation_spaces[0]
        aux_cfg = self.config.RL.AUX_TASKS
        init_aux_tasks, num_recurrent_memories, aux_task_strings = self._setup_auxiliary_tasks(
            aux_cfg, ppo_cfg, task_cfg, observation_space)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            observation_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_memories=num_recurrent_memories)
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg,
                                       init_aux_tasks)
        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),  # including bonus
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        prev_time = 0

        if ckpt != -1:
            logger.info(
                f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly."
            )
            assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported"
            # This is the checkpoint we start saving at
            count_checkpoints = ckpt + 1
            count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES
            ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu")
            self.agent.load_state_dict(ckpt_dict["state_dict"])
            if "optim_state" in ckpt_dict:
                self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"])
            else:
                logger.warn("No optimizer state loaded, results may be funky")
            if "extra_state" in ckpt_dict and "step" in ckpt_dict[
                    "extra_state"]:
                count_steps = ckpt_dict["extra_state"]["step"]

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_updates = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:

            for update in range(start_updates, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)
                self.agent.train()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                    aux_task_losses,
                    aux_dist_entropy,
                    aux_weights,
                ) = self._update_agent(ppo_cfg, rollouts)

                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering],
                    0).to(self.device)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [
                        dist_entropy,
                        aux_dist_entropy,
                    ] + [value_loss, action_loss] + aux_task_losses +
                    [count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                if aux_weights is not None and len(aux_weights) > 0:
                    distrib.all_reduce(
                        torch.tensor(aux_weights, device=self.device))
                count_steps += stats[-1].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    avg_stats = [
                        stats[i].item() / self.world_size
                        for i in range(len(stats) - 1)
                    ]
                    losses = avg_stats[2:]
                    dist_entropy, aux_dist_entropy = avg_stats[:2]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    writer.add_scalar(
                        "entropy",
                        dist_entropy,
                        count_steps,
                    )

                    writer.add_scalar("aux_entropy", aux_dist_entropy,
                                      count_steps)

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {
                            k: l
                            for l, k in zip(losses, ["value", "policy"] +
                                            aux_task_strings)
                        },
                        count_steps,
                    )

                    writer.add_scalars(
                        "aux_weights",
                        {k: l
                         for l, k in zip(aux_weights, aux_task_strings)},
                        count_steps,
                    )

                    # Log stats
                    formatted_aux_losses = [
                        "{:.3g}".format(l) for l in aux_task_losses
                    ]
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info(
                            "update: {}\tvalue_loss: {:.3g}\t action_loss: {:.3g}\taux_task_loss: {} \t aux_entropy {:.3g}\t"
                            .format(
                                update,
                                value_loss,
                                action_loss,
                                formatted_aux_losses,
                                aux_dist_entropy,
                            ))
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            count_steps /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"{self.checkpoint_prefix}.{count_checkpoints}.pth",
                            dict(step=count_steps))
                        count_checkpoints += 1

        self.envs.close()