Exemple #1
0
def ppo_init(policy: Policy, obs_space: gym.spaces.Space,
             action_space: gym.spaces.Space,
             config: TrainerConfigDict) -> None:
    """ TODO: Write documentation.
    """
    # Call base implementation
    setup_mixins(policy, obs_space, action_space, config)

    # Add previous observation in viewer requirements for CAPS loss computation
    # TODO: Remove update of `policy.model.view_requirements` after ray fix
    caps_view_requirements = {
        "_prev_obs":
        ViewRequirement(data_col="obs",
                        space=obs_space,
                        shift=-1,
                        used_for_compute_actions=False)
    }
    policy.model.view_requirements.update(caps_view_requirements)
    policy.view_requirements.update(caps_view_requirements)

    # Initialize extra loss
    policy._mean_symmetric_policy_loss = 0.0
    policy._mean_temporal_caps_loss = 0.0
    policy._mean_spatial_caps_loss = 0.0
    policy._mean_global_caps_loss = 0.0
Exemple #2
0
def ppo_loss(policy: Policy, model: ModelV2,
             dist_class: Type[TorchDistributionWrapper],
             train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    """ TODO: Write documentation.
    """
    # Compute original ppo loss
    total_loss = ppo_surrogate_loss(policy, model, dist_class, train_batch)

    # Shallow copy the input batch.
    # Be careful accessing fields using the original batch to properly
    # keep track of acessed keys, which will be used to discard useless
    # components of policy's view requirements.
    train_batch_copy = train_batch.copy(shallow=True)

    # Extract mean of predicted action from logits.
    # No need to compute the perform model forward pass since the original
    # PPO loss is already doing it, so just getting back the last ouput.
    action_logits = model._last_output
    if issubclass(dist_class, TorchDiagGaussian):
        action_mean_true, _ = torch.chunk(action_logits, 2, dim=1)
    else:
        action_dist = dist_class(action_logits, model)
        action_mean_true = action_dist.deterministic_sample()

    if policy.config["caps_temporal_reg"] > 0.0:
        # Compute the mean action corresponding to the previous observation
        observation_prev = train_batch["_prev_obs"]
        train_batch_copy["obs"] = observation_prev
        action_logits_prev, _ = model(train_batch_copy)
        if issubclass(dist_class, TorchDiagGaussian):
            action_mean_prev, _ = torch.chunk(action_logits_prev, 2, dim=1)
        else:
            action_dist_prev = dist_class(action_logits_prev, model)
            action_mean_prev = action_dist_prev.deterministic_sample()

        # Minimize the difference between the successive action mean
        policy._mean_temporal_caps_loss = torch.mean(
            (action_mean_prev - action_mean_true)**2)

        # Add temporal smoothness loss to total loss
        total_loss += policy.config["caps_temporal_reg"] * \
            policy._mean_temporal_caps_loss

    if policy.config["caps_spatial_reg"] > 0.0 or \
            policy.config["symmetric_policy_reg"] > 0.0:
        # Generate noisy observation based on specified sensivity
        offset = 0
        observation_true = train_batch["obs"]
        observation_noisy = observation_true.clone()
        batch_dim = observation_true.shape[:-1]
        observation_space = policy.observation_space.original_space
        for scale in observation_space.sensitivity.values():
            scale = torch.from_numpy(scale.copy()).to(
                dtype=torch.float32, device=observation_true.device)
            unit_noise = torch.randn((*batch_dim, len(scale)),
                                     device=observation_true.device)
            slice_idx = slice(offset, offset + len(scale))
            observation_noisy[..., slice_idx].addcmul_(scale, unit_noise)
            offset += len(scale)

        # Compute the mean action corresponding to the noisy observation
        train_batch_copy["obs"] = observation_noisy
        action_logits_noisy, _ = model(train_batch_copy)
        if issubclass(dist_class, TorchDiagGaussian):
            action_mean_noisy, _ = torch.chunk(action_logits_noisy, 2, dim=1)
        else:
            action_dist_noisy = dist_class(action_logits_noisy, model)
            action_mean_noisy = action_dist_noisy.deterministic_sample()

    if policy.config["caps_spatial_reg"] > 0.0:
        # Minimize the difference between the original action mean and the
        # one corresponding to the noisy observation.
        policy._mean_spatial_caps_loss = torch.mean(
            (action_mean_noisy - action_mean_true)**2)

        # Add spatial smoothness loss to total loss
        total_loss += policy.config["caps_spatial_reg"] * \
            policy._mean_spatial_caps_loss

    if policy.config["caps_global_reg"] > 0.0:
        # Minimize the magnitude of action mean
        policy._mean_global_caps_loss = torch.mean(action_mean_true**2)

        # Add global smoothness loss to total loss
        total_loss += policy.config["caps_global_reg"] * \
            policy._mean_global_caps_loss

    if policy.config["symmetric_policy_reg"] > 0.0:
        # Compute mirrorred observation
        offset = 0
        observation_mirror = torch.empty_like(observation_true)
        observation_space = policy.observation_space.original_space
        for mirror_mat in observation_space.mirror_mat.values():
            mirror_mat = torch.from_numpy(mirror_mat.T.copy()).to(
                dtype=torch.float32, device=observation_true.device)
            slice_idx = slice(offset, offset + len(mirror_mat))
            torch.mm(observation_true[..., slice_idx],
                     mirror_mat,
                     out=observation_mirror[..., slice_idx])
            offset += len(mirror_mat)

        # Compute the mirrored mean action corresponding to the mirrored action
        train_batch_copy["obs"] = observation_mirror
        action_logits_mirror, _ = model(train_batch_copy)
        if issubclass(dist_class, TorchDiagGaussian):
            action_mean_mirror, _ = torch.chunk(action_logits_mirror, 2, dim=1)
        else:
            action_dist_mirror = dist_class(action_logits_mirror, model)
            action_mean_mirror = action_dist_mirror.deterministic_sample()
        action_mirror_mat = policy.action_space.mirror_mat
        action_mirror_mat = torch.from_numpy(action_mirror_mat.T.copy()).to(
            dtype=torch.float32, device=observation_true.device)
        action_mean_mirror = action_mean_mirror @ action_mirror_mat

        # Minimize the assymetry of policy output
        policy._mean_symmetric_policy_loss = torch.mean(
            (action_mean_mirror - action_mean_true)**2)

        # Add policy symmetry loss to total loss
        total_loss += policy.config["symmetric_policy_reg"] * \
            policy._mean_symmetric_policy_loss

    return total_loss