def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (self.world_rank *
                                         self.config.NUM_PROCESSES)
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        if (not os.path.isdir(self.config.CHECKPOINT_FOLDER)
                and self.world_rank == 0):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.agent.init_distributed(find_unused_params=True)
        if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training:
            self.belief_predictor.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))
            if ppo_cfg.use_belief_predictor:
                logger.info(
                    "belief predictor number of trainable parameters: {}".
                    format(
                        sum(param.numel()
                            for param in self.belief_predictor.parameters()
                            if param.requires_grad)))
            logger.info(f"config: {self.config}")

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        obs_space = self.envs.observation_spaces[0]
        if ppo_cfg.use_external_memory:
            memory_dim = self.actor_critic.net.memory_dim
        else:
            memory_dim = None

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            obs_space,
            self.action_space,
            ppo_cfg.hidden_size,
            ppo_cfg.use_external_memory,
            ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps,
            ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size,
            memory_dim,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
        )
        rollouts.to(self.device)

        if self.config.RL.PPO.use_belief_predictor:
            self.belief_predictor.update(batch, None)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        # Try to resume at previous checkpoint (independent of interrupted states)
        count_steps_start, count_checkpoints, start_update = self.try_to_resume_checkpoint(
        )
        count_steps = count_steps_start

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            if self.config.RL.PPO.use_belief_predictor:
                self.belief_predictor.load_state_dict(
                    interrupted_state["belief_predictor"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        state_dict = dict(
                            state_dict=self.agent.state_dict(),
                            optim_state=self.agent.optimizer.state_dict(),
                            lr_sched_state=lr_scheduler.state_dict(),
                            config=self.config,
                            requeue_stats=requeue_stats,
                        )
                        if self.config.RL.PPO.use_belief_predictor:
                            state_dict[
                                'belief_predictor'] = self.belief_predictor.state_dict(
                                )
                        save_interrupted_state(state_dict)

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                if self.config.RL.PPO.use_belief_predictor:
                    self.belief_predictor.eval()
                for step in range(ppo_cfg.num_steps):

                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self.config.RL.PPO.use_belief_predictor:
                    self.belief_predictor.train()
                    self.belief_predictor.set_eval_encoders()
                if self._static_smt_encoder:
                    self.actor_critic.net.set_eval_encoders()

                if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training:
                    location_predictor_loss, prediction_accuracy = self.train_belief_predictor(
                        rollouts)
                else:
                    location_predictor_loss = 0
                    prediction_accuracy = 0
                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [
                        value_loss, action_loss, dist_entropy,
                        location_predictor_loss, prediction_accuracy,
                        count_steps_delta
                    ],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[5].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[0].item() / self.world_size,
                        stats[1].item() / self.world_size,
                        stats[2].item() / self.world_size,
                        stats[3].item() / self.world_size,
                        stats[4].item() / self.world_size,
                    ]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar("Metrics/reward",
                                      deltas["reward"] / deltas["count"],
                                      count_steps)

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        for metric, value in metrics.items():
                            writer.add_scalar(f"Metrics/{metric}", value,
                                              count_steps)

                    writer.add_scalar("Policy/value_loss", losses[0],
                                      count_steps)
                    writer.add_scalar("Policy/policy_loss", losses[1],
                                      count_steps)
                    writer.add_scalar("Policy/entropy_loss", losses[2],
                                      count_steps)
                    writer.add_scalar("Policy/predictor_loss", losses[3],
                                      count_steps)
                    writer.add_scalar("Policy/predictor_accuracy", losses[4],
                                      count_steps)
                    writer.add_scalar('Policy/learning_rate',
                                      lr_scheduler.get_lr()[0], count_steps)

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            (count_steps - count_steps_start) /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )
                        count_checkpoints += 1

            self.envs.close()
Exemplo n.º 2
0
    def init_rpc(
        name,
        backend=BackendType.PROCESS_GROUP,
        rank=-1,
        world_size=None,
        rpc_backend_options=None,
    ):
        r"""
        Initializes RPC primitives such as the local RPC agent
        and distributed autograd, which immediately makes the current
        process ready to send and receive RPCs.

        Arguments:
            backend (BackendType, optional): The type of RPC backend
                implementation. Supported values include
                ``BackendType.PROCESS_GROUP`` (the default) and
                ``BackendType.TENSORPIPE``. See :ref:`rpc-backends` for more
                information.
            name (str): a globally unique name of this node. (e.g.,
                ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
                Name can only contain number, alphabet, underscore, colon,
                and/or dash, and must be shorter than 128 characters.
            rank (int): a globally unique id/rank of this node.
            world_size (int): The number of workers in the group.
            rpc_backend_options (RpcBackendOptions, optional): The options
                passed to the RpcAgent constructor. It must be an agent-specific
                subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
                and contains agent-specific initialization configurations. By
                default, for all agents, it sets the default timeout to 60
                seconds and performs the rendezvous with an underlying process
                group initialized using ``init_method = "env://"``,
                meaning that environment variables ``MASTER_ADDR`` and
                ``MASTER_PORT`` need to be set properly. See
                :ref:`rpc-backends` for more information and find which options
                are available.
        """

        if not rpc_backend_options:
            # default construct a set of RPC backend options.
            rpc_backend_options = backend_registry.construct_rpc_backend_options(
                backend)

        # Rendezvous.
        # This rendezvous state sometimes is destroyed before all processes
        # finishing handshaking. To avoid that issue, we make it global to
        # keep it alive.
        global rendezvous_iterator
        rendezvous_iterator = torch.distributed.rendezvous(
            rpc_backend_options.init_method, rank=rank, world_size=world_size)
        store, _, _ = next(rendezvous_iterator)

        # Use a PrefixStore to distinguish multiple invocations.
        with _init_counter_lock:
            global _init_counter
            store = dist.PrefixStore(
                str('rpc_prefix_{}'.format(_init_counter)), store)
            _init_counter += 1

        # Initialize autograd before RPC since _init_rpc_backend guarantees all
        # processes sync via the store. If we initialize autograd after RPC,
        # there could be a race where some nodes might have initialized autograd
        # and others might not have. As a result, a node calling
        # torch.distributed.autograd.backward() would run into errors since
        # other nodes might not have been initialized.
        dist_autograd._init(rank)

        _set_profiler_node_id(rank)
        # Initialize RPC.
        api._init_rpc_backend(backend, store, name, rank, world_size,
                              rpc_backend_options)
Exemplo n.º 3
0
 def _create_store(self):
     return c10d.PrefixStore(self.prefix, self.tcpstore)
Exemplo n.º 4
0
 def _create_store(self):
     return dist.PrefixStore(self.prefix, self.filestore)
Exemplo n.º 5
0
    def init_rpc(
        name,
        backend=None,
        rank=-1,
        world_size=None,
        rpc_backend_options=None,
    ):
        r"""
        Initializes RPC primitives such as the local RPC agent
        and distributed autograd, which immediately makes the current
        process ready to send and receive RPCs.

        Args:
            backend (BackendType, optional): The type of RPC backend
                implementation. Supported values include
                ``BackendType.TENSORPIPE`` (the default) and
                ``BackendType.PROCESS_GROUP``. See :ref:`rpc-backends` for more
                information.
            name (str): a globally unique name of this node. (e.g.,
                ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
                Name can only contain number, alphabet, underscore, colon,
                and/or dash, and must be shorter than 128 characters.
            rank (int): a globally unique id/rank of this node.
            world_size (int): The number of workers in the group.
            rpc_backend_options (RpcBackendOptions, optional): The options
                passed to the RpcAgent constructor. It must be an agent-specific
                subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
                and contains agent-specific initialization configurations. By
                default, for all agents, it sets the default timeout to 60
                seconds and performs the rendezvous with an underlying process
                group initialized using ``init_method = "env://"``,
                meaning that environment variables ``MASTER_ADDR`` and
                ``MASTER_PORT`` need to be set properly. See
                :ref:`rpc-backends` for more information and find which options
                are available.
        """

        if backend is not None and not isinstance(
                backend, backend_registry.BackendType):
            raise TypeError("Argument backend must be a member of BackendType")

        if rpc_backend_options is not None and not isinstance(
                rpc_backend_options, RpcBackendOptions):
            raise TypeError(
                "Argument rpc_backend_options must be an instance of RpcBackendOptions"
            )

        # To avoid breaking users that passed a ProcessGroupRpcBackendOptions
        # without specifying the backend as PROCESS_GROUP when that was the
        # default, we try to detect the backend from the options when only the
        # latter is passed.
        if backend is None and rpc_backend_options is not None:
            for candidate_backend in BackendType:
                if isinstance(
                        rpc_backend_options,
                        type(
                            backend_registry.construct_rpc_backend_options(
                                candidate_backend)),
                ):
                    backend = candidate_backend
                    break
            else:
                raise TypeError(
                    f"Could not infer backend for options {rpc_backend_options}"
                )
            # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865)
            if backend != BackendType.TENSORPIPE:  # type: ignore[attr-defined]
                logger.warning(
                    f"RPC was initialized with no explicit backend but with options "  # type: ignore[attr-defined]
                    f"corresponding to {backend}, hence that backend will be used "
                    f"instead of the default {BackendType.TENSORPIPE}. To silence this "
                    f"warning pass `backend={backend}` explicitly.")

        if backend is None:
            backend = BackendType.TENSORPIPE  # type: ignore[attr-defined]

        if backend == BackendType.PROCESS_GROUP:  # type: ignore[attr-defined]
            logger.warning(
                "RPC was initialized with the PROCESS_GROUP backend which is "
                "deprecated and slated to be removed and superseded by the TENSORPIPE "
                "backend. It is recommended to migrate to the TENSORPIPE backend."
            )

        if rpc_backend_options is None:
            # default construct a set of RPC backend options.
            rpc_backend_options = backend_registry.construct_rpc_backend_options(
                backend)

        # Rendezvous.
        # This rendezvous state sometimes is destroyed before all processes
        # finishing handshaking. To avoid that issue, we make it global to
        # keep it alive.
        global rendezvous_iterator
        rendezvous_iterator = torch.distributed.rendezvous(
            rpc_backend_options.init_method, rank=rank, world_size=world_size)
        store, _, _ = next(rendezvous_iterator)

        # Use a PrefixStore to distinguish multiple invocations.
        with _init_counter_lock:
            global _init_counter
            store = dist.PrefixStore(
                str('rpc_prefix_{}'.format(_init_counter)), store)
            _init_counter += 1

        # Initialize autograd before RPC since _init_rpc_backend guarantees all
        # processes sync via the store. If we initialize autograd after RPC,
        # there could be a race where some nodes might have initialized autograd
        # and others might not have. As a result, a node calling
        # torch.distributed.autograd.backward() would run into errors since
        # other nodes might not have been initialized.
        dist_autograd._init(rank)

        _set_profiler_node_id(rank)
        # Initialize RPC.
        _init_rpc_backend(backend, store, name, rank, world_size,
                          rpc_backend_options)
Exemplo n.º 6
0
    def train(self, ckpt_path="", ckpt=-1, start_updates=0) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        random.seed(self.config.TASK_CONFIG.SEED + self.world_rank)
        np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank)

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        self.config.freeze()

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        task_cfg = self.config.TASK_CONFIG.TASK

        observation_space = self.envs.observation_spaces[0]
        aux_cfg = self.config.RL.AUX_TASKS
        init_aux_tasks, num_recurrent_memories, aux_task_strings = self._setup_auxiliary_tasks(
            aux_cfg, ppo_cfg, task_cfg, observation_space)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            observation_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_memories=num_recurrent_memories)
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        self._setup_actor_critic_agent(ppo_cfg, task_cfg, aux_cfg,
                                       init_aux_tasks)
        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),  # including bonus
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        prev_time = 0

        if ckpt != -1:
            logger.info(
                f"Resuming runs at checkpoint {ckpt}. Timing statistics are not tracked properly."
            )
            assert ppo_cfg.use_linear_lr_decay is False and ppo_cfg.use_linear_clip_decay is False, "Resuming with decay not supported"
            # This is the checkpoint we start saving at
            count_checkpoints = ckpt + 1
            count_steps = start_updates * ppo_cfg.num_steps * self.config.NUM_PROCESSES
            ckpt_dict = self.load_checkpoint(ckpt_path, map_location="cpu")
            self.agent.load_state_dict(ckpt_dict["state_dict"])
            if "optim_state" in ckpt_dict:
                self.agent.optimizer.load_state_dict(ckpt_dict["optim_state"])
            else:
                logger.warn("No optimizer state loaded, results may be funky")
            if "extra_state" in ckpt_dict and "step" in ckpt_dict[
                    "extra_state"]:
                count_steps = ckpt_dict["extra_state"]["step"]

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_updates = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:

            for update in range(start_updates, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)
                self.agent.train()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                    aux_task_losses,
                    aux_dist_entropy,
                    aux_weights,
                ) = self._update_agent(ppo_cfg, rollouts)

                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering],
                    0).to(self.device)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [
                        dist_entropy,
                        aux_dist_entropy,
                    ] + [value_loss, action_loss] + aux_task_losses +
                    [count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                if aux_weights is not None and len(aux_weights) > 0:
                    distrib.all_reduce(
                        torch.tensor(aux_weights, device=self.device))
                count_steps += stats[-1].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    avg_stats = [
                        stats[i].item() / self.world_size
                        for i in range(len(stats) - 1)
                    ]
                    losses = avg_stats[2:]
                    dist_entropy, aux_dist_entropy = avg_stats[:2]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    writer.add_scalar(
                        "entropy",
                        dist_entropy,
                        count_steps,
                    )

                    writer.add_scalar("aux_entropy", aux_dist_entropy,
                                      count_steps)

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {
                            k: l
                            for l, k in zip(losses, ["value", "policy"] +
                                            aux_task_strings)
                        },
                        count_steps,
                    )

                    writer.add_scalars(
                        "aux_weights",
                        {k: l
                         for l, k in zip(aux_weights, aux_task_strings)},
                        count_steps,
                    )

                    # Log stats
                    formatted_aux_losses = [
                        "{:.3g}".format(l) for l in aux_task_losses
                    ]
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info(
                            "update: {}\tvalue_loss: {:.3g}\t action_loss: {:.3g}\taux_task_loss: {} \t aux_entropy {:.3g}\t"
                            .format(
                                update,
                                value_loss,
                                action_loss,
                                formatted_aux_losses,
                                aux_dist_entropy,
                            ))
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            count_steps /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"{self.checkpoint_prefix}.{count_checkpoints}.pth",
                            dict(step=count_steps))
                        count_checkpoints += 1

        self.envs.close()
Exemplo n.º 7
0
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        import apex

        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend
        )
        # add_signal_handlers()
        self.timing = Timing()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        set_cpus(self.local_rank, self.world_size)

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += self.world_rank * self.config.SIM_BATCH_SIZE
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        double_buffered = False
        self._num_worker_groups = self.config.NUM_PARALLEL_SCENES

        self._depth = self.config.DEPTH
        self._color = self.config.COLOR

        if self.config.TASK.lower() == "pointnav":
            self.observation_space = SpaceDict(
                {
                    "pointgoal_with_gps_compass": spaces.Box(
                        low=0.0, high=1.0, shape=(2,), dtype=np.float32
                    )
                }
            )
        else:
            self.observation_space = SpaceDict({})

        self.action_space = spaces.Discrete(4)

        if self._color:
            self.observation_space = SpaceDict(
                {
                    "rgb": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=(3, *self.config.RESOLUTION),
                        dtype=np.uint8,
                    ),
                    **self.observation_space.spaces,
                }
            )

        if self._depth:
            self.observation_space = SpaceDict(
                {
                    "depth": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=(1, *self.config.RESOLUTION),
                        dtype=np.float32,
                    ),
                    **self.observation_space.spaces,
                }
            )

        ppo_cfg = self.config.RL.PPO
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0:
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.count_steps = 0
        burn_steps = 0
        burn_time = 0
        count_checkpoints = 0
        prev_time = 0
        self.update = 0

        LR_SCALE = (
            max(
                np.sqrt(
                    ppo_cfg.num_steps
                    * self.config.SIM_BATCH_SIZE
                    * ppo_cfg.num_accumulate_steps
                    / ppo_cfg.num_mini_batch
                    * self.world_size
                    / (128 * 2)
                ),
                1.0,
            )
            if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale)
            else 1.0
        )

        def cosine_decay(x):
            if x < 1:
                return (np.cos(x * np.pi) + 1.0) / 2.0
            else:
                return 0.0

        def warmup_fn(x):
            return LR_SCALE * (0.5 + 0.5 * x)

        def decay_fn(x):
            return LR_SCALE * (DECAY_TARGET + (1 - DECAY_TARGET) * cosine_decay(x))

        DECAY_TARGET = (
            0.01 / LR_SCALE
            if self.config.RL.PPO.ada_scale or True
            else (0.25 / LR_SCALE if self.config.RL.DDPPO.scale_lr else 1.0)
        )
        DECAY_PERCENT = 1.0 if self.config.RL.PPO.ada_scale or True else 0.5
        WARMUP_PERCENT = (
            0.01
            if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale)
            else 0.0
        )

        def lr_fn():
            x = self.percent_done()
            if x < WARMUP_PERCENT:
                return warmup_fn(x / WARMUP_PERCENT)
            else:
                return decay_fn((x - WARMUP_PERCENT) / DECAY_PERCENT)

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer, lr_lambda=lambda x: lr_fn()
        )

        interrupted_state = load_interrupted_state(resume_from=self.resume_from)
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])

        self.agent.init_amp(self.config.SIM_BATCH_SIZE)
        self.actor_critic.init_trt(self.config.SIM_BATCH_SIZE)
        self.actor_critic.script_net()
        self.agent.init_distributed(find_unused_params=False)

        if self.world_rank == 0:
            logger.info(
                "agent number of trainable parameters: {}".format(
                    sum(
                        param.numel()
                        for param in self.agent.parameters()
                        if param.requires_grad
                    )
                )
            )

        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            self.observation_space = SpaceDict(
                {
                    "visual_features": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=self._encoder.output_shape,
                        dtype=np.float32,
                    ),
                    **self.observation_space,
                }
            )
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)

        nenvs = self.config.SIM_BATCH_SIZE
        rollouts = DoubleBufferedRolloutStorage(
            ppo_cfg.num_steps,
            nenvs,
            self.observation_space,
            self.action_space,
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.num_recurrent_layers,
            use_data_aug=ppo_cfg.use_data_aug,
            aug_type=ppo_cfg.aug_type,
            double_buffered=double_buffered,
            vtrace=ppo_cfg.vtrace,
        )
        rollouts.to(self.device)
        rollouts.to_fp16()

        self._warmup(rollouts)

        (
            self.envs,
            self._observations,
            self._rewards,
            self._masks,
            self._rollout_infos,
            self._syncs,
        ) = construct_envs(
            self.config,
            num_worker_groups=self.config.NUM_PARALLEL_SCENES,
            double_buffered=double_buffered,
        )

        def _setup_render_and_populate_initial_frame():
            for idx in range(2 if double_buffered else 1):
                self.envs.reset(idx)

                batch = self._observations[idx]
                self._syncs[idx].wait()

                tree_copy_in_place(
                    tree_select(0, rollouts[idx].storage_buffers["observations"]),
                    batch,
                )

        _setup_render_and_populate_initial_frame()

        current_episode_reward = torch.zeros(nenvs, 1)
        running_episode_stats = dict(
            count=torch.zeros(nenvs, 1,), reward=torch.zeros(nenvs, 1,),
        )

        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )
        time_per_frame_window = deque(maxlen=ppo_cfg.reward_window_size)

        buffer_ranges = []
        for i in range(2 if double_buffered else 1):
            start_ind = buffer_ranges[-1].stop if i > 0 else 0
            buffer_ranges.append(
                slice(
                    start_ind,
                    start_ind
                    + self.config.SIM_BATCH_SIZE // (2 if double_buffered else 1),
                )
            )

        if interrupted_state is not None:
            requeue_stats = interrupted_state["requeue_stats"]

            self.count_steps = requeue_stats["count_steps"]
            self.update = requeue_stats["start_update"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            prev_time = requeue_stats["prev_time"]
            burn_steps = requeue_stats["burn_steps"]
            burn_time = requeue_stats["burn_time"]

            self.agent.ada_scale.load_state_dict(interrupted_state["ada_scale_state"])

            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            if "amp_state" in interrupted_state:
                apex.amp.load_state_dict(interrupted_state["amp_state"])

            if "grad_scaler_state" in interrupted_state:
                self.agent.grad_scaler.load_state_dict(
                    interrupted_state["grad_scaler_state"]
                )

        with (
            TensorboardWriter(
                self.config.TENSORBOARD_DIR,
                flush_secs=self.flush_secs,
                purge_step=int(self.count_steps),
            )
            if self.world_rank == 0
            else contextlib.suppress()
        ) as writer:
            distrib.barrier()
            t_start = time.time()
            while not self.is_done():
                t_rollout_start = time.time()
                if self.update == BURN_IN_UPDATES:
                    burn_time = t_rollout_start - t_start
                    burn_steps = self.count_steps

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        self.percent_done(), final_decay=ppo_cfg.decay_factor,
                    )

                if (
                    not BPS_BENCHMARK
                    and (REQUEUE.is_set() or ((self.update + 1) % 100) == 0)
                    and self.world_rank == 0
                ):
                    requeue_stats = dict(
                        count_steps=self.count_steps,
                        count_checkpoints=count_checkpoints,
                        start_update=self.update,
                        prev_time=(time.time() - t_start) + prev_time,
                        burn_time=burn_time,
                        burn_steps=burn_steps,
                    )

                    def _cast(param):
                        if "Half" in param.type():
                            param = param.to(dtype=torch.float32)

                        return param

                    save_interrupted_state(
                        dict(
                            state_dict={
                                k: _cast(v) for k, v in self.agent.state_dict().items()
                            },
                            ada_scale_state=self.agent.ada_scale.state_dict(),
                            lr_sched_state=lr_scheduler.state_dict(),
                            config=self.config,
                            requeue_stats=requeue_stats,
                            grad_scaler_state=self.agent.grad_scaler.state_dict(),
                        )
                    )

                if EXIT.is_set():
                    self._observations = None
                    self._rewards = None
                    self._masks = None
                    self._rollout_infos = None
                    self._syncs = None

                    del self.envs
                    self.envs = None

                    requeue_job()
                    return

                self.agent.eval()

                count_steps_delta = self._n_buffered_sampling(
                    rollouts,
                    current_episode_reward,
                    running_episode_stats,
                    buffer_ranges,
                    ppo_cfg.num_steps,
                    num_rollouts_done_store,
                )

                num_rollouts_done_store.add("num_done", 1)

                if not rollouts.vtrace:
                    self._compute_returns(ppo_cfg, rollouts)

                (value_loss, action_loss, dist_entropy) = self._update_agent(rollouts)

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                lr_scheduler.step()

                with self.timing.add_time("Logging"):
                    stats_ordering = list(sorted(running_episode_stats.keys()))
                    stats = torch.stack(
                        [running_episode_stats[k] for k in stats_ordering], 0,
                    ).to(device=self.device)
                    distrib.all_reduce(stats)
                    stats = stats.to(device="cpu")

                    for i, k in enumerate(stats_ordering):
                        window_episode_stats[k].append(stats[i])

                    stats = torch.tensor(
                        [
                            value_loss,
                            action_loss,
                            count_steps_delta,
                            *self.envs.swap_stats,
                        ],
                        device=self.device,
                    )
                    distrib.all_reduce(stats)
                    stats = stats.to(device="cpu")
                    count_steps_delta = int(stats[2].item())
                    self.count_steps += count_steps_delta

                    time_per_frame_window.append(
                        (time.time() - t_rollout_start) / count_steps_delta
                    )

                    if self.world_rank == 0:
                        losses = [
                            stats[0].item() / self.world_size,
                            stats[1].item() / self.world_size,
                        ]
                        deltas = {
                            k: (
                                (v[-1] - v[0]).sum().item()
                                if len(v) > 1
                                else v[0].sum().item()
                            )
                            for k, v in window_episode_stats.items()
                        }
                        deltas["count"] = max(deltas["count"], 1.0)

                        writer.add_scalar(
                            "reward",
                            deltas["reward"] / deltas["count"],
                            self.count_steps,
                        )

                        # Check to see if there are any metrics
                        # that haven't been logged yet
                        metrics = {
                            k: v / deltas["count"]
                            for k, v in deltas.items()
                            if k not in {"reward", "count"}
                        }
                        if len(metrics) > 0:
                            writer.add_scalars("metrics", metrics, self.count_steps)

                        writer.add_scalars(
                            "losses",
                            {k: l for l, k in zip(losses, ["value", "policy"])},
                            self.count_steps,
                        )

                        optim = self.agent.optimizer
                        writer.add_scalar(
                            "optimizer/base_lr",
                            optim.param_groups[-1]["lr"],
                            self.count_steps,
                        )
                        if "gain" in optim.param_groups[-1]:
                            for idx, group in enumerate(optim.param_groups):
                                writer.add_scalar(
                                    f"optimizer/lr_{idx}",
                                    group["lr"] * group["gain"],
                                    self.count_steps,
                                )
                                writer.add_scalar(
                                    f"optimizer/gain_{idx}",
                                    group["gain"],
                                    self.count_steps,
                                )

                        # log stats
                        if (
                            self.update > 0
                            and self.update % self.config.LOG_INTERVAL == 0
                        ):
                            logger.info(
                                "update: {}\twindow fps: {:.3f}\ttotal fps: {:.3f}\tframes: {}".format(
                                    self.update,
                                    1.0
                                    / (
                                        sum(time_per_frame_window)
                                        / len(time_per_frame_window)
                                    ),
                                    (self.count_steps - burn_steps)
                                    / ((time.time() - t_start) + prev_time - burn_time),
                                    self.count_steps,
                                )
                            )

                            logger.info(
                                "swap percent: {:.3f}\tscenes in use: {:.3f}\tenvs per scene: {:.3f}".format(
                                    stats[3].item() / self.world_size,
                                    stats[4].item() / self.world_size,
                                    stats[5].item() / self.world_size,
                                )
                            )

                            logger.info(
                                "Average window size: {}  {}".format(
                                    len(window_episode_stats["count"]),
                                    "  ".join(
                                        "{}: {:.3f}".format(k, v / deltas["count"])
                                        for k, v in deltas.items()
                                        if k != "count"
                                    ),
                                )
                            )

                            logger.info(self.timing)
                            # self.envs.print_renderer_stats()

                        # checkpoint model
                        if self.should_checkpoint():
                            self.save_checkpoint(
                                f"ckpt.{count_checkpoints}.pth",
                                dict(
                                    step=self.count_steps,
                                    wall_clock_time=(
                                        (time.time() - t_start) + prev_time
                                    ),
                                ),
                            )
                            count_checkpoints += 1

                self.update += 1

            self.save_checkpoint(
                "ckpt.done.pth",
                dict(
                    step=self.count_steps,
                    wall_clock_time=((time.time() - t_start) + prev_time),
                ),
            )
            self._observations = None
            self._rewards = None
            self._masks = None
            self._rollout_infos = None
            self._syncs = None
            del self.envs
            self.envs = None
Exemplo n.º 8
0
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        random.seed(self.config.TASK_CONFIG.SEED + self.world_rank)
        np.random.seed(self.config.TASK_CONFIG.SEED + self.world_rank)

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        self.config.freeze()

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        if (not os.path.isdir(self.config.CHECKPOINT_FOLDER)
                and self.world_rank == 0):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))

        observations = self.envs.reset()
        batch = batch_obs(observations)

        obs_space = self.envs.observation_spaces[0]
        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            obs_space = SpaceDict({
                "visual_features":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=self._encoder.output_shape,
                    dtype=np.float32,
                ),
                **obs_space.spaces,
            })
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            obs_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
        )
        rollouts.to(self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        episode_rewards = torch.zeros(self.envs.num_envs,
                                      1,
                                      device=self.device)
        episode_counts = torch.zeros(self.envs.num_envs, 1, device=self.device)
        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)
        window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):

                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(
                        rollouts,
                        current_episode_reward,
                        episode_rewards,
                        episode_counts,
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self._static_encoder:
                    self._encoder.eval()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats = torch.stack([episode_rewards, episode_counts], 0)
                distrib.all_reduce(stats)

                window_episode_reward.append(stats[0].clone())
                window_episode_counts.append(stats[1].clone())

                stats = torch.tensor(
                    [value_loss, action_loss, count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[2].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[0].item() / self.world_size,
                        stats[1].item() / self.world_size,
                    ]
                    stats = zip(
                        ["count", "reward"],
                        [window_episode_counts, window_episode_reward],
                    )
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in stats
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    writer.add_scalars(
                        "losses",
                        {k: l
                         for l, k in zip(losses, ["value", "policy"])},
                        count_steps,
                    )

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            count_steps /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))

                        window_rewards = (window_episode_reward[-1] -
                                          window_episode_reward[0]).sum()
                        window_counts = (window_episode_counts[-1] -
                                         window_episode_counts[0]).sum()

                        if window_counts > 0:
                            logger.info(
                                "Average window size {} reward: {:3f}".format(
                                    len(window_episode_reward),
                                    (window_rewards / window_counts).item(),
                                ))
                        else:
                            logger.info("No episodes finish in current window")

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )
                        count_checkpoints += 1

            self.envs.close()
Exemplo n.º 9
0
    def train(self) -> None:
        r"""Main method for DD-PPO SLAM.

        Returns:
            None
        """

        #####################################################################
        ## init distrib and configuration #####################################################################
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend
        )
        # self.local_rank = 1
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore(
            "rollout_tracker", tcp_store
        )
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank() # server number
        self.world_size = distrib.get_world_size() 

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank # gpu number in one server
        self.config.SIMULATOR_GPU_ID = self.local_rank
        print("********************* TORCH_GPU_ID: ", self.config.TORCH_GPU_ID)
        print("********************* SIMULATOR_GPU_ID: ", self.config.SIMULATOR_GPU_ID)

        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (
            self.world_rank * self.config.NUM_PROCESSES
        )
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")


        #####################################################################
        ## build distrib NavSLAMRLEnv environment
        #####################################################################
        print("#############################################################")
        print("## build distrib NavSLAMRLEnv environment")
        print("#############################################################")
        self.envs = construct_envs(
            self.config, get_env_class(self.config.ENV_NAME)
        )
        observations = self.envs.reset()
        print("*************************** observations len:", len(observations))

        # semantic process
        for i in range(len(observations)):
            observations[i]["semantic"] = observations[i]["semantic"].astype(np.int32)
            se = list(set(observations[i]["semantic"].ravel()))
            print(se)
        # print("*************************** observations type:", observations)
        # print("*************************** observations type:", observations[0]["map_sum"].shape) # 480*480*23
        # print("*************************** observations curr_pose:", observations[0]["curr_pose"]) # []

        batch = batch_obs(observations, device=self.device)
        print("*************************** batch len:", len(batch))
        # print("*************************** batch:", batch)

        # print("************************************* current_episodes:", (self.envs.current_episodes()))

        #####################################################################
        ## init actor_critic agent
        #####################################################################  
        print("#############################################################")
        print("## init actor_critic agent")
        print("#############################################################")
        self.map_w = observations[0]["map_sum"].shape[0]
        self.map_h = observations[0]["map_sum"].shape[1]
        # print("map_: ", observations[0]["curr_pose"].shape)


        ppo_cfg = self.config.RL.PPO
        if (
            not os.path.isdir(self.config.CHECKPOINT_FOLDER)
            and self.world_rank == 0
        ):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(observations, ppo_cfg)

        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info(
                "agent number of trainable parameters: {}".format(
                    sum(
                        param.numel()
                        for param in self.agent.parameters()
                        if param.requires_grad
                    )
                )
            )

        #####################################################################
        ## init Global Rollout Storage
        #####################################################################  
        print("#############################################################")
        print("## init Global Rollout Storage")
        print("#############################################################") 
        self.num_each_global_step = self.config.RL.SLAMDDPPO.num_each_global_step
        rollouts = GlobalRolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.obs_space,
            self.g_action_space,
        )
        rollouts.to(self.device)

        print('rollouts type:', type(rollouts))
        print('--------------------------')
        # for k in rollouts.keys():
        # print("rollouts: {0}".format(rollouts.observations))

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        with torch.no_grad():
            step_observation = {
                k: v[rollouts.step] for k, v in rollouts.observations.items()
            }
    
            _, actions, _, = self.actor_critic.act(
                step_observation,
                rollouts.prev_g_actions[0],
                rollouts.masks[0],
            )

        self.global_goals = [[int(action[0].item() * self.map_w), 
                            int(action[1].item() * self.map_h)]
                            for action in actions]

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(
            self.envs.num_envs, 1, device=self.device
        )
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )

        print("*************************** current_episode_reward:", current_episode_reward)
        print("*************************** running_episode_stats:", running_episode_stats)
        # print("*************************** window_episode_stats:", window_episode_stats)


        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        # interrupted_state = load_interrupted_state("/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth")
        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"]
            )
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        deif = {}
        with (
            TensorboardWriter(
                self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs
            )
            if self.world_rank == 0
            else contextlib.suppress()
        ) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES
                    )
                # print("************************************* current_episodes:", type(self.envs.count_episodes()))
                
                # print(EXIT.is_set())
                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ),
                            "/home/cirlab1/userdir/ybg/projects/habitat-api/data/interrup.pth"
                        )
                    print("********************EXIT*********************")

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                for step in range(ppo_cfg.num_steps):
                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_global_rollout_step(
                        rollouts, current_episode_reward, running_episode_stats
                    )
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # print("************************************* current_episodes:")

                    for i in range(len(self.envs.current_episodes())):
                        # print(" ", self.envs.current_episodes()[i].episode_id," ", self.envs.current_episodes()[i].scene_id," ", self.envs.current_episodes()[i].object_category)
                        if self.envs.current_episodes()[i].scene_id not in deif:
                            deif[self.envs.current_episodes()[i].scene_id]=[int(self.envs.current_episodes()[i].episode_id)]
                        else:
                            deif[self.envs.current_episodes()[i].scene_id].append(int(self.envs.current_episodes()[i].episode_id))


                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (
                        step
                        >= ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                    ) and int(num_rollouts_done_store.get("num_done")) > (
                        self.config.RL.DDPPO.sync_frac * self.world_size
                    ):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self._static_encoder:
                    self._encoder.eval()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0
                )
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [value_loss, action_loss, count_steps_delta],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[2].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[0].item() / self.world_size,
                        stats[1].item() / self.world_size,
                    ]
                    deltas = {
                        k: (
                            (v[-1] - v[0]).sum().item()
                            if len(v) > 1
                            else v[0].sum().item()
                        )
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {k: l for l, k in zip(losses, ["value", "policy"])},
                        count_steps,
                    )

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info(
                            "update: {}\tfps: {:.3f}\t".format(
                                update,
                                count_steps
                                / ((time.time() - t_start) + prev_time),
                            )
                        )

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(
                                update, env_time, pth_time, count_steps
                            )
                        )
                        logger.info(
                            "Average window size: {}  {}".format(
                                len(window_episode_stats["count"]),
                                "  ".join(
                                    "{}: {:.3f}".format(k, v / deltas["count"])
                                    for k, v in deltas.items()
                                    if k != "count"
                                ),
                            )
                        )

                        # for k in deif:
                        #     deif[k] = list(set(deif[k]))
                        #     deif[k].sort()
                        #     print("deif: k", k, " : ", deif[k])

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )
                        print('=' * 20 + 'Save Model' + '=' * 20)
                        logger.info(
                            "Save Model : {}".format(count_checkpoints)
                        )
                        count_checkpoints += 1

            self.envs.close()
Exemplo n.º 10
0
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        self.config.TASK_CONFIG.TASK.POINTGOAL_WITH_EGO_PREDICTION_SENSOR.MODEL.GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (self.world_rank *
                                         self.config.NUM_PROCESSES)
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(
            self.config,
            get_env_class(self.config.ENV_NAME),
            workers_ignore_signals=True,
        )

        ppo_cfg = self.config.RL.PPO
        if (not os.path.isdir(self.config.CHECKPOINT_FOLDER)
                and self.world_rank == 0):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        obs_space = self.envs.observation_spaces[0]
        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            obs_space = SpaceDict({
                "visual_features":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=self._encoder.output_shape,
                    dtype=np.float32,
                ),
                "prev_visual_features":
                spaces.Box(
                    low=np.iinfo(np.uint32).min,
                    high=np.iinfo(np.uint32).max,
                    shape=self._encoder.output_shape,
                    dtype=np.float32,
                ),
                **obs_space.spaces,
            })
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)
                batch["prev_visual_features"] = torch.zeros_like(
                    batch["visual_features"])

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            obs_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
        )
        rollouts.to(self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])
            rollouts.previous_observations[sensor][0].copy_(
                torch.zeros_like(batch[sensor]))

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                self.actor_critic.update_ib_beta(count_steps)

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()

                for step in range(ppo_cfg.num_steps):

                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self._static_encoder:
                    self._encoder.eval()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                    aux_loss,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [
                        value_loss,
                        action_loss,
                        aux_loss,
                        AuxLosses.get_loss("egomotion_error"),
                        AuxLosses.get_loss("information"),
                        0.1 * dist_entropy,
                        count_steps_delta,
                    ],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[-1].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[i].item() / self.world_size
                        for i in range(stats.size(0) - 1)
                    ]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    print(deltas["reward"])

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {
                            k: l
                            for l, k in zip(losses, [
                                "value", "policy", "aux", "egomotion_error",
                                "information", "entropy"
                            ])
                        },
                        count_steps,
                    )

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            count_steps /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )

                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )

                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                        count_checkpoints += 1

            self.envs.close()