示例#1
0
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (self.world_rank *
                                         self.config.NUM_PROCESSES)
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        if (not os.path.isdir(self.config.CHECKPOINT_FOLDER)
                and self.world_rank == 0):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        obs_space = self.envs.observation_spaces[0]
        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            obs_space = SpaceDict({
                "visual_features":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=self._encoder.output_shape,
                    dtype=np.float32,
                ),
                **obs_space.spaces,
            })
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            obs_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
        )
        rollouts.to(self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):

                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self._static_encoder:
                    self._encoder.eval()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [value_loss, action_loss, count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[2].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[0].item() / self.world_size,
                        stats[1].item() / self.world_size,
                    ]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {k: l
                         for l, k in zip(losses, ["value", "policy"])},
                        count_steps,
                    )

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            count_steps /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )
                        count_checkpoints += 1

            self.envs.close()
示例#2
0
def _worker_fn(world_rank: int, world_size: int, port: int,
               unused_params: bool):
    device = (torch.device("cuda")
              if torch.cuda.is_available() else torch.device("cpu"))
    tcp_store = distrib.TCPStore(  # type: ignore
        "127.0.0.1", port, world_size, world_rank == 0)
    distrib.init_process_group("gloo",
                               store=tcp_store,
                               rank=world_rank,
                               world_size=world_size)

    config = get_config("habitat_baselines/config/test/ppo_pointnav_test.yaml")
    obs_space = gym.spaces.Dict({
        IntegratedPointGoalGPSAndCompassSensor.cls_uuid:
        gym.spaces.Box(
            low=np.finfo(np.float32).min,
            high=np.finfo(np.float32).max,
            shape=(2, ),
            dtype=np.float32,
        )
    })
    action_space = ActionSpace({"move": EmptySpace()})
    actor_critic = PointNavBaselinePolicy.from_config(config, obs_space,
                                                      action_space)
    # This use adds some arbitrary parameters that aren't part of the computation
    # graph, so they will mess up DDP if they aren't correctly ignored by it
    if unused_params:
        actor_critic.unused = nn.Linear(64, 64)

    actor_critic.to(device=device)
    ppo_cfg = config.RL.PPO
    agent = DDPPO(
        actor_critic=actor_critic,
        clip_param=ppo_cfg.clip_param,
        ppo_epoch=ppo_cfg.ppo_epoch,
        num_mini_batch=ppo_cfg.num_mini_batch,
        value_loss_coef=ppo_cfg.value_loss_coef,
        entropy_coef=ppo_cfg.entropy_coef,
        lr=ppo_cfg.lr,
        eps=ppo_cfg.eps,
        max_grad_norm=ppo_cfg.max_grad_norm,
        use_normalized_advantage=ppo_cfg.use_normalized_advantage,
    )
    agent.init_distributed()
    rollouts = RolloutStorage(
        ppo_cfg.num_steps,
        2,
        obs_space,
        action_space,
        ppo_cfg.hidden_size,
        num_recurrent_layers=actor_critic.net.num_recurrent_layers,
        is_double_buffered=False,
    )
    rollouts.to(device)

    for k, v in rollouts.buffers["observations"].items():
        rollouts.buffers["observations"][k] = torch.randn_like(v)

    # Add two steps so batching works
    rollouts.advance_rollout()
    rollouts.advance_rollout()

    # Get a single batch
    batch = next(rollouts.recurrent_generator(rollouts.buffers["returns"], 1))

    # Call eval actions through the internal wrapper that is used in
    # agent.update
    value, action_log_probs, dist_entropy, _ = agent._evaluate_actions(
        batch["observations"],
        batch["recurrent_hidden_states"],
        batch["prev_actions"],
        batch["masks"],
        batch["actions"],
    )
    # Backprop on things
    (value.mean() + action_log_probs.mean() + dist_entropy.mean()).backward()

    # Make sure all ranks have very similar parameters
    for param in actor_critic.parameters():
        if param.grad is not None:
            grads = [param.grad.detach().clone() for _ in range(world_size)]
            distrib.all_gather(grads, grads[world_rank])

            for i in range(world_size):
                assert torch.isclose(grads[i], grads[world_rank]).all()
示例#3
0
class PPOTrainer(BaseRLTrainer):
    r"""Trainer class for PPO algorithm
    Paper: https://arxiv.org/abs/1707.06347.
    """
    supported_tasks = ["Nav-v0"]

    SHORT_ROLLOUT_THRESHOLD: float = 0.25
    _is_distributed: bool
    _obs_batching_cache: ObservationBatchingCache
    envs: VectorEnv
    agent: PPO
    actor_critic: Policy

    def __init__(self, config=None):
        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            config = interrupted_state["config"]

        super().__init__(config)
        self.actor_critic = None
        self.agent = None
        self.envs = None
        self.obs_transforms = []

        self._static_encoder = False
        self._encoder = None
        self._obs_space = None

        # Distirbuted if the world size would be
        # greater than 1
        self._is_distributed = get_distrib_size()[2] > 1
        self._obs_batching_cache = ObservationBatchingCache()

    @property
    def obs_space(self):
        if self._obs_space is None and self.envs is not None:
            self._obs_space = self.envs.observation_spaces[0]

        return self._obs_space

    @obs_space.setter
    def obs_space(self, new_obs_space):
        self._obs_space = new_obs_space

    def _all_reduce(self, t: torch.Tensor) -> torch.Tensor:
        r"""All reduce helper method that moves things to the correct
        device and only runs if distributed
        """
        if not self._is_distributed:
            return t

        orig_device = t.device
        t = t.to(device=self.device)
        torch.distributed.all_reduce(t)

        return t.to(device=orig_device)

    def _setup_actor_critic_agent(self, ppo_cfg: Config) -> None:
        r"""Sets up actor critic and agent for PPO.

        Args:
            ppo_cfg: config node with relevant params

        Returns:
            None
        """
        logger.add_filehandler(self.config.LOG_FILE)

        policy = baseline_registry.get_policy(self.config.RL.POLICY.name)
        observation_space = self.obs_space
        self.obs_transforms = get_active_obs_transforms(self.config)
        observation_space = apply_obs_transforms_obs_space(
            observation_space, self.obs_transforms
        )
        self.actor_critic = policy.from_config(
            self.config, observation_space, self.envs.action_spaces[0]
        )
        self.obs_space = observation_space
        self.actor_critic.to(self.device)

        if (
            self.config.RL.DDPPO.pretrained_encoder
            or self.config.RL.DDPPO.pretrained
        ):
            pretrained_state = torch.load(
                self.config.RL.DDPPO.pretrained_weights, map_location="cpu"
            )

        if self.config.RL.DDPPO.pretrained:
            self.actor_critic.load_state_dict(
                {
                    k[len("actor_critic.") :]: v
                    for k, v in pretrained_state["state_dict"].items()
                }
            )
        elif self.config.RL.DDPPO.pretrained_encoder:
            prefix = "actor_critic.net.visual_encoder."
            self.actor_critic.net.visual_encoder.load_state_dict(
                {
                    k[len(prefix) :]: v
                    for k, v in pretrained_state["state_dict"].items()
                    if k.startswith(prefix)
                }
            )

        if not self.config.RL.DDPPO.train_encoder:
            self._static_encoder = True
            for param in self.actor_critic.net.visual_encoder.parameters():
                param.requires_grad_(False)

        if self.config.RL.DDPPO.reset_critic:
            nn.init.orthogonal_(self.actor_critic.critic.fc.weight)
            nn.init.constant_(self.actor_critic.critic.fc.bias, 0)

        self.agent = (DDPPO if self._is_distributed else PPO)(
            actor_critic=self.actor_critic,
            clip_param=ppo_cfg.clip_param,
            ppo_epoch=ppo_cfg.ppo_epoch,
            num_mini_batch=ppo_cfg.num_mini_batch,
            value_loss_coef=ppo_cfg.value_loss_coef,
            entropy_coef=ppo_cfg.entropy_coef,
            lr=ppo_cfg.lr,
            eps=ppo_cfg.eps,
            max_grad_norm=ppo_cfg.max_grad_norm,
            use_normalized_advantage=ppo_cfg.use_normalized_advantage,
        )

    def _init_envs(self, config=None):
        if config is None:
            config = self.config

        self.envs = construct_envs(
            config,
            get_env_class(config.ENV_NAME),
            workers_ignore_signals=is_slurm_batch_job(),
        )

    def _init_train(self):
        if self.config.RL.DDPPO.force_distributed:
            self._is_distributed = True

        if is_slurm_batch_job():
            add_signal_handlers()

        if self._is_distributed:
            local_rank, tcp_store = init_distrib_slurm(
                self.config.RL.DDPPO.distrib_backend
            )
            if rank0_only():
                logger.info(
                    "Initialized DD-PPO with {} workers".format(
                        torch.distributed.get_world_size()
                    )
                )

            self.config.defrost()
            self.config.TORCH_GPU_ID = local_rank
            self.config.SIMULATOR_GPU_ID = local_rank
            # Multiply by the number of simulators to make sure they also get unique seeds
            self.config.TASK_CONFIG.SEED += (
                torch.distributed.get_rank() * self.config.NUM_ENVIRONMENTS
            )
            self.config.freeze()

            random.seed(self.config.TASK_CONFIG.SEED)
            np.random.seed(self.config.TASK_CONFIG.SEED)
            torch.manual_seed(self.config.TASK_CONFIG.SEED)
            self.num_rollouts_done_store = torch.distributed.PrefixStore(
                "rollout_tracker", tcp_store
            )
            self.num_rollouts_done_store.set("num_done", "0")

        if rank0_only() and self.config.VERBOSE:
            logger.info(f"config: {self.config}")

        profiling_wrapper.configure(
            capture_start_step=self.config.PROFILING.CAPTURE_START_STEP,
            num_steps_to_capture=self.config.PROFILING.NUM_STEPS_TO_CAPTURE,
        )

        self._init_envs()

        ppo_cfg = self.config.RL.PPO
        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.config.TORCH_GPU_ID)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        if rank0_only() and not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        if self._is_distributed:
            self.agent.init_distributed(find_unused_params=True)

        logger.info(
            "agent number of parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters())
            )
        )

        obs_space = self.obs_space
        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            obs_space = spaces.Dict(
                {
                    "visual_features": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=self._encoder.output_shape,
                        dtype=np.float32,
                    ),
                    **obs_space.spaces,
                }
            )

        self._nbuffers = 2 if ppo_cfg.use_double_buffered_sampler else 1
        self.rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            obs_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
            is_double_buffered=ppo_cfg.use_double_buffered_sampler,
        )
        self.rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(
            observations, device=self.device, cache=self._obs_batching_cache
        )
        batch = apply_obs_transforms_batch(batch, self.obs_transforms)

        if self._static_encoder:
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)

        self.rollouts.buffers["observations"][0] = batch

        self.current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        self.running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        self.window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )

        self.env_time = 0.0
        self.pth_time = 0.0
        self.t_start = time.time()

    @rank0_only
    @profiling_wrapper.RangeContext("save_checkpoint")
    def save_checkpoint(
        self, file_name: str, extra_state: Optional[Dict] = None
    ) -> None:
        r"""Save checkpoint with specified name.

        Args:
            file_name: file name for checkpoint

        Returns:
            None
        """
        checkpoint = {
            "state_dict": self.agent.state_dict(),
            "config": self.config,
        }
        if extra_state is not None:
            checkpoint["extra_state"] = extra_state

        torch.save(
            checkpoint, os.path.join(self.config.CHECKPOINT_FOLDER, file_name)
        )

    def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict:
        r"""Load checkpoint of specified path as a dict.

        Args:
            checkpoint_path: path of target checkpoint
            *args: additional positional args
            **kwargs: additional keyword args

        Returns:
            dict containing checkpoint info
        """
        return torch.load(checkpoint_path, *args, **kwargs)

    METRICS_BLACKLIST = {"top_down_map", "collisions.is_collision"}

    @classmethod
    def _extract_scalars_from_info(
        cls, info: Dict[str, Any]
    ) -> Dict[str, float]:
        result = {}
        for k, v in info.items():
            if k in cls.METRICS_BLACKLIST:
                continue

            if isinstance(v, dict):
                result.update(
                    {
                        k + "." + subk: subv
                        for subk, subv in cls._extract_scalars_from_info(
                            v
                        ).items()
                        if (k + "." + subk) not in cls.METRICS_BLACKLIST
                    }
                )
            # Things that are scalar-like will have an np.size of 1.
            # Strings also have an np.size of 1, so explicitly ban those
            elif np.size(v) == 1 and not isinstance(v, str):
                result[k] = float(v)

        return result

    @classmethod
    def _extract_scalars_from_infos(
        cls, infos: List[Dict[str, Any]]
    ) -> Dict[str, List[float]]:

        results = defaultdict(list)
        for i in range(len(infos)):
            for k, v in cls._extract_scalars_from_info(infos[i]).items():
                results[k].append(v)

        return results

    def _compute_actions_and_step_envs(self, buffer_index: int = 0):
        num_envs = self.envs.num_envs
        env_slice = slice(
            int(buffer_index * num_envs / self._nbuffers),
            int((buffer_index + 1) * num_envs / self._nbuffers),
        )

        t_sample_action = time.time()

        # sample actions
        with torch.no_grad():
            step_batch = self.rollouts.buffers[
                self.rollouts.current_rollout_step_idxs[buffer_index],
                env_slice,
            ]

            profiling_wrapper.range_push("compute actions")
            (
                values,
                actions,
                actions_log_probs,
                recurrent_hidden_states,
            ) = self.actor_critic.act(
                step_batch["observations"],
                step_batch["recurrent_hidden_states"],
                step_batch["prev_actions"],
                step_batch["masks"],
            )

        # NB: Move actions to CPU.  If CUDA tensors are
        # sent in to env.step(), that will create CUDA contexts
        # in the subprocesses.
        # For backwards compatibility, we also call .item() to convert to
        # an int
        actions = actions.to(device="cpu")
        self.pth_time += time.time() - t_sample_action

        profiling_wrapper.range_pop()  # compute actions

        t_step_env = time.time()

        for index_env, act in zip(
            range(env_slice.start, env_slice.stop), actions.unbind(0)
        ):
            self.envs.async_step_at(index_env, act.item())

        self.env_time += time.time() - t_step_env

        self.rollouts.insert(
            next_recurrent_hidden_states=recurrent_hidden_states,
            actions=actions,
            action_log_probs=actions_log_probs,
            value_preds=values,
            buffer_index=buffer_index,
        )

    def _collect_environment_result(self, buffer_index: int = 0):
        num_envs = self.envs.num_envs
        env_slice = slice(
            int(buffer_index * num_envs / self._nbuffers),
            int((buffer_index + 1) * num_envs / self._nbuffers),
        )

        t_step_env = time.time()
        outputs = [
            self.envs.wait_step_at(index_env)
            for index_env in range(env_slice.start, env_slice.stop)
        ]

        observations, rewards_l, dones, infos = [
            list(x) for x in zip(*outputs)
        ]

        self.env_time += time.time() - t_step_env

        t_update_stats = time.time()
        batch = batch_obs(
            observations, device=self.device, cache=self._obs_batching_cache
        )
        batch = apply_obs_transforms_batch(batch, self.obs_transforms)

        rewards = torch.tensor(
            rewards_l,
            dtype=torch.float,
            device=self.current_episode_reward.device,
        )
        rewards = rewards.unsqueeze(1)

        not_done_masks = torch.tensor(
            [[not done] for done in dones],
            dtype=torch.bool,
            device=self.current_episode_reward.device,
        )
        done_masks = torch.logical_not(not_done_masks)

        self.current_episode_reward[env_slice] += rewards
        current_ep_reward = self.current_episode_reward[env_slice]
        self.running_episode_stats["reward"][env_slice] += current_ep_reward.where(done_masks, current_ep_reward.new_zeros(()))  # type: ignore
        self.running_episode_stats["count"][env_slice] += done_masks.float()  # type: ignore
        for k, v_k in self._extract_scalars_from_infos(infos).items():
            v = torch.tensor(
                v_k,
                dtype=torch.float,
                device=self.current_episode_reward.device,
            ).unsqueeze(1)
            if k not in self.running_episode_stats:
                self.running_episode_stats[k] = torch.zeros_like(
                    self.running_episode_stats["count"]
                )

            self.running_episode_stats[k][env_slice] += v.where(done_masks, v.new_zeros(()))  # type: ignore

        self.current_episode_reward[env_slice].masked_fill_(done_masks, 0.0)

        if self._static_encoder:
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)

        self.rollouts.insert(
            next_observations=batch,
            rewards=rewards,
            next_masks=not_done_masks,
            buffer_index=buffer_index,
        )

        self.rollouts.advance_rollout(buffer_index)

        self.pth_time += time.time() - t_update_stats

        return env_slice.stop - env_slice.start

    @profiling_wrapper.RangeContext("_collect_rollout_step")
    def _collect_rollout_step(self):
        self._compute_actions_and_step_envs()
        return self._collect_environment_result()

    @profiling_wrapper.RangeContext("_update_agent")
    def _update_agent(self):
        ppo_cfg = self.config.RL.PPO
        t_update_model = time.time()
        with torch.no_grad():
            step_batch = self.rollouts.buffers[
                self.rollouts.current_rollout_step_idx
            ]

            next_value = self.actor_critic.get_value(
                step_batch["observations"],
                step_batch["recurrent_hidden_states"],
                step_batch["prev_actions"],
                step_batch["masks"],
            )

        self.rollouts.compute_returns(
            next_value, ppo_cfg.use_gae, ppo_cfg.gamma, ppo_cfg.tau
        )

        self.agent.train()

        value_loss, action_loss, dist_entropy = self.agent.update(
            self.rollouts
        )

        self.rollouts.after_update()
        self.pth_time += time.time() - t_update_model

        return (
            value_loss,
            action_loss,
            dist_entropy,
        )

    def _coalesce_post_step(
        self, losses: Dict[str, float], count_steps_delta: int
    ) -> Dict[str, float]:
        stats_ordering = sorted(self.running_episode_stats.keys())
        stats = torch.stack(
            [self.running_episode_stats[k] for k in stats_ordering], 0
        )

        stats = self._all_reduce(stats)

        for i, k in enumerate(stats_ordering):
            self.window_episode_stats[k].append(stats[i])

        if self._is_distributed:
            loss_name_ordering = sorted(losses.keys())
            stats = torch.tensor(
                [losses[k] for k in loss_name_ordering] + [count_steps_delta],
                device="cpu",
                dtype=torch.float32,
            )
            stats = self._all_reduce(stats)
            count_steps_delta = int(stats[-1].item())
            stats /= torch.distributed.get_world_size()

            losses = {
                k: stats[i].item() for i, k in enumerate(loss_name_ordering)
            }

        if self._is_distributed and rank0_only():
            self.num_rollouts_done_store.set("num_done", "0")

        self.num_steps_done += count_steps_delta

        return losses

    @rank0_only
    def _training_log(
        self, writer, losses: Dict[str, float], prev_time: int = 0
    ):
        deltas = {
            k: (
                (v[-1] - v[0]).sum().item()
                if len(v) > 1
                else v[0].sum().item()
            )
            for k, v in self.window_episode_stats.items()
        }
        deltas["count"] = max(deltas["count"], 1.0)

        writer.add_scalar(
            "reward",
            deltas["reward"] / deltas["count"],
            self.num_steps_done,
        )

        # Check to see if there are any metrics
        # that haven't been logged yet
        metrics = {
            k: v / deltas["count"]
            for k, v in deltas.items()
            if k not in {"reward", "count"}
        }
        if len(metrics) > 0:
            writer.add_scalars("metrics", metrics, self.num_steps_done)

        writer.add_scalars(
            "losses",
            losses,
            self.num_steps_done,
        )

        # log stats
        if self.num_updates_done % self.config.LOG_INTERVAL == 0:
            logger.info(
                "update: {}\tfps: {:.3f}\t".format(
                    self.num_updates_done,
                    self.num_steps_done
                    / ((time.time() - self.t_start) + prev_time),
                )
            )

            logger.info(
                "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                "frames: {}".format(
                    self.num_updates_done,
                    self.env_time,
                    self.pth_time,
                    self.num_steps_done,
                )
            )

            logger.info(
                "Average window size: {}  {}".format(
                    len(self.window_episode_stats["count"]),
                    "  ".join(
                        "{}: {:.3f}".format(k, v / deltas["count"])
                        for k, v in deltas.items()
                        if k != "count"
                    ),
                )
            )

    def should_end_early(self, rollout_step) -> bool:
        if not self._is_distributed:
            return False
        # This is where the preemption of workers happens.  If a
        # worker detects it will be a straggler, it preempts itself!
        return (
            rollout_step
            >= self.config.RL.PPO.num_steps * self.SHORT_ROLLOUT_THRESHOLD
        ) and int(self.num_rollouts_done_store.get("num_done")) >= (
            self.config.RL.DDPPO.sync_frac * torch.distributed.get_world_size()
        )

    @profiling_wrapper.RangeContext("train")
    def train(self) -> None:
        r"""Main method for training DD/PPO.

        Returns:
            None
        """

        self._init_train()

        count_checkpoints = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: 1 - self.percent_done(),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"]
            )
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            self.env_time = requeue_stats["env_time"]
            self.pth_time = requeue_stats["pth_time"]
            self.num_steps_done = requeue_stats["num_steps_done"]
            self.num_updates_done = requeue_stats["num_updates_done"]
            self._last_checkpoint_percent = requeue_stats[
                "_last_checkpoint_percent"
            ]
            count_checkpoints = requeue_stats["count_checkpoints"]
            prev_time = requeue_stats["prev_time"]

            self._last_checkpoint_percent = requeue_stats[
                "_last_checkpoint_percent"
            ]

        ppo_cfg = self.config.RL.PPO

        with (
            TensorboardWriter(
                self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
            )
            if rank0_only()
            else contextlib.suppress()
        ) as writer:
            while not self.is_done():
                profiling_wrapper.on_start_step()
                profiling_wrapper.range_push("train update")

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * (
                        1 - self.percent_done()
                    )

                if EXIT.is_set():
                    profiling_wrapper.range_pop()  # train update

                    self.envs.close()

                    if REQUEUE.is_set() and rank0_only():
                        requeue_stats = dict(
                            env_time=self.env_time,
                            pth_time=self.pth_time,
                            count_checkpoints=count_checkpoints,
                            num_steps_done=self.num_steps_done,
                            num_updates_done=self.num_updates_done,
                            _last_checkpoint_percent=self._last_checkpoint_percent,
                            prev_time=(time.time() - self.t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            )
                        )

                    requeue_job()
                    return

                self.agent.eval()
                count_steps_delta = 0
                profiling_wrapper.range_push("rollouts loop")

                profiling_wrapper.range_push("_collect_rollout_step")
                for buffer_index in range(self._nbuffers):
                    self._compute_actions_and_step_envs(buffer_index)

                for step in range(ppo_cfg.num_steps):
                    is_last_step = (
                        self.should_end_early(step + 1)
                        or (step + 1) == ppo_cfg.num_steps
                    )

                    for buffer_index in range(self._nbuffers):
                        count_steps_delta += self._collect_environment_result(
                            buffer_index
                        )

                        if (buffer_index + 1) == self._nbuffers:
                            profiling_wrapper.range_pop()  # _collect_rollout_step

                        if not is_last_step:
                            if (buffer_index + 1) == self._nbuffers:
                                profiling_wrapper.range_push(
                                    "_collect_rollout_step"
                                )

                            self._compute_actions_and_step_envs(buffer_index)

                    if is_last_step:
                        break

                profiling_wrapper.range_pop()  # rollouts loop

                if self._is_distributed:
                    self.num_rollouts_done_store.add("num_done", 1)

                (
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent()

                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()  # type: ignore

                self.num_updates_done += 1
                losses = self._coalesce_post_step(
                    dict(value_loss=value_loss, action_loss=action_loss),
                    count_steps_delta,
                )

                self._training_log(writer, losses, prev_time)

                # checkpoint model
                if rank0_only() and self.should_checkpoint():
                    self.save_checkpoint(
                        f"ckpt.{count_checkpoints}.pth",
                        dict(
                            step=self.num_steps_done,
                            wall_time=(time.time() - self.t_start) + prev_time,
                        ),
                    )
                    count_checkpoints += 1

                profiling_wrapper.range_pop()  # train update

            self.envs.close()

    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        if self._is_distributed:
            raise RuntimeError("Evaluation does not support distributed mode")

        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        if config.VERBOSE:
            logger.info(f"env config: {config}")

        self._init_envs(config)
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        observations = self.envs.reset()
        batch = batch_obs(
            observations, device=self.device, cache=self._obs_batching_cache
        )
        batch = apply_obs_transforms_batch(batch, self.obs_transforms)

        current_episode_reward = torch.zeros(
            self.envs.num_envs, 1, device="cpu"
        )

        test_recurrent_hidden_states = torch.zeros(
            self.config.NUM_ENVIRONMENTS,
            self.actor_critic.net.num_recurrent_layers,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(
            self.config.NUM_ENVIRONMENTS,
            1,
            device=self.device,
            dtype=torch.long,
        )
        not_done_masks = torch.zeros(
            self.config.NUM_ENVIRONMENTS,
            1,
            device=self.device,
            dtype=torch.bool,
        )
        stats_episodes: Dict[
            Any, Any
        ] = {}  # dict of dicts that stores stats per episode

        rgb_frames = [
            [] for _ in range(self.config.NUM_ENVIRONMENTS)
        ]  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}."
                )
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        pbar = tqdm.tqdm(total=number_of_eval_episodes)
        self.actor_critic.eval()
        while (
            len(stats_episodes) < number_of_eval_episodes
            and self.envs.num_envs > 0
        ):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                (
                    _,
                    actions,
                    _,
                    test_recurrent_hidden_states,
                ) = self.actor_critic.act(
                    batch,
                    test_recurrent_hidden_states,
                    prev_actions,
                    not_done_masks,
                    deterministic=False,
                )

                prev_actions.copy_(actions)  # type: ignore

            # NB: Move actions to CPU.  If CUDA tensors are
            # sent in to env.step(), that will create CUDA contexts
            # in the subprocesses.
            # For backwards compatibility, we also call .item() to convert to
            # an int
            step_data = [a.item() for a in actions.to(device="cpu")]

            outputs = self.envs.step(step_data)

            observations, rewards_l, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(
                observations,
                device=self.device,
                cache=self._obs_batching_cache,
            )
            batch = apply_obs_transforms_batch(batch, self.obs_transforms)

            not_done_masks = torch.tensor(
                [[not done] for done in dones],
                dtype=torch.bool,
                device="cpu",
            )

            rewards = torch.tensor(
                rewards_l, dtype=torch.float, device="cpu"
            ).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                if (
                    next_episodes[i].scene_id,
                    next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not not_done_masks[i].item():
                    pbar.update()
                    episode_stats = {}
                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i])
                    )
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[
                        (
                            current_episodes[i].scene_id,
                            current_episodes[i].episode_id,
                        )
                    ] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metrics=self._extract_scalars_from_info(infos[i]),
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    # TODO move normalization / channel changing out of the policy and undo it here
                    frame = observations_to_image(
                        {k: v[i] for k, v in batch.items()}, infos[i]
                    )
                    rgb_frames[i].append(frame)

            not_done_masks = not_done_masks.to(device=self.device)
            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        num_episodes = len(stats_episodes)
        aggregated_stats = {}
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = (
                sum(v[stat_key] for v in stats_episodes.values())
                / num_episodes
            )

        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.4f}")

        step_id = checkpoint_index
        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        writer.add_scalars(
            "eval_reward",
            {"average reward": aggregated_stats["reward"]},
            step_id,
        )

        metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"}
        if len(metrics) > 0:
            writer.add_scalars("eval_metrics", metrics, step_id)

        self.envs.close()
示例#4
0
    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        self.device = torch.device("cuda", self.config.TORCH_GPU_ID)
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        observations = self.envs.reset()
        batch = batch_obs(observations)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
        )
        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])
        rollouts.to(self.device)

        episode_rewards = torch.zeros(self.envs.num_envs, 1)
        episode_counts = torch.zeros(self.envs.num_envs, 1)
        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for update in range(self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                for step in range(ppo_cfg.num_steps):
                    delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step(
                        rollouts,
                        current_episode_reward,
                        episode_rewards,
                        episode_counts,
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent(
                    ppo_cfg, rollouts)
                pth_time += delta_pth_time

                window_episode_reward.append(episode_rewards.clone())
                window_episode_counts.append(episode_counts.clone())

                losses = [value_loss, action_loss]
                stats = zip(
                    ["count", "reward"],
                    [window_episode_counts, window_episode_reward],
                )
                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in stats
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                writer.add_scalars(
                    "losses",
                    {k: l
                     for l, k in zip(losses, ["value", "policy"])},
                    count_steps,
                )

                # log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    window_rewards = (window_episode_reward[-1] -
                                      window_episode_reward[0]).sum()
                    window_counts = (window_episode_counts[-1] -
                                     window_episode_counts[0]).sum()

                    if window_counts > 0:
                        logger.info(
                            "Average window size {} reward: {:3f}".format(
                                len(window_episode_reward),
                                (window_rewards / window_counts).item(),
                            ))
                    else:
                        logger.info("No episodes finish in current window")

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
                    count_checkpoints += 1

            self.envs.close()
示例#5
0
    def train(self) -> None:
        r"""Main method for training PPO.
        Returns:
            None
        """

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
        )
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for update in range(self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                for k, v in running_episode_stats.items():
                    window_episode_stats[k].append(v.clone())

                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in window_episode_stats.items()
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                # Check to see if there are any metrics
                # that haven't been logged yet
                metrics = {
                    k: v / deltas["count"]
                    for k, v in deltas.items() if k not in {"reward", "count"}
                }
                if len(metrics) > 0:
                    writer.add_scalars("metrics", metrics, count_steps)

                losses = [value_loss, action_loss]
                writer.add_scalars(
                    "losses",
                    {k: l
                     for l, k in zip(losses, ["value", "policy"])},
                    count_steps,
                )

                # log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    logger.info("Average window size: {}  {}".format(
                        len(window_episode_stats["count"]),
                        "  ".join("{}: {:.3f}".format(k, v / deltas["count"])
                                  for k, v in deltas.items() if k != "count"),
                    ))

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth",
                                         dict(step=count_steps))
                    count_checkpoints += 1

            self.envs.close()
    def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        task_cfg = self.config.TASK_CONFIG.TASK
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        # Initialize auxiliary tasks
        observation_space = self.envs.observation_spaces[0]
        aux_cfg = self.config.RL.AUX_TASKS
        init_aux_tasks, num_recurrent_memories, aux_task_strings = \
            self._setup_auxiliary_tasks(aux_cfg, ppo_cfg, task_cfg, observation_space)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            observation_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_memories=num_recurrent_memories)
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg,
                                       init_aux_tasks)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        if ckpt != -1:
            logger.info(
                f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly."
            )
            assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported"
            # This is the checkpoint we start saving at
            count_checkpoints = ckpt + 1
            count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES
            ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu")
            self.agent.load_state_dict(ckpt_dict["state_dict"])
            if "optim_state" in ckpt_dict:
                self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"])
            else:
                logger.warn("No optimizer state loaded, results may be funky")
            if "extra_state" in ckpt_dict and "step" in ckpt_dict[
                    "extra_state"]:
                count_steps = ckpt_dict["extra_state"]["step"]

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:

            for update in range(start_updates, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                delta_pth_time, value_loss, action_loss, dist_entropy, aux_task_losses, aux_dist_entropy, aux_weights = self._update_agent(
                    ppo_cfg, rollouts)
                pth_time += delta_pth_time

                for k, v in running_episode_stats.items():
                    window_episode_stats[k].append(v.clone())

                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in window_episode_stats.items()
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar(
                    "entropy",
                    dist_entropy,
                    count_steps,
                )

                writer.add_scalar("aux_entropy", aux_dist_entropy, count_steps)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                # Check to see if there are any metrics
                # that haven't been logged yet
                metrics = {
                    k: v / deltas["count"]
                    for k, v in deltas.items() if k not in {"reward", "count"}
                }
                if len(metrics) > 0:
                    writer.add_scalars("metrics", metrics, count_steps)

                losses = [value_loss, action_loss] + aux_task_losses
                writer.add_scalars(
                    "losses",
                    {
                        k: l
                        for l, k in zip(losses, ["value", "policy"] +
                                        aux_task_strings)
                    },
                    count_steps,
                )

                writer.add_scalars(
                    "aux_weights",
                    {k: l
                     for l, k in zip(aux_weights, aux_task_strings)},
                    count_steps,
                )

                writer.add_scalar(
                    "success",
                    deltas["success"] / deltas["count"],
                    count_steps,
                )

                # Log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info(
                        "update: {}\tvalue_loss: {}\t action_loss: {}\taux_task_loss: {} \t aux_entropy {}"
                        .format(update, value_loss, action_loss,
                                aux_task_losses, aux_dist_entropy))
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    logger.info("Average window size: {}  {}".format(
                        len(window_episode_stats["count"]),
                        "  ".join("{}: {:.3f}".format(k, v / deltas["count"])
                                  for k, v in deltas.items() if k != "count"),
                    ))

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(
                        f"{self.checkpoint_prefix}.{count_checkpoints}.pth",
                        dict(step=count_steps))
                    count_checkpoints += 1

        self.envs.close()
示例#7
0
    def train(self):
        # Get environments for training
        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))

        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        #logger.info(
        #    "agent number of parameters: {}".format(
        #        sum(param.numel() for param in self.agent.parameters())
        #    )
        #)

        # Change for the actual value
        cfg = self.config.RL.PPO

        rollouts = RolloutStorage(
            cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            cfg.hidden_size,
        )
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations)
        for sensor in rollouts.observations:
            print(batch[sensor].shape)

        # Copy the information to the wrapper
        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        episode_rewards = torch.zeros(self.envs.num_envs, 1)
        episode_counts = torch.zeros(self.envs.num_envs, 1)
        #current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        #window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        #window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )
        '''
示例#8
0
    def train(self) -> None:
        r"""
        Main method for training PPO
        Returns:
            None
        """
        assert (
            self.config is not None
        ), "trainer is not properly initialized, need to specify config file"

        self.envs = construct_envs(self.config, NavRLEnv)

        ppo_cfg = self.config.TRAINER.RL.PPO
        self.device = torch.device("cuda", ppo_cfg.pth_gpu_id)
        if not os.path.isdir(ppo_cfg.checkpoint_folder):
            os.makedirs(ppo_cfg.checkpoint_folder)
        self._setup_actor_critic_agent(ppo_cfg)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        observations = self.envs.reset()
        batch = batch_obs(observations)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
        )
        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])
        rollouts.to(self.device)

        episode_rewards = torch.zeros(self.envs.num_envs, 1)
        episode_counts = torch.zeros(self.envs.num_envs, 1)
        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        with (get_tensorboard_writer(
                log_dir=ppo_cfg.tensorboard_dir,
                purge_step=count_steps,
                flush_secs=30,
        )) as writer:
            for update in range(ppo_cfg.num_updates):
                if ppo_cfg.use_linear_lr_decay:
                    update_linear_schedule(
                        self.agent.optimizer,
                        update,
                        ppo_cfg.num_updates,
                        ppo_cfg.lr,
                    )

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * (
                        1 - update / ppo_cfg.num_updates)

                for step in range(ppo_cfg.num_steps):
                    t_sample_action = time.time()
                    # sample actions
                    with torch.no_grad():
                        step_observation = {
                            k: v[step]
                            for k, v in rollouts.observations.items()
                        }

                        (
                            values,
                            actions,
                            actions_log_probs,
                            recurrent_hidden_states,
                        ) = self.actor_critic.act(
                            step_observation,
                            rollouts.recurrent_hidden_states[step],
                            rollouts.masks[step],
                        )
                    pth_time += time.time() - t_sample_action

                    t_step_env = time.time()

                    outputs = self.envs.step([a[0].item() for a in actions])
                    observations, rewards, dones, infos = [
                        list(x) for x in zip(*outputs)
                    ]

                    env_time += time.time() - t_step_env

                    t_update_stats = time.time()
                    batch = batch_obs(observations)
                    rewards = torch.tensor(rewards, dtype=torch.float)
                    rewards = rewards.unsqueeze(1)

                    masks = torch.tensor(
                        [[0.0] if done else [1.0] for done in dones],
                        dtype=torch.float,
                    )

                    current_episode_reward += rewards
                    episode_rewards += (1 - masks) * current_episode_reward
                    episode_counts += 1 - masks
                    current_episode_reward *= masks

                    rollouts.insert(
                        batch,
                        recurrent_hidden_states,
                        actions,
                        actions_log_probs,
                        values,
                        rewards,
                        masks,
                    )

                    count_steps += self.envs.num_envs
                    pth_time += time.time() - t_update_stats

                window_episode_reward.append(episode_rewards.clone())
                window_episode_counts.append(episode_counts.clone())

                t_update_model = time.time()
                with torch.no_grad():
                    last_observation = {
                        k: v[-1]
                        for k, v in rollouts.observations.items()
                    }
                    next_value = self.actor_critic.get_value(
                        last_observation,
                        rollouts.recurrent_hidden_states[-1],
                        rollouts.masks[-1],
                    ).detach()

                rollouts.compute_returns(next_value, ppo_cfg.use_gae,
                                         ppo_cfg.gamma, ppo_cfg.tau)

                value_loss, action_loss, dist_entropy = self.agent.update(
                    rollouts)

                rollouts.after_update()
                pth_time += time.time() - t_update_model

                losses = [value_loss, action_loss]
                stats = zip(
                    ["count", "reward"],
                    [window_episode_counts, window_episode_reward],
                )
                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in stats
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                writer.add_scalars(
                    "losses",
                    {k: l
                     for l, k in zip(losses, ["value", "policy"])},
                    count_steps,
                )

                # log stats
                if update > 0 and update % ppo_cfg.log_interval == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    window_rewards = (window_episode_reward[-1] -
                                      window_episode_reward[0]).sum()
                    window_counts = (window_episode_counts[-1] -
                                     window_episode_counts[0]).sum()

                    if window_counts > 0:
                        logger.info(
                            "Average window size {} reward: {:3f}".format(
                                len(window_episode_reward),
                                (window_rewards / window_counts).item(),
                            ))
                    else:
                        logger.info("No episodes finish in current window")

                # checkpoint model
                if update % ppo_cfg.checkpoint_interval == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
                    count_checkpoints += 1
示例#9
0
    def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        random.seed(self.config.TASK_CONFIG.SEED + self.world_rank)
        np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank)

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        self.config.freeze()

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        task_cfg = self.config.TASK_CONFIG.TASK

        observation_space = self.envs.observation_spaces[0]
        aux_cfg = self.config.RL.AUX_TASKS
        init_aux_tasks, num_recurrent_memories, aux_task_strings = self._setup_auxiliary_tasks(
            aux_cfg, ppo_cfg, task_cfg, observation_space)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            observation_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_memories=num_recurrent_memories)
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg,
                                       init_aux_tasks)
        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),  # including bonus
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        prev_time = 0

        if ckpt != -1:
            logger.info(
                f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly."
            )
            assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported"
            # This is the checkpoint we start saving at
            count_checkpoints = ckpt + 1
            count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES
            ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu")
            self.agent.load_state_dict(ckpt_dict["state_dict"])
            if "optim_state" in ckpt_dict:
                self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"])
            else:
                logger.warn("No optimizer state loaded, results may be funky")
            if "extra_state" in ckpt_dict and "step" in ckpt_dict[
                    "extra_state"]:
                count_steps = ckpt_dict["extra_state"]["step"]

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_updates = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:

            for update in range(start_updates, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)
                self.agent.train()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                    aux_task_losses,
                    aux_dist_entropy,
                    aux_weights,
                ) = self._update_agent(ppo_cfg, rollouts)

                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering],
                    0).to(self.device)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [
                        dist_entropy,
                        aux_dist_entropy,
                    ] + [value_loss, action_loss] + aux_task_losses +
                    [count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                if aux_weights is not None and len(aux_weights) > 0:
                    distrib.all_reduce(
                        torch.tensor(aux_weights, device=self.device))
                count_steps += stats[-1].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    avg_stats = [
                        stats[i].item() / self.world_size
                        for i in range(len(stats) - 1)
                    ]
                    losses = avg_stats[2:]
                    dist_entropy, aux_dist_entropy = avg_stats[:2]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    writer.add_scalar(
                        "entropy",
                        dist_entropy,
                        count_steps,
                    )

                    writer.add_scalar("aux_entropy", aux_dist_entropy,
                                      count_steps)

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {
                            k: l
                            for l, k in zip(losses, ["value", "policy"] +
                                            aux_task_strings)
                        },
                        count_steps,
                    )

                    writer.add_scalars(
                        "aux_weights",
                        {k: l
                         for l, k in zip(aux_weights, aux_task_strings)},
                        count_steps,
                    )

                    # Log stats
                    formatted_aux_losses = [
                        "{:.3g}".format(l) for l in aux_task_losses
                    ]
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info(
                            "update: {}\tvalue_loss: {:.3g}\t action_loss: {:.3g}\taux_task_loss: {} \t aux_entropy {:.3g}\t"
                            .format(
                                update,
                                value_loss,
                                action_loss,
                                formatted_aux_losses,
                                aux_dist_entropy,
                            ))
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            count_steps /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"{self.checkpoint_prefix}.{count_checkpoints}.pth",
                            dict(step=count_steps))
                        count_checkpoints += 1

        self.envs.close()
示例#10
0
    def benchmark(self) -> None:
        if TIME_DEBUG: s = time.time()
        #self.config.defrost()
        #self.config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_EPISODES = 10
        #self.config.freeze()
        if torch.cuda.device_count() <= 1:
            self.config.defrost()
            self.config.TORCH_GPU_ID = 0
            self.config.SIMULATOR_GPU_ID = 0
            self.config.freeze()
        self.envs = construct_envs(self.config, eval(self.config.ENV_NAME))
        if ADD_IL:
            self.il_envs = construct_envs(self.config,
                                          eval(self.config.ENV_NAME),
                                          no_val=True)
        self.collect_mode = 'RL'

        if TIME_DEBUG: s = log_time(s, 'construct envs')
        ppo_cfg = self.config.RL.PPO
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg)
        # if 'SMT' in self.config.POLICY:
        #     sd = torch.load('visual_embedding18.pth')
        #     self.actor_critic.net.visual_encoder.load_state_dict(sd['visual_encoder'])
        #     self.actor_critic.net.prev_action_embedding.load_state_dict(sd['prev_action_embedding'])
        #     self.actor_critic.net.visual_encoder.cuda()
        #     self.actor_critic.net.prev_action_embedding.cuda()
        #     self.envs.setup_embedding_network(self.actor_critic.net.visual_encoder, self.actor_critic.net.prev_action_embedding)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        num_train_processes, num_val_processes = self.config.NUM_PROCESSES, self.config.NUM_VAL_PROCESSES
        total_processes = num_train_processes + num_val_processes
        OBS_LIST = self.config.OBS_TO_SAVE
        self.num_processes = num_train_processes
        rollouts = RolloutStorage(ppo_cfg.num_steps,
                                  num_train_processes,
                                  self.envs.observation_spaces[0],
                                  self.envs.action_spaces[0],
                                  ppo_cfg.hidden_size,
                                  self.actor_critic.net.num_recurrent_layers,
                                  OBS_LIST=OBS_LIST)
        rollouts.to(self.device)

        batch = self.envs.reset()

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(
                batch[sensor][:num_train_processes])
        self.last_observations = batch
        self.last_recurrent_hidden_states = torch.zeros(
            self.actor_critic.net.num_recurrent_layers, total_processes,
            ppo_cfg.hidden_size).to(self.device)
        self.last_prev_actions = torch.zeros(
            total_processes, rollouts.prev_actions.shape[-1]).to(self.device)
        self.last_masks = torch.zeros(total_processes, 1).to(self.device)

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None
        if ADD_IL:
            rollouts2 = RolloutStorage(
                ppo_cfg.num_steps,
                num_train_processes,
                self.il_envs.observation_spaces[0],
                self.il_envs.action_spaces[0],
                ppo_cfg.hidden_size,
                self.actor_critic.net.num_recurrent_layers,
                OBS_LIST=OBS_LIST)
            rollouts2.to(self.device)
            batch2 = self.il_envs.reset()
            for sensor in rollouts2.observations:
                rollouts2.observations[sensor][0].copy_(
                    batch2[sensor][:num_train_processes])
            self.saved_last_obs = batch2
            self.saved_last_recurrent_hidden_states = torch.zeros(
                self.actor_critic.net.num_recurrent_layers, total_processes,
                ppo_cfg.hidden_size).to(self.device)
            self.saved_last_prev_actions = torch.zeros(
                total_processes,
                rollouts2.prev_actions.shape[-1]).to(self.device)
            self.saved_last_masks = torch.zeros(total_processes,
                                                1).to(self.device)
            batch2 = None
        else:
            rollouts2 = None

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0 if not hasattr(self,
                                       'resume_steps') else self.resume_steps
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )
        if TIME_DEBUG: s = log_time(s, 'setup all')

        for update in range(100):
            if ppo_cfg.use_linear_lr_decay:
                lr_scheduler.step()

            if ppo_cfg.use_linear_clip_decay:
                self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                    update, self.config.NUM_UPDATES)
            if TIME_DEBUG: s = log_time(s, 'collect rollout start')
            modes = ['RL']
            if ADD_IL:
                modes += ['IL']
            for collect_mode in modes:
                self.collect_mode = collect_mode
                use_rollouts = rollouts if self.collect_mode == 'RL' else rollouts2
                if ADD_IL: self.exchange_lasts()
                for step in range(ppo_cfg.num_steps):

                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(use_rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps
                #print(delta_env_time, delta_pth_time)
            if TIME_DEBUG: s = log_time(s, 'collect rollout done')
            (delta_pth_time, value_loss, action_loss, dist_entropy,
             il_loss) = self._update_agent(ppo_cfg, rollouts, rollouts2)
            #print(delta_pth_time)
            pth_time += delta_pth_time
            use_rollouts.after_update()
            if TIME_DEBUG: s = log_time(s, 'update agent')
            for k, v in running_episode_stats.items():
                window_episode_stats[k].append(v.clone())

            deltas = {
                k: ((v[-1][:self.num_processes] -
                     v[0][:self.num_processes]).sum().item()
                    if len(v) > 1 else v[0][:self.num_processes].sum().item())
                for k, v in window_episode_stats.items()
            }

            deltas["count"] = max(deltas["count"], 1.0)
            #self.write_tb('train', writer, deltas, count_steps, losses)

            eval_deltas = {
                k: ((v[-1][self.num_processes:] -
                     v[0][self.num_processes:]).sum().item()
                    if len(v) > 1 else v[0][self.num_processes:].sum().item())
                for k, v in window_episode_stats.items()
            }
            eval_deltas["count"] = max(eval_deltas["count"], 1.0)

            #self.write_tb('val', writer, eval_deltas, count_steps)

            # log stats
            if update > 0 and update % self.config.LOG_INTERVAL == 0:
                logger.info("update: {}\tfps: {:.3f}\t".format(
                    update, count_steps / (time.time() - t_start)))

                logger.info(
                    "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                    "frames: {}".format(update, env_time, pth_time,
                                        count_steps))
                logger.info("Average window size: {}  {}".format(
                    len(window_episode_stats["count"]),
                    "  ".join("{}: {:.3f}".format(k, v / deltas["count"])
                              for k, v in deltas.items() if k != "count"),
                ))
                logger.info("validation metrics: {}".format(
                    "  ".join("{}: {:.3f}".format(k, v / eval_deltas["count"])
                              for k, v in eval_deltas.items()
                              if k != "count"), ))

        self.envs.close()
示例#11
0
    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """

        self.add_new_based_on_cfg()

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg, train=True)

        if self.config.PRETRAINED_CHECKPOINT_PATH:
            ckpt_dict = self.load_checkpoint(
                self.config.PRETRAINED_CHECKPOINT_PATH, map_location="cpu")
            self.agent.load_state_dict(ckpt_dict["state_dict"], strict=False)

        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers)
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs_augment_aux(observations, self.envs.get_shared_mem())

        for sensor in rollouts.observations:
            if sensor in batch:
                rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        info_data_keys = ["discovered", "collisions_wall", "collisions_prox"]

        log_data_keys = [
            "episode_rewards", "episode_go_rewards", "episode_counts",
            "current_episode_reward", "current_episode_go_reward"
        ] + info_data_keys

        log_data = dict(
            {k: torch.zeros(self.envs.num_envs, 1)
             for k in log_data_keys})
        info_data = dict({k: log_data[k] for k in info_data_keys})

        win_keys = log_data_keys
        win_keys.pop(win_keys.index("current_episode_reward"))
        win_keys.pop(win_keys.index("current_episode_go_reward"))

        windows = dict({
            k: deque(maxlen=ppo_cfg.reward_window_size)
            for k in log_data.keys()
        })

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        train_steps = min(self.config.NUM_UPDATES,
                          self.config.HARD_NUM_UPDATES)

        log_interval = self.config.LOG_INTERVAL
        num_updates = self.config.NUM_UPDATES
        agent = self.agent
        ckpt_interval = self.config.CHECKPOINT_INTERVAL

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for update in range(train_steps):
                if ppo_cfg.use_linear_clip_decay:
                    agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, num_updates)

                for step in range(ppo_cfg.num_steps):
                    delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step(
                        rollouts, log_data["current_episode_reward"],
                        log_data["current_episode_go_reward"],
                        log_data["episode_rewards"],
                        log_data["episode_go_rewards"],
                        log_data["episode_counts"], info_data)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                delta_pth_time, value_loss, action_loss, dist_entropy,\
                    aux_loss = self._update_agent(ppo_cfg, rollouts)

                # TODO check if LR is init
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                pth_time += delta_pth_time

                # ==================================================================================
                # -- Log data for window averaging
                for k, v in windows.items():
                    windows[k].append(log_data[k].clone())

                value_names = ["value", "policy", "entropy"] + list(
                    aux_loss.keys())
                losses = [value_loss, action_loss, dist_entropy] + list(
                    aux_loss.values())

                stats = zip(list(windows.keys()), list(windows.values()))
                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in stats
                }
                act_ep = deltas["episode_counts"]
                counts = max(act_ep, 1.0)
                deltas["episode_counts"] *= counts

                for k, v in deltas.items():
                    deltas[k] = v / counts
                    writer.add_scalar(k, deltas[k], count_steps)

                writer.add_scalars("losses",
                                   {k: l
                                    for l, k in zip(losses, value_names)},
                                   count_steps)

                # log stats
                if update > 0 and update % log_interval == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    if act_ep > 0:
                        log_txt = f"Average window size {len(windows['episode_counts'])}"
                        for k, v in deltas.items():
                            log_txt += f" | {k}: {v:.3f}"

                        logger.info(log_txt)
                        logger.info(
                            f"Aux losses: {list(zip(value_names, losses))}")
                    else:
                        logger.info("No episodes finish in current window")
                # ==================================================================================

                # checkpoint model
                if update % ckpt_interval == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
                    count_checkpoints += 1

            self.envs.close()
示例#12
0
    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        if TIME_DEBUG: s = time.time()
        self.envs = construct_envs(
            self.config, eval(self.config.ENV_NAME)
        )
        if TIME_DEBUG: s = log_time(s, 'construct envs')
        ppo_cfg = self.config.RL.PPO
        self.device = (
            torch.device("cuda", self.config.TORCH_GPU_ID)
            if torch.cuda.is_available()
            else torch.device("cpu")
        )
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg)
        if 'SMT' in self.config.POLICY:
            sd = torch.load('visual_embedding18_explore.pth') if 'Explore' in self.config.POLICY else torch.load('visual_embedding18.pth')
            self.actor_critic.net.visual_encoder.load_state_dict(sd['visual_encoder'])
            self.actor_critic.net.prev_action_embedding.load_state_dict(sd['prev_action_embedding'])
            self.actor_critic.net.visual_encoder.cuda()
            self.actor_critic.net.prev_action_embedding.cuda()
            total_num = self.config.NUM_PROCESSES + self.config.NUM_VAL_PROCESSES
            args_list = {'visual_encoder': self.actor_critic.net.visual_encoder,
                                                       'prev_action_embedding': self.actor_critic.net.prev_action_embedding}
            self.envs.call(['setup_embedding_network']*total_num, [args_list]*total_num)
        logger.info(
            "agent number of parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters())
            )
        )

        num_train_processes, num_val_processes = self.config.NUM_PROCESSES, self.config.NUM_VAL_PROCESSES
        total_processes = num_train_processes + num_val_processes
        OBS_LIST = self.config.OBS_TO_SAVE
        self.num_processes = num_train_processes
        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            num_train_processes,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            self.actor_critic.net.num_recurrent_layers,
            OBS_LIST = OBS_LIST
        )
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor][:num_train_processes])
        self.last_observations = batch
        self.last_recurrent_hidden_states = torch.zeros(self.actor_critic.net.num_recurrent_layers, total_processes, ppo_cfg.hidden_size).to(self.device)
        self.last_prev_actions = torch.zeros(total_processes, rollouts.prev_actions.shape[-1]).to(self.device)
        self.last_masks = torch.zeros(total_processes,1).to(self.device)

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0 if not hasattr(self, 'resume_steps') else self.resume_steps
        start_steps = 0 if not hasattr(self, 'resume_steps') else self.resume_steps
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )
        if TIME_DEBUG: s = log_time(s, 'setup all')
        with TensorboardWriter(
            self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
        ) as writer:
            for update in range(self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES
                    )
                if TIME_DEBUG: s = log_time(s, 'collect rollout start')
                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(
                        rollouts, current_episode_reward, running_episode_stats
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps
                if TIME_DEBUG: s = log_time(s, 'collect rollout done')
                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                    il_loss
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time
                rollouts.after_update()
                if TIME_DEBUG: s = log_time(s, 'update agent')
                for k, v in running_episode_stats.items():
                    window_episode_stats[k].append(v.clone())

                deltas = {
                    k: (
                        (v[-1][:self.num_processes] - v[0][:self.num_processes]).sum().item()
                        if len(v) > 1
                        else v[0][:self.num_processes].sum().item()
                    )
                    for k, v in window_episode_stats.items()
                }


                deltas["count"] = max(deltas["count"], 1.0)
                losses = [value_loss, action_loss, dist_entropy, il_loss]
                self.write_tb('train', writer, deltas, count_steps, losses)

                eval_deltas = {
                    k: (
                        (v[-1][self.num_processes:] - v[0][self.num_processes:]).sum().item()
                        if len(v) > 1
                        else v[0][self.num_processes:].sum().item()
                    )
                    for k, v in window_episode_stats.items()
                }
                eval_deltas["count"] = max(eval_deltas["count"], 1.0)

                self.write_tb('val', writer, eval_deltas, count_steps)

                # log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info(
                        "update: {}\tfps: {:.3f}\t".format(
                            update, (count_steps - start_steps) / (time.time() - t_start)
                        )
                    )

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(
                            update, env_time, pth_time, count_steps
                        )
                    )
                    logger.info(
                        "Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items()
                                if k != "count"
                            ),
                        )
                    )
                    logger.info(
                        "validation metrics: {}".format(
                            "  ".join(
                                "{}: {:.3f}".format(k, v / eval_deltas["count"])
                                for k, v in eval_deltas.items()
                                if k != "count"
                            ),
                        )
                    )


                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(
                        f"ckpt.{count_checkpoints}.pth", dict(step=count_steps)
                    )
                    count_checkpoints += 1

            self.envs.close()
    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        self.add_new_based_on_cfg()

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg, train=True)

        if self.config.PRETRAINED_CHECKPOINT_PATH:
            ckpt_dict = self.load_checkpoint(
                self.config.PRETRAINED_CHECKPOINT_PATH, map_location="cpu")
            self.agent.load_state_dict(ckpt_dict["state_dict"], strict=False)

        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers)
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs_augment_aux(observations)

        for sensor in rollouts.observations:
            if sensor in batch:
                rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        episode_rewards = torch.zeros(self.envs.num_envs, 1)
        episode_go_rewards = torch.zeros(self.envs.num_envs,
                                         1)  # Grid oracle rewars
        episode_counts = torch.zeros(self.envs.num_envs, 1)
        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        current_episode_go_reward = torch.zeros(self.envs.num_envs,
                                                1)  # Grid oracle rewars
        window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_go_reward = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        train_steps = min(self.config.NUM_UPDATES,
                          self.config.HARD_NUM_UPDATES)
        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for update in range(train_steps):
                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                for step in range(ppo_cfg.num_steps):
                    delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step(
                        rollouts,
                        current_episode_reward,
                        current_episode_go_reward,
                        episode_rewards,
                        episode_go_rewards,
                        episode_counts,
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                delta_pth_time, value_loss, action_loss, dist_entropy,\
                    aux_loss = self._update_agent(ppo_cfg, rollouts)

                # TODO check if LR is init
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                pth_time += delta_pth_time

                window_episode_reward.append(episode_rewards.clone())
                window_episode_go_reward.append(episode_go_rewards.clone())
                window_episode_counts.append(episode_counts.clone())

                value_names = ["value", "policy", "entropy"] + list(
                    aux_loss.keys())
                losses = [value_loss, action_loss, dist_entropy] + list(
                    aux_loss.values())

                stats = zip(
                    ["count", "reward", "reward_go"],
                    [
                        window_episode_counts, window_episode_reward,
                        window_episode_go_reward
                    ],
                )
                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in stats
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)

                writer.add_scalar("reward_go",
                                  deltas["reward_go"] / deltas["count"],
                                  count_steps)

                writer.add_scalars(
                    "losses",
                    {k: l
                     for l, k in zip(losses, value_names)},
                    count_steps,
                )

                # log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    window_rewards = (window_episode_reward[-1] -
                                      window_episode_reward[0]).sum()
                    window_go_rewards = (window_episode_go_reward[-1] -
                                         window_episode_go_reward[0]).sum()
                    window_counts = (window_episode_counts[-1] -
                                     window_episode_counts[0]).sum()

                    if window_counts > 0:
                        logger.info(
                            "Average window size {} reward: {:3f} reward_go: {:3f}"
                            .format(
                                len(window_episode_reward),
                                (window_rewards / window_counts).item(),
                                (window_go_rewards / window_counts).item(),
                            ))
                        logger.info(
                            f"Aux losses: {list(zip(value_names, losses))}")
                    else:
                        logger.info("No episodes finish in current window")

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
                    count_checkpoints += 1

            self.envs.close()
示例#14
0
class PPOTrainer_(BaseRLTrainer):
    supported_tasks = ["Nav-v0"]

    def __init__(self, config=None):
        super().__init__(config)
        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))
        self._setup_actor_critic_agent(self.config)
        self.rollout = RolloutStorage(self.config.RL.PPO.num_steps,
                                      self.envs.num_envs,
                                      self.envs.observation_spaces[0],
                                      self.envs.action_spaces[0],
                                      self.config.RL.PPO.hidden_size)
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        self.rollout.to(self.device)

    def _setup_actor_critic_agent(self, ppo_cfg):
        print(self.config)
        # len(observation_spaces)==NUM_PROCESSES==4?
        print(self.envs.observation_spaces[0])
        print(self.envs.action_spaces[0])
        self.actor_critic = PointNavBaselinePolicy_(
            cnn_parameter={
                "observation_spaces": self.envs.observation_spaces[0],
                "feature_dim": self.config.RL.PPO.cnn_output_size
            },
            depth_decoder_parameter={},
            rnn_parameter={
                "input_dim": 0,
                "hidden_dim": self.config.RL.PPO.hidden_size,
                "n_layer": 1
            },
            actor_parameter={
                "action_spaces": 4,
            },
            critic_parameter={},
            use_splitnet_auxiliary=self.config.RL.PPO.use_splitnet_auxiliary,
        )
        self.agent = PPO_(
            self.actor_critic,
            self.config,
        )

    def _collect_rollout_step(self, current_episode_reward,
                              running_episode_stats):
        '''
        obss :   "rgb"
                 "depth"
                 "pointgoal_with_gps_compass"
        '''
        ### PASS DATA
        with torch.no_grad():
            obs = {
                k: v[self.rollout.step]
                for k, v in self.rollout.observations.items()
            }
            hidden_states = self.rollout.recurrent_hidden_states[
                self.rollout.step]
            inverse_dones = self.rollout.masks[self.rollout.step]
            (actions_log_probs, actions, value, distributions,
             rnn_hidden_state) = (self.actor_critic.act(
                 obs, hidden_states, inverse_dones))
        ### PASS ENV
        res = self.envs.step([action.item() for action in actions])

        ### Process
        observations, rewards, dones, infos = [list(x) for x in zip(*res)]
        observations = batch_obs(observations, self.device)
        rewards = torch.tensor(rewards, dtype=torch.float,
                               device=self.device).unsqueeze(-1)
        dones = torch.tensor(dones, dtype=torch.float, device=self.device)
        inverse_dones = torch.tensor([[0] if done else [1] for done in dones],
                                     dtype=torch.float,
                                     device=self.device)
        self.rollout.insert(
            observations,
            rnn_hidden_state,
            actions,
            actions_log_probs,
            value,
            rewards,
            inverse_dones,
        )
        # print("rewards:", rewards)
        # print("dones:", dones)
        current_episode_reward += rewards
        running_episode_stats["reward"] += (
            1 - inverse_dones) * current_episode_reward
        running_episode_stats["count"] += 1 - inverse_dones
        for k, v in self._extract_scalars_from_infos(infos).items():
            v = torch.tensor(v,
                             dtype=torch.float,
                             device=current_episode_reward.device).unsqueeze(1)
            if k not in running_episode_stats:
                running_episode_stats[k] = torch.zeros_like(
                    running_episode_stats["count"])

            running_episode_stats[k] += (1 - inverse_dones) * v

        current_episode_reward *= inverse_dones
        return

    def _save_checkpoint(self, checkpoint_name):
        checkpoint_dict = {
            "state_dict": self.agent.state_dict(),
            "config": self.config,
        }
        torch.save(
            checkpoint_dict,
            os.path.join(self.config.CHECKPOINT_FOLDER, checkpoint_name),
        )

    def load_checkpoint(self, checkpoint_path, *args, **kwargs):
        return torch.load(checkpoint_path, *args, **kwargs)

    def _update_agent(self):
        with torch.no_grad():
            last_obs = {
                k: v[self.rollout.step]
                for k, v in self.rollout.observations.items()
            }
            hidden_states = self.rollout.recurrent_hidden_states[
                self.rollout.step]
            inverse_dones = self.rollout.masks[self.rollout.step]
            next_value = self.actor_critic.get_value(last_obs, hidden_states,
                                                     inverse_dones).detach()

        self.rollout.compute_returns(next_value, self.config.RL.PPO.use_gae,
                                     self.config.RL.PPO.gamma,
                                     self.config.RL.PPO.tau)
        loss, loss_auxiliary = self.agent.update(self.rollout)
        self.rollout.after_update()

        return loss, loss_auxiliary

    def train(self) -> None:
        #### init
        obs = self.envs.reset()
        #### 先暫存一個 obs
        batch = batch_obs(obs, device=self.device)
        for sensor in self.rollout.observations:
            self.rollout.observations[sensor][0].copy_(batch[sensor])
        #### Para
        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )
        #### PPO LOG PARA
        count_steps = 0
        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=self.config.RL.PPO.reward_window_size))

        #### 開始訓練
        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for epoch in range(self.config.NUM_UPDATES):
                #### decay
                if self.config.RL.PPO.use_linear_lr_decay:
                    lr_scheduler.step()
                if (epoch + 1) % self.config.CHECKPOINT_INTERVAL == 0:
                    self.agent.entropy_coef = self.agent.entropy_coef * 0.9
                print(self.agent.entropy_coef)

                #### 蒐集rollout
                print("=== collect rollout ===")
                for step in range(self.config.RL.PPO.num_steps):
                    self._collect_rollout_step(current_episode_reward,
                                               running_episode_stats)
                    count_steps += self.envs.num_envs
                #### 更新
                loss, loss_auxiliary = self._update_agent()
                #### LOGGER
                writer.add_scalars("loss", loss, epoch * self.envs.num_envs)
                writer.add_scalars("loss_auxiliary", loss_auxiliary,
                                   epoch * self.envs.num_envs)

                #### PPO LOG PARA
                for k, v in running_episode_stats.items():
                    window_episode_stats[k].append(v.clone())

                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in window_episode_stats.items()
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar("reward", deltas["reward"] / deltas["count"],
                                  count_steps)
                metrics = {
                    k: v / deltas["count"]
                    for k, v in deltas.items() if k not in {"reward", "count"}
                }
                if len(metrics) > 0:
                    writer.add_scalars("metrics", metrics, count_steps)
                #### PPO LOG PARA

                if epoch % (1) == 0:
                    print("count_steps:", count_steps)
                    print("deltas:", deltas)
                    print("metrics:", metrics)
                if epoch % self.config.CHECKPOINT_INTERVAL == 0:
                    self._save_checkpoint(f"checkpoint.{epoch}.pth")

        self.envs.close()

    # 每個checkpoint call一次
    # TEST_EPISODE_COUNT 跑幾次EPISODE
    # 還沒結束就將observations存到rgb_frames中,並且做後處裡,top_down_map...
    # Dones之後,generate_video,存到disk跟tensorboard
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """
        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        logger.info(f"env config: {config}")
        self.envs = construct_envs(config, get_env_class(config.ENV_NAME))
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            1,
            self.config.NUM_PROCESSES,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        pbar = tqdm.tqdm(total=self.config.TEST_EPISODE_COUNT)
        self.actor_critic.eval()
        while (len(stats_episodes) < self.config.TEST_EPISODE_COUNT
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                (_, actions, _, _,
                 test_recurrent_hidden_states) = (self.actor_critic.act(
                     batch, test_recurrent_hidden_states, not_done_masks))

                prev_actions.copy_(actions)

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, device=self.device)
            not_done_masks = torch.tensor(
                [[0.0] if done else [1.0] for done in dones],
                dtype=torch.float,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                if (
                        next_episodes[i].scene_id,
                        next_episodes[i].episode_id,
                ) in stats_episodes:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    pbar.update()
                    episode_stats = dict()
                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i]))
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    stats_episodes[(
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metrics=self._extract_scalars_from_info(infos[i]),
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        num_episodes = len(stats_episodes)
        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            aggregated_stats[stat_key] = (
                sum([v[stat_key]
                     for v in stats_episodes.values()]) / num_episodes)

        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.4f}")

        step_id = checkpoint_index
        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        writer.add_scalars(
            "eval_reward",
            {"average reward": aggregated_stats["reward"]},
            step_id,
        )

        metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"}
        if len(metrics) > 0:
            writer.add_scalars("eval_metrics", metrics, step_id)

        self.envs.close()

    METRICS_BLACKLIST = {"top_down_map", "collisions.is_collision"}

    @classmethod
    def _extract_scalars_from_info(cls, info: Dict[str,
                                                   Any]) -> Dict[str, float]:
        result = {}
        for k, v in info.items():

            if k in cls.METRICS_BLACKLIST:
                continue

            if isinstance(v, dict):
                result.update({
                    k + "." + subk: subv
                    for subk, subv in cls._extract_scalars_from_info(
                        v).items()
                    if (k + "." + subk) not in cls.METRICS_BLACKLIST
                })
            # Things that are scalar-like will have an np.size of 1.
            # Strings also have an np.size of 1, so explicitly ban those
            elif np.size(v) == 1 and not isinstance(v, str):
                result[k] = float(v)

        return result

    @classmethod
    def _extract_scalars_from_infos(
            cls, infos: List[Dict[str, Any]]) -> Dict[str, List[float]]:

        results = defaultdict(list)
        for i in range(len(infos)):
            for k, v in cls._extract_scalars_from_info(infos[i]).items():
                results[k].append(v)

        return results