Ejemplo n.º 1
0
def actor_critic_loss(policy, model, dist_class, train_batch):
    logits, _ = model.from_batch(train_batch)
    values = model.value_function()

    if policy.is_recurrent():
        max_seq_len = torch.max(train_batch["seq_lens"])
        mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len)
        valid_mask = torch.reshape(mask_orig, [-1])
    else:
        valid_mask = torch.ones_like(values, dtype=torch.bool)

    dist = dist_class(logits, model)
    log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
    policy.pi_err = -torch.sum(
        torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES],
                            valid_mask))
    policy.value_err = 0.5 * torch.sum(
        torch.pow(
            torch.masked_select(
                values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS],
                valid_mask), 2.0))
    policy.entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask))
    overall_err = (
        policy.pi_err + policy.value_err * policy.config["vf_loss_coeff"] -
        policy.entropy * policy.config["entropy_coeff"])
    return overall_err
Ejemplo n.º 2
0
    def forward(self, inputs: TensorType) -> TensorType:
        L = list(inputs.size())[1]  # length of segment
        H = self._num_heads  # number of attention heads
        D = self._head_dim  # attention head dimension

        qkv = self._qkv_layer(inputs)

        queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1)
        queries = queries[:, -L:]  # only query based on the segment

        queries = torch.reshape(queries, [-1, L, H, D])
        keys = torch.reshape(keys, [-1, L, H, D])
        values = torch.reshape(values, [-1, L, H, D])

        score = torch.einsum("bihd,bjhd->bijh", queries, keys)
        score = score / D**0.5

        # causal mask of the same length as the sequence
        mask = sequence_mask(torch.arange(1, L + 1), dtype=score.dtype)
        mask = mask[None, :, :, None]
        mask = mask.float()

        masked_score = score * mask + 1e30 * (mask - 1.)
        wmat = nn.functional.softmax(masked_score, dim=2)

        out = torch.einsum("bijh,bjhd->bihd", wmat, values)
        shape = list(out.size())[:2] + [H * D]
        #        temp = torch.cat(temp2, [H * D], dim=0)
        out = torch.reshape(out, shape)
        return self._linear_layer(out)
Ejemplo n.º 3
0
def actor_critic_loss(policy, model, dist_class, train_batch):
    # If policy is recurrent, mask out padded sequences
    # and calculate batch size
    if policy.is_recurrent():
        seq_lens = train_batch['seq_lens']
        max_seq_len = torch.max(seq_lens)
        mask_orig = sequence_mask(seq_lens, max_seq_len)
        mask = torch.reshape(mask_orig, [-1])
        batch_size = seq_lens.shape[0] * max_seq_len
    else:
        mask = torch.ones_like(train_batch[SampleBatch.REWARDS])
        batch_size = mask.shape[0]

    logits, _ = model.from_batch(train_batch)
    values = model.value_function()
    dist = dist_class(logits, model)
    log_probs = dist.logp(train_batch[SampleBatch.ACTIONS])
    policy.entropy = -torch.sum(dist.entropy() * mask) / batch_size
    policy.pi_err = -torch.sum(train_batch[Postprocessing.ADVANTAGES] *
                               log_probs.reshape(-1) * mask) / batch_size
    policy.value_err = torch.sum(
        torch.pow(
            (values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS]) *
            mask, 2.0)) / batch_size
    overall_err = sum([
        policy.pi_err,
        policy.config['vf_loss_coeff'] * policy.value_err,
        policy.config['entropy_coeff'] * policy.entropy,
    ])
    return overall_err
Ejemplo n.º 4
0
def ppo_surrogate_loss(policy, model, dist_class, train_batch):
    logits, state = model.from_batch(train_batch)
    action_dist = dist_class(logits, model)

    mask = None
    if state:
        max_seq_len = torch.max(train_batch["seq_lens"])
        mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = torch.reshape(mask, [-1])

    policy.loss_obj = PPOLoss(
        dist_class,
        model,
        train_batch[Postprocessing.VALUE_TARGETS],
        train_batch[Postprocessing.ADVANTAGES],
        train_batch[SampleBatch.ACTIONS],
        train_batch[SampleBatch.ACTION_DIST_INPUTS],
        train_batch[SampleBatch.ACTION_LOGP],
        train_batch[SampleBatch.VF_PREDS],
        action_dist,
        model.value_function(),
        policy.kl_coeff,
        mask,
        entropy_coeff=policy.entropy_coeff,
        clip_param=policy.config["clip_param"],
        vf_clip_param=policy.config["vf_clip_param"],
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        use_gae=policy.config["use_gae"],
    )

    return policy.loss_obj.loss
Ejemplo n.º 5
0
def actor_critic_loss(policy, model, dist_class, train_batch):
    assert policy.is_recurrent(), "policy must be recurrent"

    seq_lens = train_batch['seq_lens']
    batch_size = seq_lens.shape[0]
    max_seq_len = torch.max(seq_lens)
    mask_orig = sequence_mask(seq_lens, max_seq_len)
    mask = torch.reshape(mask_orig, [-1])

    horizon = policy.config['fun_horizon']

    manager_horizon_mask = mask_orig.clone()
    manager_horizon_mask[:, -horizon:] = False
    manager_horizon_mask = manager_horizon_mask.reshape(-1)

    # Hacky way of passing data from sample batch to train batch
    model.random_select = train_batch['random_select'].reshape(
        (batch_size, -1))
    model.random_goal = train_batch['random_goal'].reshape(
        (batch_size, max_seq_len, -1))

    logits, _ = model.from_batch(train_batch)
    manager_values, worker_values = model.value_function()
    manager_latent_state, manager_goal = model.manager_features()

    manager_latent_state_future = torch.roll(manager_latent_state, -horizon, 1)
    manager_latent_state_diff = (manager_latent_state_future -
                                 manager_latent_state).detach()

    policy.manager_loss = 10.0 * -torch.sum(
        train_batch['manager_advantages'] * F.cosine_similarity(
            manager_latent_state_diff, manager_goal, dim=-1).reshape(-1) *
        manager_horizon_mask) / (batch_size * max_seq_len)

    dist = dist_class(logits, model)
    log_probs = dist.logp(train_batch[SampleBatch.ACTIONS])
    policy.entropy = 3e-4 * -torch.sum(dist.entropy() * mask) / (batch_size *
                                                                 max_seq_len)
    policy.pi_err = 0.1 * -torch.sum(
        train_batch['worker_advantages'] * log_probs.reshape(-1) * mask) / (
            batch_size * max_seq_len)

    policy.manager_value_err = torch.sum(
        torch.pow(
            (manager_values.reshape(-1) - train_batch['manager_value_targets'])
            * mask, 2.0)) / (batch_size * max_seq_len)
    policy.worker_value_err = 0.01 * torch.sum(
        torch.pow(
            (worker_values.reshape(-1) - train_batch['worker_value_targets']) *
            mask, 2.0)) / (batch_size * max_seq_len)

    overall_err = sum([
        policy.pi_err,
        policy.manager_value_err,
        policy.worker_value_err,
        policy.entropy,
        policy.manager_loss,
    ])
    return overall_err
Ejemplo n.º 6
0
def loss_fn(policy: Policy, model: ModelV2,
            dist_class: TorchDistributionWrapper, sample_batch: SampleBatch):
    max_seq_len = sample_batch['seq_lens'].max().item()
    mask = sequence_mask(sample_batch['seq_lens'],
                         max_seq_len,
                         time_major=model.is_time_major()).view((-1, 1))
    mean_reg = sample_batch['seq_lens'].sum() * model.nbr_agents
    actions = sample_batch['actions'].view(
        (sample_batch['actions'].shape[0], model.nbr_agents,
         -1))[:, :, :1].to(torch.long)
    actions = add_time_dimension(actions,
                                 max_seq_len=max_seq_len,
                                 framework='torch',
                                 time_major=True).reshape_as(actions)

    logits_pi, _ = model(sample_batch, [
        sample_batch['state_in_0'],
    ], sample_batch['seq_lens'])
    logits_pi = logits_pi.view((logits_pi.shape[0], model.nbr_agents, -1))
    logits_pi_action = logits_pi[:, :, :model.nbr_actions]
    log_pi_action = nn.functional.log_softmax(logits_pi_action, dim=-1)
    pi_action = torch.exp(log_pi_action)
    log_pi_action_selected = torch.gather(log_pi_action, -1,
                                          actions).squeeze(-1)

    q_values = model.q_values(sample_batch, target=False)
    q_values = add_time_dimension(q_values,
                                  max_seq_len=max_seq_len,
                                  framework="torch",
                                  time_major=True).reshape_as(q_values)
    q_values_selected = torch.gather(q_values, -1, actions).squeeze(-1)
    q_values_target = sample_batch[Postprocessing.VALUE_TARGETS]
    q_values_target = add_time_dimension(
        q_values_target,
        max_seq_len=max_seq_len,
        framework="torch",
        time_major=True).reshape_as(q_values_target)
    td_error = q_values_selected - q_values_target

    with torch.no_grad():
        coma_avg = q_values_selected - (pi_action * q_values).sum(-1)
    entropy = -(log_pi_action * pi_action).sum(-1)

    critic_loss = torch.pow(mask * td_error, 2.0)
    actor_loss = mask * coma_avg * log_pi_action_selected
    entropy = mask * entropy

    policy.actor_loss = -actor_loss.sum() / mean_reg
    policy.critic_loss = critic_loss.sum() / mean_reg
    policy.entropy = entropy.sum() / mean_reg

    pi_loss = policy.actor_loss - policy.config[
        'entropy_coeff'] * policy.entropy

    return pi_loss, policy.critic_loss
Ejemplo n.º 7
0
    def forward(self, inputs: TensorType,
                memory: TensorType = None) -> TensorType:
        T = list(inputs.size())[1]  # length of segment (time)
        H = self._num_heads  # number of attention heads
        d = self._head_dim  # attention head dimension

        # Add previous memory chunk (as const, w/o gradient) to input.
        # Tau (number of (prev) time slices in each memory chunk).
        Tau = list(memory.shape)[1] if memory is not None else 0
        if memory is not None:
            memory.requires_grad_(False)
            inputs = torch.cat((memory, inputs), dim=1)

        # Apply the Layer-Norm.
        if self._input_layernorm is not None:
            inputs = self._input_layernorm(inputs)

        qkv = self._qkv_layer(inputs)

        queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1)
        # Cut out Tau memory timesteps from query.
        queries = queries[:, -T:]

        queries = torch.reshape(queries, [-1, T, H, d])
        keys = torch.reshape(keys, [-1, T + Tau, H, d])
        values = torch.reshape(values, [-1, T + Tau, H, d])

        R = self._pos_proj(self._rel_pos_encoder)
        R = torch.reshape(R, [T + Tau, H, d])

        # b=batch
        # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
        # h=head
        # d=head-dim (over which we will reduce-sum)
        score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys)
        pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R)
        score = score + self.rel_shift(pos_score)
        score = score / d**0.5

        # causal mask of the same length as the sequence
        mask = sequence_mask(
            torch.arange(Tau + 1, T + Tau + 1), dtype=score.dtype)
        mask = mask[None, :, :, None]

        masked_score = score * mask + 1e30 * (mask.to(torch.float32) - 1.)
        wmat = nn.functional.softmax(masked_score, dim=2)

        out = torch.einsum("bijh,bjhd->bihd", wmat, values)
        shape = list(out.shape)[:2] + [H * d]
        out = torch.reshape(out, shape)

        return self._linear_layer(out)
Ejemplo n.º 8
0
def actor_critic_loss(policy: Policy, model: ModelV2,
                      dist_class: ActionDistribution,
                      train_batch: SampleBatch) -> TensorType:
    logits, _ = model.from_batch(train_batch)
    values = model.value_function()

    if policy.is_recurrent():
        B = len(train_batch[SampleBatch.SEQ_LENS])
        max_seq_len = logits.shape[0] // B
        mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                  max_seq_len)
        valid_mask = torch.reshape(mask_orig, [-1])
    else:
        valid_mask = torch.ones_like(values, dtype=torch.bool)

    dist = dist_class(logits, model)
    log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
    pi_err = -torch.sum(
        torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES],
                            valid_mask))

    # Compute a value function loss.
    if policy.config["use_critic"]:
        value_err = 0.5 * torch.sum(
            torch.pow(
                torch.masked_select(
                    values.reshape(-1) -
                    train_batch[Postprocessing.VALUE_TARGETS], valid_mask),
                2.0))
    # Ignore the value function.
    else:
        value_err = 0.0

    entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask))

    total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] -
                  entropy * policy.config["entropy_coeff"])

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["entropy"] = entropy
    model.tower_stats["pi_err"] = pi_err
    model.tower_stats["value_err"] = value_err

    return total_loss
def ppo_surrogate_loss(policy, model, dist_class, train_batch):
    logits, state = model.from_batch(train_batch)
    action_dist = dist_class(logits, model)

    mask = None
    if state:
        max_seq_len = torch.max(train_batch["seq_lens"])
        mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = torch.reshape(mask, [-1])

    if policy.config["use_aux_loss"]:
        aux_input = {
            "new_obs": train_batch[SampleBatch.NEXT_OBS],
            "actions": train_batch[SampleBatch.ACTIONS],
            "dones": train_batch["dones"],
            "rewards": train_batch["rewards"]
        }
        aux_loss = policy.model.aux_loss(aux_input)
        train_batch["aux_loss"] = aux_loss

    policy.loss_obj = PPOLoss(
        dist_class,
        model,
        train_batch[Postprocessing.VALUE_TARGETS],
        train_batch[Postprocessing.ADVANTAGES],
        train_batch[SampleBatch.ACTIONS],
        train_batch[SampleBatch.ACTION_DIST_INPUTS],
        train_batch[SampleBatch.ACTION_LOGP],
        train_batch[SampleBatch.VF_PREDS],
        action_dist,
        model.value_function(),
        policy.kl_coeff,
        mask,
        train_batch.get("aux_loss", None),
        entropy_coeff=policy.entropy_coeff,
        clip_param=policy.config["clip_param"],
        vf_clip_param=policy.config["vf_clip_param"],
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        aux_loss_coeff=policy.config["aux_loss_coeff"],
        use_gae=policy.config["use_gae"],
    )

    return policy.loss_obj.loss
Ejemplo n.º 10
0
def actor_critic_loss(policy: Policy, model: ModelV2,
                      dist_class: ActionDistribution,
                      train_batch: SampleBatch) -> TensorType:
    logits, _ = model.from_batch(train_batch)
    values = model.value_function()

    if policy.is_recurrent():
        max_seq_len = torch.max(train_batch["seq_lens"])
        mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len)
        valid_mask = torch.reshape(mask_orig, [-1])
    else:
        valid_mask = torch.ones_like(values, dtype=torch.bool)

    dist = dist_class(logits, model)
    log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
    pi_err = -torch.sum(
        torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES],
                            valid_mask))

    # Compute a value function loss.
    if policy.config["use_critic"]:
        value_err = 0.5 * torch.sum(
            torch.pow(
                torch.masked_select(
                    values.reshape(-1) -
                    train_batch[Postprocessing.VALUE_TARGETS], valid_mask),
                2.0))
    # Ignore the value function.
    else:
        value_err = 0.0

    entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask))

    total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] -
                  entropy * policy.config["entropy_coeff"])

    policy.entropy = entropy
    policy.pi_err = pi_err
    policy.value_err = value_err

    return total_loss
Ejemplo n.º 11
0
def build_vtrace_loss(policy, model, dist_class, train_batch):
    model_out, _ = model.from_batch(train_batch)
    action_dist = dist_class(model_out, model)

    if isinstance(policy.action_space, gym.spaces.Discrete):
        is_multidiscrete = False
        output_hidden_shape = [policy.action_space.n]
    elif isinstance(policy.action_space, gym.spaces.MultiDiscrete):
        is_multidiscrete = True
        output_hidden_shape = policy.action_space.nvec.astype(np.int32)
    else:
        is_multidiscrete = False
        output_hidden_shape = 1

    def _make_time_major(*args, **kw):
        return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
                               *args, **kw)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
    if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
        unpacked_behaviour_logits = torch.split(behaviour_logits,
                                                list(output_hidden_shape),
                                                dim=1)
        unpacked_outputs = torch.split(model_out,
                                       list(output_hidden_shape),
                                       dim=1)
    else:
        unpacked_behaviour_logits = torch.chunk(behaviour_logits,
                                                output_hidden_shape,
                                                dim=1)
        unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1)
    values = model.value_function()

    if policy.is_recurrent():
        max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
        mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                  max_seq_len)
        mask = torch.reshape(mask_orig, [-1])
    else:
        mask = torch.ones_like(rewards)

    # Prepare actions for loss.
    loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions,
                                                                    dim=1)

    # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
    loss = VTraceLoss(
        actions=_make_time_major(loss_actions, drop_last=True),
        actions_logp=_make_time_major(action_dist.logp(actions),
                                      drop_last=True),
        actions_entropy=_make_time_major(action_dist.entropy(),
                                         drop_last=True),
        dones=_make_time_major(dones, drop_last=True),
        behaviour_action_logp=_make_time_major(behaviour_action_logp,
                                               drop_last=True),
        behaviour_logits=_make_time_major(unpacked_behaviour_logits,
                                          drop_last=True),
        target_logits=_make_time_major(unpacked_outputs, drop_last=True),
        discount=policy.config["gamma"],
        rewards=_make_time_major(rewards, drop_last=True),
        values=_make_time_major(values, drop_last=True),
        bootstrap_value=_make_time_major(values)[-1],
        dist_class=TorchCategorical if is_multidiscrete else dist_class,
        model=model,
        valid_mask=_make_time_major(mask, drop_last=True),
        config=policy.config,
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        entropy_coeff=policy.entropy_coeff,
        clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
        clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"])

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["pi_loss"] = loss.pi_loss
    model.tower_stats["vf_loss"] = loss.vf_loss
    model.tower_stats["entropy"] = loss.entropy
    model.tower_stats["mean_entropy"] = loss.mean_entropy
    model.tower_stats["total_loss"] = loss.total_loss

    values_batched = make_time_major(policy,
                                     train_batch.get(SampleBatch.SEQ_LENS),
                                     values,
                                     drop_last=policy.config["vtrace"])
    model.tower_stats["vf_explained_var"] = explained_variance(
        torch.reshape(loss.value_targets, [-1]),
        torch.reshape(values_batched, [-1]))

    return loss.total_loss
Ejemplo n.º 12
0
def ppo_surrogate_loss(
        policy: Policy, model: ModelV2,
        dist_class: Type[TorchDistributionWrapper],
        train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for Proximal Policy Objective.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[ActionDistribution]: The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    logits, state = model.from_batch(train_batch, is_training=True)
    curr_action_dist = dist_class(logits, model)

    # RNN case: Mask away 0-padded chunks at end of time axis.
    if state:
        B = len(train_batch["seq_lens"])
        max_seq_len = logits.shape[0] // B
        mask = sequence_mask(train_batch["seq_lens"],
                             max_seq_len,
                             time_major=model.is_time_major())
        mask = torch.reshape(mask, [-1])
        num_valid = torch.sum(mask)

        def reduce_mean_valid(t):
            return torch.sum(t[mask]) / num_valid

    # non-RNN case: No masking.
    else:
        mask = None
        reduce_mean_valid = torch.mean

    prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
                                  model)

    logp_ratio = torch.exp(
        curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
        train_batch[SampleBatch.ACTION_LOGP])
    action_kl = prev_action_dist.kl(curr_action_dist)
    mean_kl = reduce_mean_valid(action_kl)

    curr_entropy = curr_action_dist.entropy()
    mean_entropy = reduce_mean_valid(curr_entropy)

    surrogate_loss = torch.min(
        train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
        train_batch[Postprocessing.ADVANTAGES] *
        torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
                    1 + policy.config["clip_param"]))
    mean_policy_loss = reduce_mean_valid(-surrogate_loss)

    if policy.config["use_gae"]:
        prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
        value_fn_out = model.value_function()
        vf_loss1 = torch.pow(
            value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
        vf_clipped = prev_value_fn_out + torch.clamp(
            value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
            policy.config["vf_clip_param"])
        vf_loss2 = torch.pow(
            vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
        vf_loss = torch.max(vf_loss1, vf_loss2)
        mean_vf_loss = reduce_mean_valid(vf_loss)
        total_loss = reduce_mean_valid(-surrogate_loss +
                                       policy.kl_coeff * action_kl +
                                       policy.config["vf_loss_coeff"] *
                                       vf_loss -
                                       policy.entropy_coeff * curr_entropy)
    else:
        mean_vf_loss = 0.0
        total_loss = reduce_mean_valid(-surrogate_loss +
                                       policy.kl_coeff * action_kl -
                                       policy.entropy_coeff * curr_entropy)

    # Store stats in policy for stats_fn.
    policy._total_loss = total_loss
    policy._mean_policy_loss = mean_policy_loss
    policy._mean_vf_loss = mean_vf_loss
    policy._vf_explained_var = explained_variance(
        train_batch[Postprocessing.VALUE_TARGETS],
        policy.model.value_function())
    policy._mean_entropy = mean_entropy
    policy._mean_kl = mean_kl

    return total_loss
Ejemplo n.º 13
0
def build_CAT_vtrace_loss(policy, model, dist_class, train_batch):
    action_space_parts = model.action_space_parts

    def _make_time_major(*args, **kw):
        return make_time_major(policy, train_batch.get("seq_lens"), *args,
                               **kw)

    # Repeat the output_hidden_shape depending on the number of actions that have been generated
    # output_hidden_shape = np.tile(output_hidden_shape, action_repeats)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

    invalid_action_mask = train_batch['invalid_action_mask']
    autoregressive_actions = policy.config['autoregressive_actions']

    if 'seq_lens' in train_batch:
        max_seq_len = policy.config['rollout_fragment_length']
        mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = torch.reshape(mask_orig, [-1])
    else:
        mask = torch.ones_like(rewards)

    actions_per_step = policy.config["actions_per_step"]

    states = []
    i = 0
    while "state_in_{}".format(i) in train_batch:
        states.append(train_batch["state_in_{}".format(i)])
        i += 1

    seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else []

    model.observation_features_module(train_batch, states, seq_lens)
    action_features, _ = model.action_features_module(train_batch, states, seq_lens)

    previous_action = None
    embedded_action = None
    logp_list = []
    entropy_list = []
    logits_list = []

    multi_actions = torch.chunk(actions, actions_per_step, dim=1)
    multi_invalid_action_mask = torch.chunk(invalid_action_mask, actions_per_step, dim=1)
    for a in range(actions_per_step):
        if autoregressive_actions:
            if a == 0:
                batch_size = action_features.shape[0]
                previous_action = torch.zeros([batch_size, len(action_space_parts)]).to(action_features.device)
            else:
                previous_action = multi_actions[a-1]

            embedded_action = model.embed_action_module(previous_action)

        logits = model.action_module(action_features, embedded_action)
        logits += torch.maximum(torch.tensor(torch.finfo().min), torch.log(multi_invalid_action_mask[a]))
        cat = TorchMultiCategorical(logits, model, action_space_parts)

        logits_list.append(logits)
        logp_list.append(cat.logp(multi_actions[a]))
        entropy_list.append(cat.entropy())

    logp = torch.stack(logp_list, dim=1).sum(dim=1)
    entropy = torch.stack(entropy_list, dim=1).sum(dim=1)
    target_logits = torch.hstack(logits_list)

    unpack_shape = np.tile(action_space_parts, actions_per_step)

    unpacked_behaviour_logits = torch.split(behaviour_logits, list(unpack_shape), dim=1)
    unpacked_outputs = torch.split(target_logits, list(unpack_shape), dim=1)

    values = model.value_function()

    # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
    policy.loss = VTraceLoss(
        actions=_make_time_major(actions, drop_last=True),
        actions_logp=_make_time_major(logp, drop_last=True),
        actions_entropy=_make_time_major(entropy, drop_last=True),
        dones=_make_time_major(dones, drop_last=True),
        behaviour_action_logp=_make_time_major(
            behaviour_action_logp, drop_last=True),
        behaviour_logits=_make_time_major(
            unpacked_behaviour_logits, drop_last=True),
        target_logits=_make_time_major(unpacked_outputs, drop_last=True),
        discount=policy.config["gamma"],
        rewards=_make_time_major(rewards, drop_last=True),
        values=_make_time_major(values, drop_last=True),
        bootstrap_value=_make_time_major(values)[-1],
        dist_class=TorchCategorical,
        model=model,
        valid_mask=_make_time_major(mask, drop_last=True),
        config=policy.config,
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        entropy_coeff=policy.entropy_coeff,
        clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
        clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"])

    return policy.loss.total_loss
Ejemplo n.º 14
0
def r2d2_loss(policy: Policy, model, _,
              train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for R2D2TorchPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    config = policy.config

    # Construct internal state inputs.
    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches

    # Q-network evaluation (at t).
    q, _, _, _ = compute_q_values(policy,
                                  model,
                                  train_batch,
                                  state_batches=state_batches,
                                  seq_lens=train_batch.get("seq_lens"),
                                  explore=False,
                                  is_training=True)

    # Target Q-network evaluation (at t+1).
    q_target, _, _, _ = compute_q_values(policy,
                                         policy.target_q_model,
                                         train_batch,
                                         state_batches=state_batches,
                                         seq_lens=train_batch.get("seq_lens"),
                                         explore=False,
                                         is_training=True)

    actions = train_batch[SampleBatch.ACTIONS].long()
    dones = train_batch[SampleBatch.DONES].float()
    rewards = train_batch[SampleBatch.REWARDS]
    weights = train_batch[PRIO_WEIGHTS]

    B = state_batches[0].shape[0]
    T = q.shape[0] // B

    # Q scores for actions which we know were selected in the given state.
    one_hot_selection = F.one_hot(actions, policy.action_space.n)
    q_selected = torch.sum(
        torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=policy.device))
        * one_hot_selection, 1)

    if config["double_q"]:
        best_actions = torch.argmax(q, dim=1)
    else:
        best_actions = torch.argmax(q_target, dim=1)

    best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n)
    q_target_best = torch.sum(
        torch.where(q_target > FLOAT_MIN, q_target,
                    torch.tensor(0.0, device=policy.device)) *
        best_actions_one_hot,
        dim=1)

    if config["num_atoms"] > 1:
        raise ValueError("Distributional R2D2 not supported yet!")
    else:
        q_target_best_masked_tp1 = (1.0 - dones) * torch.cat(
            [q_target_best[1:],
             torch.tensor([0.0], device=policy.device)])

        if config["use_h_function"]:
            h_inv = h_inverse(q_target_best_masked_tp1,
                              config["h_function_epsilon"])
            target = h_function(
                rewards + config["gamma"]**config["n_step"] * h_inv,
                config["h_function_epsilon"])
        else:
            target = rewards + \
                config["gamma"] ** config["n_step"] * q_target_best_masked_tp1

        # Seq-mask all loss-related terms.
        seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1]
        # Mask away also the burn-in sequence at the beginning.
        burn_in = policy.config["burn_in"]
        if burn_in > 0 and burn_in < T:
            seq_mask[:, :burn_in] = False

        num_valid = torch.sum(seq_mask)

        def reduce_mean_valid(t):
            return torch.sum(t[seq_mask]) / num_valid

        # Make sure use the correct time indices:
        # Q(t) - [gamma * r + Q^(t+1)]
        q_selected = q_selected.reshape([B, T])[:, :-1]
        td_error = q_selected - target.reshape([B, T])[:, :-1].detach()
        td_error = td_error * seq_mask
        weights = weights.reshape([B, T])[:, :-1]
        policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
        policy._td_error = td_error.reshape([-1])
        policy._loss_stats = {
            "mean_q": reduce_mean_valid(q_selected),
            "min_q": torch.min(q_selected),
            "max_q": torch.max(q_selected),
            "mean_td_error": reduce_mean_valid(td_error),
        }

    return policy._total_loss
Ejemplo n.º 15
0
def build_vtrace_loss(policy, model, dist_class, train_batch):
    model_out, _ = model.from_batch(train_batch)
    action_dist = dist_class(model_out, model)

    if isinstance(policy.action_space, gym.spaces.Discrete):
        is_multidiscrete = False
        output_hidden_shape = [policy.action_space.n]
    elif isinstance(policy.action_space, gym.spaces.MultiDiscrete):
        is_multidiscrete = True
        output_hidden_shape = policy.action_space.nvec.astype(np.int32)
    else:
        is_multidiscrete = False
        output_hidden_shape = 1

    def _make_time_major(*args, **kw):
        return make_time_major(policy, train_batch.get("seq_lens"), *args,
                               **kw)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
    if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
        unpacked_behaviour_logits = torch.split(behaviour_logits,
                                                list(output_hidden_shape),
                                                dim=1)
        unpacked_outputs = torch.split(model_out,
                                       list(output_hidden_shape),
                                       dim=1)
    else:
        unpacked_behaviour_logits = torch.chunk(behaviour_logits,
                                                output_hidden_shape,
                                                dim=1)
        unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1)
    values = model.value_function()

    if policy.is_recurrent():
        max_seq_len = torch.max(train_batch["seq_lens"])
        mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = torch.reshape(mask_orig, [-1])
    else:
        mask = torch.ones_like(rewards)

    # Prepare actions for loss.
    loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions,
                                                                    dim=1)

    # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
    loss = VTraceLoss(
        actions=_make_time_major(loss_actions, drop_last=True),
        actions_logp=_make_time_major(action_dist.logp(actions),
                                      drop_last=True),
        actions_entropy=_make_time_major(action_dist.entropy(),
                                         drop_last=True),
        dones=_make_time_major(dones, drop_last=True),
        behaviour_action_logp=_make_time_major(behaviour_action_logp,
                                               drop_last=True),
        behaviour_logits=_make_time_major(unpacked_behaviour_logits,
                                          drop_last=True),
        target_logits=_make_time_major(unpacked_outputs, drop_last=True),
        discount=policy.config["gamma"],
        rewards=_make_time_major(rewards, drop_last=True),
        values=_make_time_major(values, drop_last=True),
        bootstrap_value=_make_time_major(values)[-1],
        dist_class=TorchCategorical if is_multidiscrete else dist_class,
        model=model,
        valid_mask=_make_time_major(mask, drop_last=True),
        config=policy.config,
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        entropy_coeff=policy.entropy_coeff,
        clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
        clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"])

    # Store loss object only for multi-GPU tower 0.
    if model is policy.model_gpu_towers[0]:
        policy.loss = loss

    return loss.total_loss
Ejemplo n.º 16
0
def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
    model_out, _ = model.from_batch(train_batch)
    action_dist = dist_class(model_out, model)

    if isinstance(policy.action_space, gym.spaces.Discrete):
        is_multidiscrete = False
        output_hidden_shape = [policy.action_space.n]
    elif isinstance(policy.action_space,
                    gym.spaces.multi_discrete.MultiDiscrete):
        is_multidiscrete = True
        output_hidden_shape = policy.action_space.nvec.astype(np.int32)
    else:
        is_multidiscrete = False
        output_hidden_shape = 1

    def _make_time_major(*args, **kw):
        return make_time_major(policy, train_batch.get("seq_lens"), *args,
                               **kw)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

    target_model_out, _ = policy.target_model.from_batch(train_batch)
    old_policy_behaviour_logits = target_model_out.detach()

    unpacked_behaviour_logits = torch.split(behaviour_logits,
                                            output_hidden_shape,
                                            dim=1)
    unpacked_old_policy_behaviour_logits = torch.split(
        old_policy_behaviour_logits, output_hidden_shape, dim=1)
    unpacked_outputs = torch.split(model_out, output_hidden_shape, dim=1)
    old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
    prev_action_dist = dist_class(behaviour_logits, policy.model)
    values = policy.model.value_function()

    policy.model_vars = policy.model.variables()
    policy.target_model_vars = policy.target_model.variables()

    if policy.is_recurrent():
        max_seq_len = torch.max(train_batch["seq_lens"]) - 1
        mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = torch.reshape(mask, [-1])
    else:
        mask = torch.ones_like(rewards)

    if policy.config["vtrace"]:
        logger.debug("Using V-Trace surrogate loss (vtrace=True)")

        # Prepare actions for loss
        loss_actions = actions if is_multidiscrete else torch.unsqueeze(
            actions, dim=1)

        # Prepare KL for Loss
        mean_kl = _make_time_major(
            old_policy_action_dist.multi_kl(action_dist), drop_last=True)

        policy.loss = VTraceSurrogateLoss(
            actions=_make_time_major(loss_actions, drop_last=True),
            prev_actions_logp=_make_time_major(prev_action_dist.logp(actions),
                                               drop_last=True),
            actions_logp=_make_time_major(action_dist.logp(actions),
                                          drop_last=True),
            old_policy_actions_logp=_make_time_major(
                old_policy_action_dist.logp(actions), drop_last=True),
            action_kl=torch.mean(mean_kl, dim=0)
            if is_multidiscrete else mean_kl,
            actions_entropy=_make_time_major(action_dist.multi_entropy(),
                                             drop_last=True),
            dones=_make_time_major(dones, drop_last=True),
            behaviour_logits=_make_time_major(unpacked_behaviour_logits,
                                              drop_last=True),
            old_policy_behaviour_logits=_make_time_major(
                unpacked_old_policy_behaviour_logits, drop_last=True),
            target_logits=_make_time_major(unpacked_outputs, drop_last=True),
            discount=policy.config["gamma"],
            rewards=_make_time_major(rewards, drop_last=True),
            values=_make_time_major(values, drop_last=True),
            bootstrap_value=_make_time_major(values)[-1],
            dist_class=TorchCategorical if is_multidiscrete else dist_class,
            model=policy.model,
            valid_mask=_make_time_major(mask, drop_last=True),
            vf_loss_coeff=policy.config["vf_loss_coeff"],
            entropy_coeff=policy.config["entropy_coeff"],
            clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
            clip_pg_rho_threshold=policy.
            config["vtrace_clip_pg_rho_threshold"],
            clip_param=policy.config["clip_param"],
            cur_kl_coeff=policy.kl_coeff,
            use_kl_loss=policy.config["use_kl_loss"])
    else:
        logger.debug("Using PPO surrogate loss (vtrace=False)")

        # Prepare KL for Loss
        mean_kl = _make_time_major(prev_action_dist.multi_kl(action_dist))

        policy.loss = PPOSurrogateLoss(
            prev_actions_logp=_make_time_major(prev_action_dist.logp(actions)),
            actions_logp=_make_time_major(action_dist.logp(actions)),
            action_kl=torch.mean(mean_kl, dim=0)
            if is_multidiscrete else mean_kl,
            actions_entropy=_make_time_major(action_dist.multi_entropy()),
            values=_make_time_major(values),
            valid_mask=_make_time_major(mask),
            advantages=_make_time_major(
                train_batch[Postprocessing.ADVANTAGES]),
            value_targets=_make_time_major(
                train_batch[Postprocessing.VALUE_TARGETS]),
            vf_loss_coeff=policy.config["vf_loss_coeff"],
            entropy_coeff=policy.config["entropy_coeff"],
            clip_param=policy.config["clip_param"],
            cur_kl_coeff=policy.kl_coeff,
            use_kl_loss=policy.config["use_kl_loss"])

    return policy.loss.total_loss
Ejemplo n.º 17
0
def appo_surrogate_loss(policy: Policy, model: ModelV2,
                        dist_class: Type[TorchDistributionWrapper],
                        train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for APPO.

    With IS modifications and V-trace for Advantage Estimation.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[ActionDistribution]): The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    target_model = policy.target_models[model]

    model_out, _ = model.from_batch(train_batch)
    action_dist = dist_class(model_out, model)

    if isinstance(policy.action_space, gym.spaces.Discrete):
        is_multidiscrete = False
        output_hidden_shape = [policy.action_space.n]
    elif isinstance(policy.action_space,
                    gym.spaces.multi_discrete.MultiDiscrete):
        is_multidiscrete = True
        output_hidden_shape = policy.action_space.nvec.astype(np.int32)
    else:
        is_multidiscrete = False
        output_hidden_shape = 1

    def _make_time_major(*args, **kw):
        return make_time_major(policy, train_batch.get("seq_lens"), *args,
                               **kw)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

    target_model_out, _ = target_model.from_batch(train_batch)

    prev_action_dist = dist_class(behaviour_logits, model)
    values = model.value_function()
    values_time_major = _make_time_major(values)

    if policy.is_recurrent():
        max_seq_len = torch.max(train_batch["seq_lens"])
        mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = torch.reshape(mask, [-1])
        mask = _make_time_major(mask, drop_last=policy.config["vtrace"])
        num_valid = torch.sum(mask)

        def reduce_mean_valid(t):
            return torch.sum(t[mask]) / num_valid

    else:
        reduce_mean_valid = torch.mean

    if policy.config["vtrace"]:
        logger.debug("Using V-Trace surrogate loss (vtrace=True)")

        old_policy_behaviour_logits = target_model_out.detach()
        old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

        if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
            unpacked_behaviour_logits = torch.split(behaviour_logits,
                                                    list(output_hidden_shape),
                                                    dim=1)
            unpacked_old_policy_behaviour_logits = torch.split(
                old_policy_behaviour_logits, list(output_hidden_shape), dim=1)
        else:
            unpacked_behaviour_logits = torch.chunk(behaviour_logits,
                                                    output_hidden_shape,
                                                    dim=1)
            unpacked_old_policy_behaviour_logits = torch.chunk(
                old_policy_behaviour_logits, output_hidden_shape, dim=1)

        # Prepare actions for loss.
        loss_actions = actions if is_multidiscrete else torch.unsqueeze(
            actions, dim=1)

        # Prepare KL for loss.
        action_kl = _make_time_major(old_policy_action_dist.kl(action_dist),
                                     drop_last=True)

        # Compute vtrace on the CPU for better perf.
        vtrace_returns = vtrace.multi_from_logits(
            behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits,
                                                     drop_last=True),
            target_policy_logits=_make_time_major(
                unpacked_old_policy_behaviour_logits, drop_last=True),
            actions=torch.unbind(_make_time_major(loss_actions,
                                                  drop_last=True),
                                 dim=2),
            discounts=(1.0 - _make_time_major(dones, drop_last=True).float()) *
            policy.config["gamma"],
            rewards=_make_time_major(rewards, drop_last=True),
            values=values_time_major[:-1],  # drop-last=True
            bootstrap_value=values_time_major[-1],
            dist_class=TorchCategorical if is_multidiscrete else dist_class,
            model=model,
            clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
            clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]
        )

        actions_logp = _make_time_major(action_dist.logp(actions),
                                        drop_last=True)
        prev_actions_logp = _make_time_major(prev_action_dist.logp(actions),
                                             drop_last=True)
        old_policy_actions_logp = _make_time_major(
            old_policy_action_dist.logp(actions), drop_last=True)
        is_ratio = torch.clamp(
            torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
        logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
        policy._is_ratio = is_ratio

        advantages = vtrace_returns.pg_advantages.to(logp_ratio.device)
        surrogate_loss = torch.min(
            advantages * logp_ratio,
            advantages *
            torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
                        1 + policy.config["clip_param"]))

        mean_kl = reduce_mean_valid(action_kl)
        mean_policy_loss = -reduce_mean_valid(surrogate_loss)

        # The value function loss.
        value_targets = vtrace_returns.vs.to(values_time_major.device)
        delta = values_time_major[:-1] - value_targets
        mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

        # The entropy loss.
        mean_entropy = reduce_mean_valid(
            _make_time_major(action_dist.entropy(), drop_last=True))

    else:
        logger.debug("Using PPO surrogate loss (vtrace=False)")

        # Prepare KL for Loss
        action_kl = _make_time_major(prev_action_dist.kl(action_dist))

        actions_logp = _make_time_major(action_dist.logp(actions))
        prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
        logp_ratio = torch.exp(actions_logp - prev_actions_logp)

        advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
        surrogate_loss = torch.min(
            advantages * logp_ratio,
            advantages *
            torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
                        1 + policy.config["clip_param"]))

        mean_kl = reduce_mean_valid(action_kl)
        mean_policy_loss = -reduce_mean_valid(surrogate_loss)

        # The value function loss.
        value_targets = _make_time_major(
            train_batch[Postprocessing.VALUE_TARGETS])
        delta = values_time_major - value_targets
        mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

        # The entropy loss.
        mean_entropy = reduce_mean_valid(
            _make_time_major(action_dist.entropy()))

    # The summed weighted loss
    total_loss = mean_policy_loss + \
        mean_vf_loss * policy.config["vf_loss_coeff"] - \
        mean_entropy * policy.config["entropy_coeff"]

    # Optional additional KL Loss
    if policy.config["use_kl_loss"]:
        total_loss += policy.kl_coeff * mean_kl

    policy._total_loss = total_loss
    policy._mean_policy_loss = mean_policy_loss
    policy._mean_kl = mean_kl
    policy._mean_vf_loss = mean_vf_loss
    policy._mean_entropy = mean_entropy
    policy._value_targets = value_targets
    policy._vf_explained_var = explained_variance(
        torch.reshape(value_targets, [-1]),
        torch.reshape(
            values_time_major[:-1]
            if policy.config["vtrace"] else values_time_major, [-1]),
    )

    return total_loss
Ejemplo n.º 18
0
def drq_ppo_surrogate_loss(policy, model, dist_class, train_batch):
    """ loss function for PPO with input augmentations
    """
    ################################################################################################
    # logits, state = model.from_batch(train_batch)
    # action_dist = dist_class(logits, model)
    # mask = None
    # if state:
    #     max_seq_len = torch.max(train_batch["seq_lens"])
    #     mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
    #     mask = torch.reshape(mask, [-1])
    # policy.loss_obj = PPOLoss(
    #     dist_class,
    #     model,
    #     train_batch[Postprocessing.VALUE_TARGETS],
    #     train_batch[Postprocessing.ADVANTAGES],
    #     train_batch[SampleBatch.ACTIONS],
    #     train_batch[SampleBatch.ACTION_DIST_INPUTS],
    #     train_batch[SampleBatch.ACTION_LOGP],
    #     train_batch[SampleBatch.VF_PREDS],
    #     action_dist,
    #     model.value_function(),
    #     policy.kl_coeff,
    #     mask,
    #     entropy_coeff=policy.entropy_coeff,
    #     clip_param=policy.config["clip_param"],
    #     vf_clip_param=policy.config["vf_clip_param"],
    #     vf_loss_coeff=policy.config["vf_loss_coeff"],
    #     use_gae=policy.config["use_gae"],
    # )
    # return policy.loss_obj.loss

    # NOTE: averaged augmented loss for ppo
    aug_num = policy.config["aug_num"]
    aug_loss = 0
    orig_cur_obs = train_batch[SampleBatch.CUR_OBS].clone()

    for _ in range(aug_num):
        # do augmentation, overwrite last augmented obs
        aug_cur_obs = model.trans(orig_cur_obs.permute(0, 3, 1,
                                                       2).float()).permute(
                                                           0, 2, 3, 1)
        train_batch[SampleBatch.CUR_OBS] = aug_cur_obs

        # forward with augmented obs
        logits, state = model.from_batch(train_batch)
        action_dist = dist_class(logits, model)

        mask = None
        if state:
            max_seq_len = torch.max(train_batch["seq_lens"])
            mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
            mask = torch.reshape(mask, [-1])

        policy.loss_obj = PPOLoss(
            dist_class,
            model,
            train_batch[Postprocessing.VALUE_TARGETS],
            train_batch[Postprocessing.ADVANTAGES],
            train_batch[SampleBatch.ACTIONS],
            train_batch[SampleBatch.ACTION_DIST_INPUTS],
            train_batch[SampleBatch.ACTION_LOGP],
            train_batch[SampleBatch.VF_PREDS],
            action_dist,
            model.value_function(),
            policy.kl_coeff,
            mask,
            entropy_coeff=policy.entropy_coeff,
            clip_param=policy.config["clip_param"],
            vf_clip_param=policy.config["vf_clip_param"],
            vf_loss_coeff=policy.config["vf_loss_coeff"],
            use_gae=policy.config["use_gae"],
        )
        # accumulate loss with augmented obs
        aug_loss += policy.loss_obj.loss
    return aug_loss / aug_num
Ejemplo n.º 19
0
def actor_critic_loss(
        policy: Policy, model: ModelV2,
        dist_class: Type[TorchDistributionWrapper],
        train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for the Soft Actor Critic.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[TorchDistributionWrapper]: The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    # Should be True only for debugging purposes (e.g. test cases)!
    deterministic = policy.config["_deterministic_loss"]

    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches
    seq_lens = train_batch.get("seq_lens")

    model_out_t, state_in_t = model(
        {
            "obs": train_batch[SampleBatch.CUR_OBS],
            "prev_actions": train_batch[SampleBatch.PREV_ACTIONS],
            "prev_rewards": train_batch[SampleBatch.PREV_REWARDS],
            "is_training": True,
        }, state_batches, seq_lens)
    states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"])

    model_out_tp1, state_in_tp1 = model(
        {
            "obs": train_batch[SampleBatch.NEXT_OBS],
            "prev_actions": train_batch[SampleBatch.ACTIONS],
            "prev_rewards": train_batch[SampleBatch.REWARDS],
            "is_training": True,
        }, state_batches, seq_lens)
    states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"])

    target_model_out_tp1, target_state_in_tp1 = policy.target_model(
        {
            "obs": train_batch[SampleBatch.NEXT_OBS],
            "prev_actions": train_batch[SampleBatch.ACTIONS],
            "prev_rewards": train_batch[SampleBatch.REWARDS],
            "is_training": True,
        }, state_batches, seq_lens)
    target_states_in_tp1 = \
        policy.target_model.select_state(state_in_tp1,
                                         ["policy", "q", "twin_q"])

    alpha = torch.exp(model.log_alpha)

    # Discrete case.
    if model.discrete:
        # Get all action probs directly from pi and form their logp.
        log_pis_t = F.log_softmax(model.get_policy_output(
            model_out_t, states_in_t["policy"], seq_lens)[0],
                                  dim=-1)
        policy_t = torch.exp(log_pis_t)
        log_pis_tp1 = F.log_softmax(
            model.get_policy_output(model_out_tp1, states_in_tp1["policy"],
                                    seq_lens)[0], -1)
        policy_tp1 = torch.exp(log_pis_tp1)
        # Q-values.
        q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)[0]
        # Target Q-values.
        q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
                                                 target_states_in_tp1["q"],
                                                 seq_lens)[0]
        if policy.config["twin_q"]:
            twin_q_t = model.get_twin_q_values(model_out_t,
                                               states_in_t["twin_q"],
                                               seq_lens)[0]
            twin_q_tp1 = policy.target_model.get_twin_q_values(
                target_model_out_tp1, target_states_in_tp1["twin_q"],
                seq_lens)[0]
            q_tp1 = torch.min(q_tp1, twin_q_tp1)
        q_tp1 -= alpha * log_pis_tp1

        # Actually selected Q-values (from the actions batch).
        one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                            num_classes=q_t.size()[-1])
        q_t_selected = torch.sum(q_t * one_hot, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
        # Discrete case: "Best" means weighted by the policy (prob) outputs.
        q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1)
        q_tp1_best_masked = \
            (1.0 - train_batch[SampleBatch.DONES].float()) * \
            q_tp1_best
    # Continuous actions case.
    else:
        # Sample single actions from distribution.
        action_dist_class = _get_dist_class(policy, policy.config,
                                            policy.action_space)
        action_dist_t = action_dist_class(
            model.get_policy_output(model_out_t, states_in_t["policy"],
                                    seq_lens)[0], policy.model)
        policy_t = action_dist_t.sample() if not deterministic else \
            action_dist_t.deterministic_sample()
        log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
        action_dist_tp1 = action_dist_class(
            model.get_policy_output(model_out_tp1, states_in_tp1["policy"],
                                    seq_lens)[0], policy.model)
        policy_tp1 = action_dist_tp1.sample() if not deterministic else \
            action_dist_tp1.deterministic_sample()
        log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)

        # Q-values for the actually selected actions.
        q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens,
                                 train_batch[SampleBatch.ACTIONS])[0]
        if policy.config["twin_q"]:
            twin_q_t = model.get_twin_q_values(
                model_out_t, states_in_t["twin_q"], seq_lens,
                train_batch[SampleBatch.ACTIONS])[0]

        # Q-values for current policy in given current state.
        q_t_det_policy = model.get_q_values(model_out_t, states_in_t["q"],
                                            seq_lens, policy_t)[0]
        if policy.config["twin_q"]:
            twin_q_t_det_policy = model.get_twin_q_values(
                model_out_t, states_in_t["twin_q"], seq_lens, policy_t)[0]
            q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy)

        # Target q network evaluation.
        q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
                                                 target_states_in_tp1["q"],
                                                 seq_lens, policy_tp1)[0]
        if policy.config["twin_q"]:
            twin_q_tp1 = policy.target_model.get_twin_q_values(
                target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens,
                policy_tp1)[0]
            # Take min over both twin-NNs.
            q_tp1 = torch.min(q_tp1, twin_q_tp1)

        q_t_selected = torch.squeeze(q_t, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
        q_tp1 -= alpha * log_pis_tp1

        q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
        q_tp1_best_masked = \
            (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           (policy.config["gamma"]**policy.config["n_step"]) *
                           q_tp1_best_masked).detach()

    # BURNIN #
    B = state_batches[0].shape[0]
    T = q_t_selected.shape[0] // B
    seq_mask = sequence_mask(train_batch["seq_lens"], T)
    # Mask away also the burn-in sequence at the beginning.
    burn_in = policy.config["burn_in"]
    if burn_in > 0 and burn_in < T:
        seq_mask[:, :burn_in] = False

    seq_mask = seq_mask.reshape(-1)
    num_valid = torch.sum(seq_mask)

    def reduce_mean_valid(t):
        return torch.sum(t[seq_mask]) / num_valid

    # Compute the TD-error (potentially clipped).
    base_td_error = torch.abs(q_t_selected - q_t_selected_target)
    if policy.config["twin_q"]:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss = [
        reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                          huber_loss(base_td_error))
    ]
    if policy.config["twin_q"]:
        critic_loss.append(
            reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                              huber_loss(twin_td_error)))

    # Alpha- and actor losses.
    # Note: In the papers, alpha is used directly, here we take the log.
    # Discrete case: Multiply the action probs as weights with the original
    # loss terms (no expectations needed).
    if model.discrete:
        weighted_log_alpha_loss = policy_t.detach() * (
            -model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Sum up weighted terms and mean over all batch items.
        alpha_loss = reduce_mean_valid(
            torch.sum(weighted_log_alpha_loss, dim=-1))
        # Actor loss.
        actor_loss = reduce_mean_valid(
            torch.sum(
                torch.mul(
                    # NOTE: No stop_grad around policy output here
                    # (compare with q_t_det_policy for continuous case).
                    policy_t,
                    alpha.detach() * log_pis_t - q_t.detach()),
                dim=-1))
    else:
        alpha_loss = -reduce_mean_valid(
            model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Note: Do not detach q_t_det_policy here b/c is depends partly
        # on the policy vars (policy sample pushed through Q-net).
        # However, we must make sure `actor_loss` is not used to update
        # the Q-net(s)' variables.
        actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t -
                                       q_t_det_policy)

    # Save for stats function.
    policy.q_t = q_t * seq_mask[..., None]
    policy.policy_t = policy_t * seq_mask[..., None]
    policy.log_pis_t = log_pis_t * seq_mask[..., None]

    # Store td-error in model, such that for multi-GPU, we do not override
    # them during the parallel loss phase. TD-error tensor in final stats
    # can then be concatenated and retrieved for each individual batch item.
    model.td_error = td_error * seq_mask

    policy.actor_loss = actor_loss
    policy.critic_loss = critic_loss
    policy.alpha_loss = alpha_loss
    policy.log_alpha_value = model.log_alpha
    policy.alpha_value = alpha
    policy.target_entropy = model.target_entropy

    # Return all loss terms corresponding to our optimizers.
    return tuple([policy.actor_loss] + policy.critic_loss +
                 [policy.alpha_loss])
Ejemplo n.º 20
0
def mu_zero_loss(
        policy: Policy, model: ModelV2,
        dist_class: Type[TorchDistributionWrapper],
        train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:

    logits, state = model.from_batch(train_batch, is_training=True)
    curr_action_dist = dist_class(logits, model)

    mcts_policy = dist_class(train_batch["mcts_policy"], model)

    # RNN case: Mask away 0-padded chunks at end of time axis.
    if state:
        max_seq_len = torch.max(train_batch["seq_lens"])
        mask = sequence_mask(train_batch["seq_lens"],
                             max_seq_len,
                             time_major=model.is_time_major())
        mask = torch.reshape(mask, [-1])
        num_valid = torch.sum(mask)

        def reduce_mean_valid(t):
            return torch.sum(t[mask]) / num_valid

    # non-RNN case: No masking.
    else:
        mask = None
        reduce_mean_valid = torch.mean

    prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
                                  model)

    logp = curr_action_dist.logp(train_batch[SampleBatch.ACTIONS])
    logp_ratio = torch.exp(logp - train_batch[SampleBatch.ACTION_LOGP])

    action_kl = prev_action_dist.kl(curr_action_dist)
    mean_kl = reduce_mean_valid(action_kl)

    curr_entropy = curr_action_dist.entropy()
    mean_entropy = reduce_mean_valid(curr_entropy)

    surrogate_loss = torch.min(
        train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
        train_batch[Postprocessing.ADVANTAGES] *
        torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
                    1 + policy.config["clip_param"]))
    mean_policy_loss = reduce_mean_valid(-surrogate_loss)

    if policy.config["use_gae"]:
        prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
        value_fn_out = model.value_function()
        vf_loss1 = torch.pow(
            value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
        vf_clipped = prev_value_fn_out + torch.clamp(
            value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
            policy.config["vf_clip_param"])
        vf_loss2 = torch.pow(
            vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
        vf_loss = torch.max(vf_loss1, vf_loss2)
        mean_vf_loss = reduce_mean_valid(vf_loss)
        total_loss = reduce_mean_valid(
            -surrogate_loss * policy.config["surrogate_coeff"] +
            policy.kl_coeff * action_kl +
            policy.config["vf_loss_coeff"] * vf_loss -
            policy.entropy_coeff * curr_entropy)
    else:
        mean_vf_loss = 0.0
        total_loss = reduce_mean_valid(-surrogate_loss +
                                       policy.kl_coeff * action_kl -
                                       policy.entropy_coeff * curr_entropy)

    pred_reward = model.reward_function(policy_logits=logits)

    # rewards need to be float for proper bAcK pRopAgAtiOn
    reward_loss = torch.nn.functional.mse_loss(
        pred_reward, train_batch[SampleBatch.REWARDS].float())

    mcts_loss = torch.mean(-mcts_policy.dist.probs *
                           torch.log(curr_action_dist.dist.probs))

    total_loss += reward_loss + mcts_loss

    # Store stats in policy for stats_fn.
    policy._total_loss = total_loss
    policy._mean_policy_loss = mean_policy_loss
    policy._mean_vf_loss = mean_vf_loss
    policy._mean_entropy = mean_entropy
    policy._mean_kl = mean_kl
    policy._mean_reward_loss = reward_loss
    policy._mcts_loss = mcts_loss

    return total_loss