示例#1
0
    def forward(self, inputs: TensorType) -> TensorType:
        L = list(inputs.size())[1]  # length of segment
        H = self._num_heads  # number of attention heads
        D = self._head_dim  # attention head dimension

        qkv = self._qkv_layer(inputs)

        queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1)
        queries = queries[:, -L:]  # only query based on the segment

        queries = torch.reshape(queries, [-1, L, H, D])
        keys = torch.reshape(keys, [-1, L, H, D])
        values = torch.reshape(values, [-1, L, H, D])

        score = torch.einsum("bihd,bjhd->bijh", queries, keys)
        score = score / D**0.5

        # causal mask of the same length as the sequence
        mask = sequence_mask(torch.arange(1, L + 1), dtype=score.dtype)
        mask = mask[None, :, :, None]
        mask = mask.float()

        masked_score = score * mask + 1e30 * (mask - 1.)
        wmat = nn.functional.softmax(masked_score, dim=2)

        out = torch.einsum("bijh,bjhd->bihd", wmat, values)
        shape = list(out.size())[:2] + [H * D]
        #        temp = torch.cat(temp2, [H * D], dim=0)
        out = torch.reshape(out, shape)
        return self._linear_layer(out)
示例#2
0
def actor_critic_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: ActionDistribution,
    train_batch: SampleBatch,
) -> TensorType:
    logits, _ = model(train_batch)
    values = model.value_function()

    if policy.is_recurrent():
        B = len(train_batch[SampleBatch.SEQ_LENS])
        max_seq_len = logits.shape[0] // B
        mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                  max_seq_len)
        valid_mask = torch.reshape(mask_orig, [-1])
    else:
        valid_mask = torch.ones_like(values, dtype=torch.bool)

    dist = dist_class(logits, model)
    log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
    pi_err = -torch.sum(
        torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES],
                            valid_mask))

    # Compute a value function loss.
    if policy.config["use_critic"]:
        value_err = 0.5 * torch.sum(
            torch.pow(
                torch.masked_select(
                    values.reshape(-1) -
                    train_batch[Postprocessing.VALUE_TARGETS],
                    valid_mask,
                ),
                2.0,
            ))
    # Ignore the value function.
    else:
        value_err = 0.0

    entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask))

    total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] -
                  entropy * policy.entropy_coeff)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["entropy"] = entropy
    model.tower_stats["pi_err"] = pi_err
    model.tower_stats["value_err"] = value_err

    return total_loss
示例#3
0
    def forward(self,
                inputs: TensorType,
                memory: TensorType = None) -> TensorType:
        T = list(inputs.size())[1]  # length of segment (time)
        H = self._num_heads  # number of attention heads
        d = self._head_dim  # attention head dimension

        # Add previous memory chunk (as const, w/o gradient) to input.
        # Tau (number of (prev) time slices in each memory chunk).
        Tau = list(memory.shape)[1]
        inputs = torch.cat((memory.detach(), inputs), dim=1)

        # Apply the Layer-Norm.
        if self._input_layernorm is not None:
            inputs = self._input_layernorm(inputs)

        qkv = self._qkv_layer(inputs)

        queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1)
        # Cut out Tau memory timesteps from query.
        queries = queries[:, -T:]

        queries = torch.reshape(queries, [-1, T, H, d])
        keys = torch.reshape(keys, [-1, Tau + T, H, d])
        values = torch.reshape(values, [-1, Tau + T, H, d])

        R = self._pos_proj(self._rel_pos_embedding(Tau + T))
        R = torch.reshape(R, [Tau + T, H, d])

        # b=batch
        # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
        # h=head
        # d=head-dim (over which we will reduce-sum)
        score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys)
        pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R)
        score = score + self.rel_shift(pos_score)
        score = score / d**0.5

        # causal mask of the same length as the sequence
        mask = sequence_mask(torch.arange(Tau + 1, Tau + T + 1),
                             dtype=score.dtype).to(score.device)
        mask = mask[None, :, :, None]

        masked_score = score * mask + 1e30 * (mask.float() - 1.0)
        wmat = nn.functional.softmax(masked_score, dim=2)

        out = torch.einsum("bijh,bjhd->bihd", wmat, values)
        shape = list(out.shape)[:2] + [H * d]
        out = torch.reshape(out, shape)

        return self._linear_layer(out)
示例#4
0
def build_vtrace_loss(policy, model, dist_class, train_batch):
    model_out, _ = model(train_batch)
    action_dist = dist_class(model_out, model)

    if isinstance(policy.action_space, gym.spaces.Discrete):
        is_multidiscrete = False
        output_hidden_shape = [policy.action_space.n]
    elif isinstance(policy.action_space, gym.spaces.MultiDiscrete):
        is_multidiscrete = True
        output_hidden_shape = policy.action_space.nvec.astype(np.int32)
    else:
        is_multidiscrete = False
        output_hidden_shape = 1

    def _make_time_major(*args, **kw):
        return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
                               *args, **kw)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
    if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
        unpacked_behaviour_logits = torch.split(behaviour_logits,
                                                list(output_hidden_shape),
                                                dim=1)
        unpacked_outputs = torch.split(model_out,
                                       list(output_hidden_shape),
                                       dim=1)
    else:
        unpacked_behaviour_logits = torch.chunk(behaviour_logits,
                                                output_hidden_shape,
                                                dim=1)
        unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1)
    values = model.value_function()

    if policy.is_recurrent():
        max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
        mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                  max_seq_len)
        mask = torch.reshape(mask_orig, [-1])
    else:
        mask = torch.ones_like(rewards)

    # Prepare actions for loss.
    loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions,
                                                                    dim=1)

    # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc.
    drop_last = policy.config["vtrace_drop_last_ts"]
    loss = VTraceLoss(
        actions=_make_time_major(loss_actions, drop_last=drop_last),
        actions_logp=_make_time_major(action_dist.logp(actions),
                                      drop_last=drop_last),
        actions_entropy=_make_time_major(action_dist.entropy(),
                                         drop_last=drop_last),
        dones=_make_time_major(dones, drop_last=drop_last),
        behaviour_action_logp=_make_time_major(behaviour_action_logp,
                                               drop_last=drop_last),
        behaviour_logits=_make_time_major(unpacked_behaviour_logits,
                                          drop_last=drop_last),
        target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last),
        discount=policy.config["gamma"],
        rewards=_make_time_major(rewards, drop_last=drop_last),
        values=_make_time_major(values, drop_last=drop_last),
        bootstrap_value=_make_time_major(values)[-1],
        dist_class=TorchCategorical if is_multidiscrete else dist_class,
        model=model,
        valid_mask=_make_time_major(mask, drop_last=drop_last),
        config=policy.config,
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        entropy_coeff=policy.entropy_coeff,
        clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
        clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"],
    )

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["pi_loss"] = loss.pi_loss
    model.tower_stats["vf_loss"] = loss.vf_loss
    model.tower_stats["entropy"] = loss.entropy
    model.tower_stats["mean_entropy"] = loss.mean_entropy
    model.tower_stats["total_loss"] = loss.total_loss

    values_batched = make_time_major(
        policy,
        train_batch.get(SampleBatch.SEQ_LENS),
        values,
        drop_last=policy.config["vtrace"] and drop_last,
    )
    model.tower_stats["vf_explained_var"] = explained_variance(
        torch.reshape(loss.value_targets, [-1]),
        torch.reshape(values_batched, [-1]))

    return loss.total_loss
示例#5
0
    def loss(self, model: ModelV2, dist_class: Type[ActionDistribution],
             train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
        """Constructs the loss for Proximal Policy Objective.

        Args:
            model: The Model to calculate the loss for.
            dist_class: The action distr. class.
            train_batch: The training data.

        Returns:
            The PPO loss tensor given the input batch.
        """

        logits, state = model(train_batch)
        curr_action_dist = dist_class(logits, model)

        # RNN case: Mask away 0-padded chunks at end of time axis.
        if state:
            B = len(train_batch[SampleBatch.SEQ_LENS])
            max_seq_len = logits.shape[0] // B
            mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                 max_seq_len,
                                 time_major=model.is_time_major())
            mask = torch.reshape(mask, [-1])
            num_valid = torch.sum(mask)

            def reduce_mean_valid(t):
                return torch.sum(t[mask]) / num_valid

        # non-RNN case: No masking.
        else:
            mask = None
            reduce_mean_valid = torch.mean

        prev_action_dist = dist_class(
            train_batch[SampleBatch.ACTION_DIST_INPUTS], model)

        logp_ratio = torch.exp(
            curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
            train_batch[SampleBatch.ACTION_LOGP])

        # Only calculate kl loss if necessary (kl-coeff > 0.0).
        if self.config["kl_coeff"] > 0.0:
            action_kl = prev_action_dist.kl(curr_action_dist)
            mean_kl_loss = reduce_mean_valid(action_kl)
        else:
            mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device)

        curr_entropy = curr_action_dist.entropy()
        mean_entropy = reduce_mean_valid(curr_entropy)

        surrogate_loss = torch.min(
            train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
            train_batch[Postprocessing.ADVANTAGES] *
            torch.clamp(logp_ratio, 1 - self.config["clip_param"],
                        1 + self.config["clip_param"]))
        mean_policy_loss = reduce_mean_valid(-surrogate_loss)

        # Compute a value function loss.
        if self.config["use_critic"]:
            prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
            value_fn_out = model.value_function()
            vf_loss1 = torch.pow(
                value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
            vf_clipped = prev_value_fn_out + torch.clamp(
                value_fn_out - prev_value_fn_out,
                -self.config["vf_clip_param"], self.config["vf_clip_param"])
            vf_loss2 = torch.pow(
                vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
            vf_loss = torch.max(vf_loss1, vf_loss2)
            mean_vf_loss = reduce_mean_valid(vf_loss)
        # Ignore the value function.
        else:
            vf_loss = mean_vf_loss = 0.0

        total_loss = reduce_mean_valid(-surrogate_loss +
                                       self.config["vf_loss_coeff"] * vf_loss -
                                       self.entropy_coeff * curr_entropy)

        # Add mean_kl_loss (already processed through `reduce_mean_valid`),
        # if necessary.
        if self.config["kl_coeff"] > 0.0:
            total_loss += self.kl_coeff * mean_kl_loss

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["total_loss"] = total_loss
        model.tower_stats["mean_policy_loss"] = mean_policy_loss
        model.tower_stats["mean_vf_loss"] = mean_vf_loss
        model.tower_stats["vf_explained_var"] = explained_variance(
            train_batch[Postprocessing.VALUE_TARGETS], model.value_function())
        model.tower_stats["mean_entropy"] = mean_entropy
        model.tower_stats["mean_kl_loss"] = mean_kl_loss

        return total_loss
示例#6
0
def actor_critic_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: Type[TorchDistributionWrapper],
    train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for the Soft Actor Critic.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[TorchDistributionWrapper]: The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    target_model = policy.target_models[model]

    # Should be True only for debugging purposes (e.g. test cases)!
    deterministic = policy.config["_deterministic_loss"]

    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches
    seq_lens = train_batch.get(SampleBatch.SEQ_LENS)

    model_out_t, state_in_t = model(
        SampleBatch(
            obs=train_batch[SampleBatch.CUR_OBS],
            prev_actions=train_batch[SampleBatch.PREV_ACTIONS],
            prev_rewards=train_batch[SampleBatch.PREV_REWARDS],
            _is_training=True,
        ),
        state_batches,
        seq_lens,
    )
    states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"])

    model_out_tp1, state_in_tp1 = model(
        SampleBatch(
            obs=train_batch[SampleBatch.NEXT_OBS],
            prev_actions=train_batch[SampleBatch.ACTIONS],
            prev_rewards=train_batch[SampleBatch.REWARDS],
            _is_training=True,
        ),
        state_batches,
        seq_lens,
    )
    states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"])

    target_model_out_tp1, target_state_in_tp1 = target_model(
        SampleBatch(
            obs=train_batch[SampleBatch.NEXT_OBS],
            prev_actions=train_batch[SampleBatch.ACTIONS],
            prev_rewards=train_batch[SampleBatch.REWARDS],
            _is_training=True,
        ),
        state_batches,
        seq_lens,
    )
    target_states_in_tp1 = target_model.select_state(state_in_tp1,
                                                     ["policy", "q", "twin_q"])

    alpha = torch.exp(model.log_alpha)

    # Discrete case.
    if model.discrete:
        # Get all action probs directly from pi and form their logp.
        action_dist_inputs_t, _ = model.get_action_model_outputs(
            model_out_t, states_in_t["policy"], seq_lens)
        log_pis_t = F.log_softmax(
            action_dist_inputs_t,
            dim=-1,
        )
        policy_t = torch.exp(log_pis_t)

        action_dist_inputs_tp1, _ = model.get_action_model_outputs(
            model_out_tp1, states_in_tp1["policy"], seq_lens)
        log_pis_tp1 = F.log_softmax(
            action_dist_inputs_tp1,
            -1,
        )
        policy_tp1 = torch.exp(log_pis_tp1)

        # Q-values.
        q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)
        # Target Q-values.
        q_tp1, _ = target_model.get_q_values(target_model_out_tp1,
                                             target_states_in_tp1["q"],
                                             seq_lens)
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(model_out_t,
                                                  states_in_t["twin_q"],
                                                  seq_lens)
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens)
            q_tp1 = torch.min(q_tp1, twin_q_tp1)
        q_tp1 -= alpha * log_pis_tp1

        # Actually selected Q-values (from the actions batch).
        one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                            num_classes=q_t.size()[-1])
        q_t_selected = torch.sum(q_t * one_hot, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
        # Discrete case: "Best" means weighted by the policy (prob) outputs.
        q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best
    # Continuous actions case.
    else:
        # Sample single actions from distribution.
        action_dist_class = _get_dist_class(policy, policy.config,
                                            policy.action_space)
        action_dist_inputs_t, _ = model.get_action_model_outputs(
            model_out_t, states_in_t["policy"], seq_lens)
        action_dist_t = action_dist_class(
            action_dist_inputs_t,
            model,
        )
        policy_t = (action_dist_t.sample() if not deterministic else
                    action_dist_t.deterministic_sample())
        log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
        action_dist_inputs_t, _ = model.get_action_model_outputs(
            model_out_tp1, states_in_tp1["policy"], seq_lens)
        action_dist_tp1 = action_dist_class(
            action_dist_inputs_t,
            model,
        )
        policy_tp1 = (action_dist_tp1.sample() if not deterministic else
                      action_dist_tp1.deterministic_sample())
        log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)

        # Q-values for the actually selected actions.
        q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens,
                                    train_batch[SampleBatch.ACTIONS])
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(
                model_out_t,
                states_in_t["twin_q"],
                seq_lens,
                train_batch[SampleBatch.ACTIONS],
            )

        # Q-values for current policy in given current state.
        q_t_det_policy, _ = model.get_q_values(model_out_t, states_in_t["q"],
                                               seq_lens, policy_t)
        if policy.config["twin_q"]:
            twin_q_t_det_policy, _ = model.get_twin_q_values(
                model_out_t, states_in_t["twin_q"], seq_lens, policy_t)
            q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy)

        # Target q network evaluation.
        q_tp1 = target_model.get_q_values(target_model_out_tp1,
                                          target_states_in_tp1["q"], seq_lens,
                                          policy_tp1)
        if policy.config["twin_q"]:
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1,
                target_states_in_tp1["twin_q"],
                seq_lens,
                policy_tp1,
            )
            # Take min over both twin-NNs.
            q_tp1 = torch.min(q_tp1, twin_q_tp1)

        q_t_selected = torch.squeeze(q_t, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
        q_tp1 -= alpha * log_pis_tp1

        q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           (policy.config["gamma"]**policy.config["n_step"]) *
                           q_tp1_best_masked).detach()

    # BURNIN #
    B = state_batches[0].shape[0]
    T = q_t_selected.shape[0] // B
    seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)
    # Mask away also the burn-in sequence at the beginning.
    burn_in = policy.config["burn_in"]
    if burn_in > 0 and burn_in < T:
        seq_mask[:, :burn_in] = False

    seq_mask = seq_mask.reshape(-1)
    num_valid = torch.sum(seq_mask)

    def reduce_mean_valid(t):
        return torch.sum(t[seq_mask]) / num_valid

    # Compute the TD-error (potentially clipped).
    base_td_error = torch.abs(q_t_selected - q_t_selected_target)
    if policy.config["twin_q"]:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss = [
        reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                          huber_loss(base_td_error))
    ]
    if policy.config["twin_q"]:
        critic_loss.append(
            reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                              huber_loss(twin_td_error)))
    td_error = td_error * seq_mask

    # Alpha- and actor losses.
    # Note: In the papers, alpha is used directly, here we take the log.
    # Discrete case: Multiply the action probs as weights with the original
    # loss terms (no expectations needed).
    if model.discrete:
        weighted_log_alpha_loss = policy_t.detach() * (
            -model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Sum up weighted terms and mean over all batch items.
        alpha_loss = reduce_mean_valid(
            torch.sum(weighted_log_alpha_loss, dim=-1))
        # Actor loss.
        actor_loss = reduce_mean_valid(
            torch.sum(
                torch.mul(
                    # NOTE: No stop_grad around policy output here
                    # (compare with q_t_det_policy for continuous case).
                    policy_t,
                    alpha.detach() * log_pis_t - q_t.detach(),
                ),
                dim=-1,
            ))
    else:
        alpha_loss = -reduce_mean_valid(
            model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Note: Do not detach q_t_det_policy here b/c is depends partly
        # on the policy vars (policy sample pushed through Q-net).
        # However, we must make sure `actor_loss` is not used to update
        # the Q-net(s)' variables.
        actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t -
                                       q_t_det_policy)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_t"] = q_t * seq_mask[..., None]
    model.tower_stats["policy_t"] = policy_t * seq_mask[..., None]
    model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None]
    model.tower_stats["actor_loss"] = actor_loss
    model.tower_stats["critic_loss"] = critic_loss
    model.tower_stats["alpha_loss"] = alpha_loss
    # Store per time chunk (b/c we need only one mean
    # prioritized replay weight per stored sequence).
    model.tower_stats["td_error"] = torch.mean(td_error.reshape([-1, T]),
                                               dim=-1)

    # Return all loss terms corresponding to our optimizers.
    return tuple([actor_loss] + critic_loss + [alpha_loss])
示例#7
0
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[TorchDistributionWrapper],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        """Constructs the loss function.

        Args:
            model: The Model to calculate the loss for.
            dist_class: The action distr. class.
            train_batch: The training data.

        Returns:
            The A3C loss tensor given the input batch.
        """
        logits, _ = model(train_batch)
        values = model.value_function()

        if self.is_recurrent():
            B = len(train_batch[SampleBatch.SEQ_LENS])
            max_seq_len = logits.shape[0] // B
            mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                      max_seq_len)
            valid_mask = torch.reshape(mask_orig, [-1])
        else:
            valid_mask = torch.ones_like(values, dtype=torch.bool)

        dist = dist_class(logits, model)
        log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
        pi_err = -torch.sum(
            torch.masked_select(
                log_probs * train_batch[Postprocessing.ADVANTAGES],
                valid_mask))

        # Compute a value function loss.
        if self.config["use_critic"]:
            value_err = 0.5 * torch.sum(
                torch.pow(
                    torch.masked_select(
                        values.reshape(-1) -
                        train_batch[Postprocessing.VALUE_TARGETS],
                        valid_mask,
                    ),
                    2.0,
                ))
        # Ignore the value function.
        else:
            value_err = 0.0

        entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask))

        total_loss = (pi_err + value_err * self.config["vf_loss_coeff"] -
                      entropy * self.entropy_coeff)

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["entropy"] = entropy
        model.tower_stats["pi_err"] = pi_err
        model.tower_stats["value_err"] = value_err

        return total_loss
示例#8
0
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[ActionDistribution],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        """Constructs the loss for APPO.

        With IS modifications and V-trace for Advantage Estimation.

        Args:
            model (ModelV2): The Model to calculate the loss for.
            dist_class (Type[ActionDistribution]): The action distr. class.
            train_batch (SampleBatch): The training data.

        Returns:
            Union[TensorType, List[TensorType]]: A single loss tensor or a list
                of loss tensors.
        """
        target_model = self.target_models[model]

        model_out, _ = model(train_batch)
        action_dist = dist_class(model_out, model)

        if isinstance(self.action_space, gym.spaces.Discrete):
            is_multidiscrete = False
            output_hidden_shape = [self.action_space.n]
        elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
            is_multidiscrete = True
            output_hidden_shape = self.action_space.nvec.astype(np.int32)
        else:
            is_multidiscrete = False
            output_hidden_shape = 1

        def _make_time_major(*args, **kwargs):
            return make_time_major(
                self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs
            )

        actions = train_batch[SampleBatch.ACTIONS]
        dones = train_batch[SampleBatch.DONES]
        rewards = train_batch[SampleBatch.REWARDS]
        behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

        target_model_out, _ = target_model(train_batch)

        prev_action_dist = dist_class(behaviour_logits, model)
        values = model.value_function()
        values_time_major = _make_time_major(values)

        drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"]

        if self.is_recurrent():
            max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
            mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
            mask = torch.reshape(mask, [-1])
            mask = _make_time_major(mask, drop_last=drop_last)
            num_valid = torch.sum(mask)

            def reduce_mean_valid(t):
                return torch.sum(t[mask]) / num_valid

        else:
            reduce_mean_valid = torch.mean

        if self.config["vtrace"]:
            logger.debug(
                "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})"
            )

            old_policy_behaviour_logits = target_model_out.detach()
            old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

            if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
                unpacked_behaviour_logits = torch.split(
                    behaviour_logits, list(output_hidden_shape), dim=1
                )
                unpacked_old_policy_behaviour_logits = torch.split(
                    old_policy_behaviour_logits, list(output_hidden_shape), dim=1
                )
            else:
                unpacked_behaviour_logits = torch.chunk(
                    behaviour_logits, output_hidden_shape, dim=1
                )
                unpacked_old_policy_behaviour_logits = torch.chunk(
                    old_policy_behaviour_logits, output_hidden_shape, dim=1
                )

            # Prepare actions for loss.
            loss_actions = (
                actions if is_multidiscrete else torch.unsqueeze(actions, dim=1)
            )

            # Prepare KL for loss.
            action_kl = _make_time_major(
                old_policy_action_dist.kl(action_dist), drop_last=drop_last
            )

            # Compute vtrace on the CPU for better perf.
            vtrace_returns = vtrace.multi_from_logits(
                behaviour_policy_logits=_make_time_major(
                    unpacked_behaviour_logits, drop_last=drop_last
                ),
                target_policy_logits=_make_time_major(
                    unpacked_old_policy_behaviour_logits, drop_last=drop_last
                ),
                actions=torch.unbind(
                    _make_time_major(loss_actions, drop_last=drop_last), dim=2
                ),
                discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float())
                * self.config["gamma"],
                rewards=_make_time_major(rewards, drop_last=drop_last),
                values=values_time_major[:-1] if drop_last else values_time_major,
                bootstrap_value=values_time_major[-1],
                dist_class=TorchCategorical if is_multidiscrete else dist_class,
                model=model,
                clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
                clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"],
            )

            actions_logp = _make_time_major(
                action_dist.logp(actions), drop_last=drop_last
            )
            prev_actions_logp = _make_time_major(
                prev_action_dist.logp(actions), drop_last=drop_last
            )
            old_policy_actions_logp = _make_time_major(
                old_policy_action_dist.logp(actions), drop_last=drop_last
            )
            is_ratio = torch.clamp(
                torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
            )
            logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
            self._is_ratio = is_ratio

            advantages = vtrace_returns.pg_advantages.to(logp_ratio.device)
            surrogate_loss = torch.min(
                advantages * logp_ratio,
                advantages
                * torch.clamp(
                    logp_ratio,
                    1 - self.config["clip_param"],
                    1 + self.config["clip_param"],
                ),
            )

            mean_kl_loss = reduce_mean_valid(action_kl)
            mean_policy_loss = -reduce_mean_valid(surrogate_loss)

            # The value function loss.
            value_targets = vtrace_returns.vs.to(values_time_major.device)
            if drop_last:
                delta = values_time_major[:-1] - value_targets
            else:
                delta = values_time_major - value_targets
            mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

            # The entropy loss.
            mean_entropy = reduce_mean_valid(
                _make_time_major(action_dist.entropy(), drop_last=drop_last)
            )

        else:
            logger.debug("Using PPO surrogate loss (vtrace=False)")

            # Prepare KL for Loss
            action_kl = _make_time_major(prev_action_dist.kl(action_dist))

            actions_logp = _make_time_major(action_dist.logp(actions))
            prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
            logp_ratio = torch.exp(actions_logp - prev_actions_logp)

            advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
            surrogate_loss = torch.min(
                advantages * logp_ratio,
                advantages
                * torch.clamp(
                    logp_ratio,
                    1 - self.config["clip_param"],
                    1 + self.config["clip_param"],
                ),
            )

            mean_kl_loss = reduce_mean_valid(action_kl)
            mean_policy_loss = -reduce_mean_valid(surrogate_loss)

            # The value function loss.
            value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS])
            delta = values_time_major - value_targets
            mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

            # The entropy loss.
            mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))

        # The summed weighted loss
        total_loss = (
            mean_policy_loss
            + mean_vf_loss * self.config["vf_loss_coeff"]
            - mean_entropy * self.entropy_coeff
        )

        # Optional additional KL Loss
        if self.config["use_kl_loss"]:
            total_loss += self.kl_coeff * mean_kl_loss

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["total_loss"] = total_loss
        model.tower_stats["mean_policy_loss"] = mean_policy_loss
        model.tower_stats["mean_kl_loss"] = mean_kl_loss
        model.tower_stats["mean_vf_loss"] = mean_vf_loss
        model.tower_stats["mean_entropy"] = mean_entropy
        model.tower_stats["value_targets"] = value_targets
        model.tower_stats["vf_explained_var"] = explained_variance(
            torch.reshape(value_targets, [-1]),
            torch.reshape(
                values_time_major[:-1] if drop_last else values_time_major, [-1]
            ),
        )

        return total_loss
示例#9
0
def r2d2_loss(policy: Policy, model, _,
              train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for R2D2TorchPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    target_model = policy.target_models[model]
    config = policy.config

    # Construct internal state inputs.
    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches

    # Q-network evaluation (at t).
    q, _, _, _ = compute_q_values(policy,
                                  model,
                                  train_batch,
                                  state_batches=state_batches,
                                  seq_lens=train_batch.get(
                                      SampleBatch.SEQ_LENS),
                                  explore=False,
                                  is_training=True)

    # Target Q-network evaluation (at t+1).
    q_target, _, _, _ = compute_q_values(policy,
                                         target_model,
                                         train_batch,
                                         state_batches=state_batches,
                                         seq_lens=train_batch.get(
                                             SampleBatch.SEQ_LENS),
                                         explore=False,
                                         is_training=True)

    actions = train_batch[SampleBatch.ACTIONS].long()
    dones = train_batch[SampleBatch.DONES].float()
    rewards = train_batch[SampleBatch.REWARDS]
    weights = train_batch[PRIO_WEIGHTS]

    B = state_batches[0].shape[0]
    T = q.shape[0] // B

    # Q scores for actions which we know were selected in the given state.
    one_hot_selection = F.one_hot(actions, policy.action_space.n)
    q_selected = torch.sum(
        torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=q.device)) *
        one_hot_selection, 1)

    if config["double_q"]:
        best_actions = torch.argmax(q, dim=1)
    else:
        best_actions = torch.argmax(q_target, dim=1)

    best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n)
    q_target_best = torch.sum(
        torch.where(q_target > FLOAT_MIN, q_target,
                    torch.tensor(0.0, device=q_target.device)) *
        best_actions_one_hot,
        dim=1)

    if config["num_atoms"] > 1:
        raise ValueError("Distributional R2D2 not supported yet!")
    else:
        q_target_best_masked_tp1 = (1.0 - dones) * torch.cat([
            q_target_best[1:],
            torch.tensor([0.0], device=q_target_best.device)
        ])

        if config["use_h_function"]:
            h_inv = h_inverse(q_target_best_masked_tp1,
                              config["h_function_epsilon"])
            target = h_function(
                rewards + config["gamma"]**config["n_step"] * h_inv,
                config["h_function_epsilon"])
        else:
            target = rewards + \
                config["gamma"] ** config["n_step"] * q_target_best_masked_tp1

        # Seq-mask all loss-related terms.
        seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1]
        # Mask away also the burn-in sequence at the beginning.
        burn_in = policy.config["burn_in"]
        if burn_in > 0 and burn_in < T:
            seq_mask[:, :burn_in] = False

        num_valid = torch.sum(seq_mask)

        def reduce_mean_valid(t):
            return torch.sum(t[seq_mask]) / num_valid

        # Make sure use the correct time indices:
        # Q(t) - [gamma * r + Q^(t+1)]
        q_selected = q_selected.reshape([B, T])[:, :-1]
        td_error = q_selected - target.reshape([B, T])[:, :-1].detach()
        td_error = td_error * seq_mask
        weights = weights.reshape([B, T])[:, :-1]
        total_loss = reduce_mean_valid(weights * huber_loss(td_error))

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["total_loss"] = total_loss
        model.tower_stats["mean_q"] = reduce_mean_valid(q_selected)
        model.tower_stats["min_q"] = torch.min(q_selected)
        model.tower_stats["max_q"] = torch.max(q_selected)
        model.tower_stats["mean_td_error"] = reduce_mean_valid(td_error)
        # Store per time chunk (b/c we need only one mean
        # prioritized replay weight per stored sequence).
        model.tower_stats["td_error"] = torch.mean(td_error, dim=-1)

    return total_loss