示例#1
0
    def update(self, rollouts: RolloutStorage) -> Tuple[float, float, float]:
        advantages = self.get_advantages(rollouts)

        value_loss_epoch = 0.0
        action_loss_epoch = 0.0
        dist_entropy_epoch = 0.0

        for _e in range(self.ppo_epoch):
            profiling_wrapper.range_push("PPO.update epoch")
            data_generator = rollouts.recurrent_generator(
                advantages, self.num_mini_batch)

            for batch in data_generator:
                (
                    values,
                    action_log_probs,
                    dist_entropy,
                    _,
                ) = self.actor_critic.evaluate_actions(
                    batch["observations"],
                    batch["recurrent_hidden_states"],
                    batch["prev_actions"],
                    batch["masks"],
                    batch["actions"],
                )

                ratio = torch.exp(action_log_probs - batch["action_log_probs"])
                surr1 = ratio * batch["advantages"]
                surr2 = (torch.clamp(ratio, 1.0 - self.clip_param, 1.0 +
                                     self.clip_param) * batch["advantages"])
                action_loss = -(torch.min(surr1, surr2).mean())

                if self.use_clipped_value_loss:
                    value_pred_clipped = batch["value_preds"] + (
                        values - batch["value_preds"]).clamp(
                            -self.clip_param, self.clip_param)
                    value_losses = (values - batch["returns"]).pow(2)
                    value_losses_clipped = (value_pred_clipped -
                                            batch["returns"]).pow(2)
                    value_loss = 0.5 * torch.max(value_losses,
                                                 value_losses_clipped)
                else:
                    value_loss = 0.5 * (batch["returns"] - values).pow(2)

                value_loss = value_loss.mean()
                dist_entropy = dist_entropy.mean()

                self.optimizer.zero_grad()
                total_loss = (value_loss * self.value_loss_coef + action_loss -
                              dist_entropy * self.entropy_coef)

                self.before_backward(total_loss)
                total_loss.backward()
                self.after_backward(total_loss)

                self.before_step()
                self.optimizer.step()
                self.after_step()

                value_loss_epoch += value_loss.item()
                action_loss_epoch += action_loss.item()
                dist_entropy_epoch += dist_entropy.item()

            profiling_wrapper.range_pop()  # PPO.update epoch

        num_updates = self.ppo_epoch * self.num_mini_batch

        value_loss_epoch /= num_updates
        action_loss_epoch /= num_updates
        dist_entropy_epoch /= num_updates

        return value_loss_epoch, action_loss_epoch, dist_entropy_epoch
示例#2
0
def _worker_fn(world_rank: int, world_size: int, port: int,
               unused_params: bool):
    device = (torch.device("cuda")
              if torch.cuda.is_available() else torch.device("cpu"))
    tcp_store = distrib.TCPStore(  # type: ignore
        "127.0.0.1", port, world_size, world_rank == 0)
    distrib.init_process_group("gloo",
                               store=tcp_store,
                               rank=world_rank,
                               world_size=world_size)

    config = get_config("habitat_baselines/config/test/ppo_pointnav_test.yaml")
    obs_space = gym.spaces.Dict({
        IntegratedPointGoalGPSAndCompassSensor.cls_uuid:
        gym.spaces.Box(
            low=np.finfo(np.float32).min,
            high=np.finfo(np.float32).max,
            shape=(2, ),
            dtype=np.float32,
        )
    })
    action_space = ActionSpace({"move": EmptySpace()})
    actor_critic = PointNavBaselinePolicy.from_config(config, obs_space,
                                                      action_space)
    # This use adds some arbitrary parameters that aren't part of the computation
    # graph, so they will mess up DDP if they aren't correctly ignored by it
    if unused_params:
        actor_critic.unused = nn.Linear(64, 64)

    actor_critic.to(device=device)
    ppo_cfg = config.RL.PPO
    agent = DDPPO(
        actor_critic=actor_critic,
        clip_param=ppo_cfg.clip_param,
        ppo_epoch=ppo_cfg.ppo_epoch,
        num_mini_batch=ppo_cfg.num_mini_batch,
        value_loss_coef=ppo_cfg.value_loss_coef,
        entropy_coef=ppo_cfg.entropy_coef,
        lr=ppo_cfg.lr,
        eps=ppo_cfg.eps,
        max_grad_norm=ppo_cfg.max_grad_norm,
        use_normalized_advantage=ppo_cfg.use_normalized_advantage,
    )
    agent.init_distributed()
    rollouts = RolloutStorage(
        ppo_cfg.num_steps,
        2,
        obs_space,
        action_space,
        ppo_cfg.hidden_size,
        num_recurrent_layers=actor_critic.net.num_recurrent_layers,
        is_double_buffered=False,
    )
    rollouts.to(device)

    for k, v in rollouts.buffers["observations"].items():
        rollouts.buffers["observations"][k] = torch.randn_like(v)

    # Add two steps so batching works
    rollouts.advance_rollout()
    rollouts.advance_rollout()

    # Get a single batch
    batch = next(rollouts.recurrent_generator(rollouts.buffers["returns"], 1))

    # Call eval actions through the internal wrapper that is used in
    # agent.update
    value, action_log_probs, dist_entropy, _ = agent._evaluate_actions(
        batch["observations"],
        batch["recurrent_hidden_states"],
        batch["prev_actions"],
        batch["masks"],
        batch["actions"],
    )
    # Backprop on things
    (value.mean() + action_log_probs.mean() + dist_entropy.mean()).backward()

    # Make sure all ranks have very similar parameters
    for param in actor_critic.parameters():
        if param.grad is not None:
            grads = [param.grad.detach().clone() for _ in range(world_size)]
            distrib.all_gather(grads, grads[world_rank])

            for i in range(world_size):
                assert torch.isclose(grads[i], grads[world_rank]).all()
示例#3
0
    def update(self, rollouts: RolloutStorage) -> Tuple[float, float, float]:
        advantages = self.get_advantages(rollouts)

        value_loss_epoch = 0.0
        action_loss_epoch = 0.0
        dist_entropy_epoch = 0.0

        for _e in range(self.ppo_epoch):
            data_generator = rollouts.recurrent_generator(
                advantages, self.num_mini_batch)

            for sample in data_generator:
                (
                    obs_batch,
                    recurrent_hidden_states_batch,
                    actions_batch,
                    prev_actions_batch,
                    value_preds_batch,
                    return_batch,
                    masks_batch,
                    old_action_log_probs_batch,
                    adv_targ,
                ) = sample

                # Reshape to do in a single forward pass for all steps
                (
                    values,
                    action_log_probs,
                    dist_entropy,
                    _,
                ) = self.actor_critic.evaluate_actions(
                    obs_batch,
                    recurrent_hidden_states_batch,
                    prev_actions_batch,
                    masks_batch,
                    actions_batch,
                )

                ratio = torch.exp(action_log_probs -
                                  old_action_log_probs_batch)
                surr1 = ratio * adv_targ
                surr2 = (torch.clamp(ratio, 1.0 - self.clip_param,
                                     1.0 + self.clip_param) * adv_targ)
                action_loss = -torch.min(surr1, surr2).mean()

                if self.use_clipped_value_loss:
                    value_pred_clipped = value_preds_batch + (
                        values - value_preds_batch).clamp(
                            -self.clip_param, self.clip_param)
                    value_losses = (values - return_batch).pow(2)
                    value_losses_clipped = (value_pred_clipped -
                                            return_batch).pow(2)
                    value_loss = (
                        0.5 *
                        torch.max(value_losses, value_losses_clipped).mean())
                else:
                    value_loss = 0.5 * (return_batch - values).pow(2).mean()

                self.optimizer.zero_grad()
                total_loss = (value_loss * self.value_loss_coef + action_loss -
                              dist_entropy * self.entropy_coef)

                self.before_backward(total_loss)
                total_loss.backward()
                self.after_backward(total_loss)

                self.before_step()
                self.optimizer.step()
                self.after_step()

                value_loss_epoch += value_loss.item()
                action_loss_epoch += action_loss.item()
                dist_entropy_epoch += dist_entropy.item()

        num_updates = self.ppo_epoch * self.num_mini_batch

        value_loss_epoch /= num_updates
        action_loss_epoch /= num_updates
        dist_entropy_epoch /= num_updates

        return value_loss_epoch, action_loss_epoch, dist_entropy_epoch