示例#1
0
def centralized_critic_postprocessing(
    policy, sample_batch, other_agent_batches=None, episode=None
):
    pytorch = policy.config["framework"] == "torch"
    if (pytorch and hasattr(policy, "compute_central_vf")) or (
        not pytorch and policy.loss_initialized()
    ):
        assert other_agent_batches is not None
        [(_, opponent_batch)] = list(other_agent_batches.values())

        # also record the opponent obs and actions in the trajectory
        sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS]
        sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS]

        # overwrite default VF prediction with the central VF
        if args.framework == "torch":
            sample_batch[SampleBatch.VF_PREDS] = (
                policy.compute_central_vf(
                    convert_to_torch_tensor(
                        sample_batch[SampleBatch.CUR_OBS], policy.device
                    ),
                    convert_to_torch_tensor(sample_batch[OPPONENT_OBS], policy.device),
                    convert_to_torch_tensor(
                        sample_batch[OPPONENT_ACTION], policy.device
                    ),
                )
                .cpu()
                .detach()
                .numpy()
            )
        else:
            sample_batch[SampleBatch.VF_PREDS] = convert_to_numpy(
                policy.compute_central_vf(
                    sample_batch[SampleBatch.CUR_OBS],
                    sample_batch[OPPONENT_OBS],
                    sample_batch[OPPONENT_ACTION],
                )
            )
    else:
        # Policy hasn't been initialized yet, use zeros.
        sample_batch[OPPONENT_OBS] = np.zeros_like(sample_batch[SampleBatch.CUR_OBS])
        sample_batch[OPPONENT_ACTION] = np.zeros_like(sample_batch[SampleBatch.ACTIONS])
        sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
            sample_batch[SampleBatch.REWARDS], dtype=np.float32
        )

    completed = sample_batch["dones"][-1]
    if completed:
        last_r = 0.0
    else:
        last_r = sample_batch[SampleBatch.VF_PREDS][-1]

    train_batch = compute_advantages(
        sample_batch,
        last_r,
        policy.config["gamma"],
        policy.config["lambda"],
        use_gae=policy.config["use_gae"],
    )
    return train_batch
示例#2
0
 def _translate_weights_to_torch(self, weights_dict, map_):
     model_dict = {
         map_[k]: convert_to_torch_tensor(
             np.transpose(v) if re.search("kernel", k) else v)
         for k, v in weights_dict.items() if re.search(
             "default_policy/(actor_(hidden_0|out)|sequential(_1)?)/", k)
     }
     model_dict["policy_model.action_out_squashed.low_action"] = \
         convert_to_torch_tensor(np.array([0.0]))
     model_dict["policy_model.action_out_squashed.action_range"] = \
         convert_to_torch_tensor(np.array([1.0]))
     return model_dict
示例#3
0
    def compute_actions(
        self,
        obs_batch: Union[List[TensorStructType], TensorStructType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Union[List[TensorStructType],
                                 TensorStructType] = None,
        prev_reward_batch: Union[List[TensorStructType],
                                 TensorStructType] = None,
        info_batch: Optional[Dict[str, list]] = None,
        episodes: Optional[List["Episode"]] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        **kwargs,
    ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:

        with torch.no_grad():
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS: obs_batch,
                "is_training": False,
            })
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(
                    prev_action_batch)
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = np.asarray(
                    prev_reward_batch)
            state_batches = [
                convert_to_torch_tensor(s, self.device)
                for s in (state_batches or [])
            ]
            return self._compute_action_helper(input_dict, state_batches,
                                               seq_lens, explore, timestep)
示例#4
0
    def _compute_actions(policy,
                         obs_batch,
                         add_noise=False,
                         update=True,
                         **kwargs):
        # Batch is given as list -> Try converting to numpy first.
        if isinstance(obs_batch, list) and len(obs_batch) == 1:
            obs_batch = obs_batch[0]
        observation = policy.preprocessor.transform(obs_batch)
        observation = policy.observation_filter(observation[None],
                                                update=update)

        observation = convert_to_torch_tensor(observation, policy.device)
        dist_inputs, _ = policy.model({SampleBatch.CUR_OBS: observation}, [],
                                      None)
        dist = policy.dist_class(dist_inputs, policy.model)
        action = dist.sample()

        def _add_noise(single_action, single_action_space):
            single_action = single_action.detach().cpu().numpy()
            if add_noise and isinstance(single_action_space, gym.spaces.Box) \
                    and single_action_space.dtype.name.startswith("float"):
                single_action += np.random.randn(*single_action.shape) * \
                                 policy.action_noise_std
            return single_action

        action = tree.map_structure(_add_noise, action,
                                    policy.action_space_struct)
        action = unbatch(action)
        return action, [], {}
示例#5
0
文件: test_sac.py 项目: ijrsvt/ray
    def _translate_weights_to_torch(self, weights_dict, map_):
        model_dict = {
            map_[k]: convert_to_torch_tensor(
                np.transpose(v) if re.search("kernel", k) else np.
                array([v]) if re.search("log_alpha", k) else v)
            for i, (k, v) in enumerate(weights_dict.items()) if i < 13
        }

        return model_dict
示例#6
0
 def to_device(self, device, framework="torch"):
     """TODO: transfer batch to given device as framework tensor."""
     if framework == "torch":
         assert torch is not None
         for k, v in self.items():
             if isinstance(v, np.ndarray) and v.dtype != object:
                 self[k] = convert_to_torch_tensor(v, device)
     else:
         raise NotImplementedError
     return self
示例#7
0
 def set_state(self, state: dict) -> None:
     # Set optimizer vars first.
     optimizer_vars = state.get("_optimizer_variables", None)
     if optimizer_vars:
         assert len(optimizer_vars) == len(self._optimizers)
         for o, s in zip(self._optimizers, optimizer_vars):
             optim_state_dict = convert_to_torch_tensor(s, device=self.device)
             o.load_state_dict(optim_state_dict)
     # Set exploration's state.
     if hasattr(self, "exploration") and "_exploration_state" in state:
         self.exploration.set_state(state=state["_exploration_state"])
     # Then the Policy's (NN) weights.
     super().set_state(state)
示例#8
0
 def predict_model_batches(self, obs, actions, device=None):
     """Used by worker who gather trajectories via TD models."""
     pre_obs = obs
     if self.normalize_data:
         obs = normalize(obs, self.normalizations[SampleBatch.CUR_OBS])
         actions = normalize(actions,
                             self.normalizations[SampleBatch.ACTIONS])
     x = np.concatenate([obs, actions], axis=-1)
     x = convert_to_torch_tensor(x, device=device)
     delta = self.forward(x).detach().cpu().numpy()
     if self.normalize_data:
         delta = denormalize(delta, self.normalizations["delta"])
     new_obs = pre_obs + delta
     clipped_obs = np.clip(new_obs, self.env_obs_space.low,
                           self.env_obs_space.high)
     return clipped_obs
示例#9
0
 def set_weights(self, weights: ModelWeights) -> None:
     weights = convert_to_torch_tensor(weights, device=self.device)
     self.model.load_state_dict(weights)
示例#10
0
    def compute_log_likelihoods(
        self,
        actions: Union[List[TensorStructType], TensorStructType],
        obs_batch: Union[List[TensorStructType], TensorStructType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[Union[List[TensorStructType],
                                          TensorStructType]] = None,
        prev_reward_batch: Optional[Union[List[TensorStructType],
                                          TensorStructType]] = None,
        actions_normalized: bool = True,
    ) -> TensorType:

        if is_overridden(self.action_sampler_fn) and not is_overridden(
                self.action_distribution_fn):
            raise ValueError("Cannot compute log-prob/likelihood w/o an "
                             "`action_distribution_fn` and a provided "
                             "`action_sampler_fn`!")

        with torch.no_grad():
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS: obs_batch,
                SampleBatch.ACTIONS: actions
            })
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            state_batches = [
                convert_to_torch_tensor(s, self.device)
                for s in (state_batches or [])
            ]

            # Exploration hook before each forward pass.
            self.exploration.before_compute_actions(explore=False)

            # Action dist class and inputs are generated via custom function.
            if is_overridden(self.action_distribution_fn):
                dist_inputs, dist_class, state_out = self.action_distribution_fn(
                    self.model,
                    input_dict=input_dict,
                    state_batches=state_batches,
                    seq_lens=seq_lens,
                    explore=False,
                    is_training=False,
                )
            # Default action-dist inputs calculation.
            else:
                dist_class = self.dist_class
                dist_inputs, _ = self.model(input_dict, state_batches,
                                            seq_lens)

            action_dist = dist_class(dist_inputs, self.model)

            # Normalize actions if necessary.
            actions = input_dict[SampleBatch.ACTIONS]
            if not actions_normalized and self.config["normalize_actions"]:
                actions = normalize_action(actions, self.action_space_struct)

            log_likelihoods = action_dist.logp(actions)

            return log_likelihoods
示例#11
0
 def set_ensemble_weights(policy, pid, weights):
     weights = weights[pid]
     weights = convert_to_torch_tensor(weights, device=policy.device)
     model = policy.dynamics_model
     model.load_state_dict(weights)
示例#12
0
    def compute_log_likelihoods(
        self,
        actions: Union[List[TensorStructType], TensorStructType],
        obs_batch: Union[List[TensorStructType], TensorStructType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[
            Union[List[TensorStructType], TensorStructType]
        ] = None,
        prev_reward_batch: Optional[
            Union[List[TensorStructType], TensorStructType]
        ] = None,
        actions_normalized: bool = True,
    ) -> TensorType:

        if self.action_sampler_fn and self.action_distribution_fn is None:
            raise ValueError(
                "Cannot compute log-prob/likelihood w/o an "
                "`action_distribution_fn` and a provided "
                "`action_sampler_fn`!"
            )

        with torch.no_grad():
            input_dict = self._lazy_tensor_dict(
                {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
            )
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            state_batches = [
                convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
            ]

            # Exploration hook before each forward pass.
            self.exploration.before_compute_actions(explore=False)

            # Action dist class and inputs are generated via custom function.
            if self.action_distribution_fn:

                # Try new action_distribution_fn signature, supporting
                # state_batches and seq_lens.
                try:
                    dist_inputs, dist_class, state_out = self.action_distribution_fn(
                        self,
                        self.model,
                        input_dict=input_dict,
                        state_batches=state_batches,
                        seq_lens=seq_lens,
                        explore=False,
                        is_training=False,
                    )
                # Trying the old way (to stay backward compatible).
                # TODO: Remove in future.
                except TypeError as e:
                    if (
                        "positional argument" in e.args[0]
                        or "unexpected keyword argument" in e.args[0]
                    ):
                        dist_inputs, dist_class, _ = self.action_distribution_fn(
                            policy=self,
                            model=self.model,
                            obs_batch=input_dict[SampleBatch.CUR_OBS],
                            explore=False,
                            is_training=False,
                        )
                    else:
                        raise e

            # Default action-dist inputs calculation.
            else:
                dist_class = self.dist_class
                dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)

            action_dist = dist_class(dist_inputs, self.model)

            # Normalize actions if necessary.
            actions = input_dict[SampleBatch.ACTIONS]
            if not actions_normalized and self.config["normalize_actions"]:
                actions = normalize_action(actions, self.action_space_struct)

            log_likelihoods = action_dist.logp(actions)

            return log_likelihoods
示例#13
0
def from_importance_weights(
    log_rhos,
    discounts,
    rewards,
    values,
    bootstrap_value,
    clip_rho_threshold=1.0,
    clip_pg_rho_threshold=1.0,
):
    """V-trace from log importance weights.

    Calculates V-trace actor critic targets as described in

    "IMPALA: Scalable Distributed Deep-RL with
    Importance Weighted Actor-Learner Architectures"
    by Espeholt, Soyer, Munos et al.

    In the notation used throughout documentation and comments, T refers to the
    time dimension ranging from 0 to T-1. B refers to the batch size. This code
    also supports the case where all tensors have the same number of additional
    dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C],
    `bootstrap_value` is [B, C].

    Args:
        log_rhos: A float32 tensor of shape [T, B] representing the log
            importance sampling weights, i.e.
            log(target_policy(a) / behaviour_policy(a)). V-trace performs
            operations on rhos in log-space for numerical stability.
        discounts: A float32 tensor of shape [T, B] with discounts encountered
            when following the behaviour policy.
        rewards: A float32 tensor of shape [T, B] containing rewards generated
            by following the behaviour policy.
        values: A float32 tensor of shape [T, B] with the value function
            estimates wrt. the target policy.
        bootstrap_value: A float32 of shape [B] with the value function
            estimate at time T.
        clip_rho_threshold: A scalar float32 tensor with the clipping threshold
            for importance weights (rho) when calculating the baseline targets
            (vs). rho^bar in the paper. If None, no clipping is applied.
        clip_pg_rho_threshold: A scalar float32 tensor with the clipping
            threshold on rho_s in
            \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
            If None, no clipping is applied.

    Returns:
        A VTraceReturns namedtuple (vs, pg_advantages) where:
        vs: A float32 tensor of shape [T, B]. Can be used as target to
            train a baseline (V(x_t) - vs_t)^2.
        pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
            advantage in the calculation of policy gradients.
    """
    log_rhos = convert_to_torch_tensor(log_rhos, device="cpu")
    discounts = convert_to_torch_tensor(discounts, device="cpu")
    rewards = convert_to_torch_tensor(rewards, device="cpu")
    values = convert_to_torch_tensor(values, device="cpu")
    bootstrap_value = convert_to_torch_tensor(bootstrap_value, device="cpu")

    # Make sure tensor ranks are consistent.
    rho_rank = len(log_rhos.size())  # Usually 2.
    assert rho_rank == len(values.size())
    assert rho_rank - 1 == len(
        bootstrap_value.size()), "must have rank {}".format(rho_rank - 1)
    assert rho_rank == len(discounts.size())
    assert rho_rank == len(rewards.size())

    rhos = torch.exp(log_rhos)
    if clip_rho_threshold is not None:
        clipped_rhos = torch.clamp_max(rhos, clip_rho_threshold)
    else:
        clipped_rhos = rhos

    cs = torch.clamp_max(rhos, 1.0)
    # Append bootstrapped value to get [v1, ..., v_t+1]
    values_t_plus_1 = torch.cat(
        [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
    deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)

    vs_minus_v_xs = [torch.zeros_like(bootstrap_value)]
    for i in reversed(range(len(discounts))):
        discount_t, c_t, delta_t = discounts[i], cs[i], deltas[i]
        vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1])
    vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:])
    # Reverse the results back to original order.
    vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[0])
    # Add V(x_s) to get v_s.
    vs = vs_minus_v_xs + values

    # Advantage for policy gradient.
    vs_t_plus_1 = torch.cat(
        [vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
    if clip_pg_rho_threshold is not None:
        clipped_pg_rhos = torch.clamp_max(rhos, clip_pg_rho_threshold)
    else:
        clipped_pg_rhos = rhos
    pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 -
                                       values)

    # Make sure no gradients backpropagated through the returned values.
    return VTraceReturns(vs=vs.detach(), pg_advantages=pg_advantages.detach())
示例#14
0
def multi_from_logits(
    behaviour_policy_logits,
    target_policy_logits,
    actions,
    discounts,
    rewards,
    values,
    bootstrap_value,
    dist_class,
    model,
    behaviour_action_log_probs=None,
    clip_rho_threshold=1.0,
    clip_pg_rho_threshold=1.0,
):
    """V-trace for softmax policies.

    Calculates V-trace actor critic targets for softmax polices as described in

    "IMPALA: Scalable Distributed Deep-RL with
    Importance Weighted Actor-Learner Architectures"
    by Espeholt, Soyer, Munos et al.

    Target policy refers to the policy we are interested in improving and
    behaviour policy refers to the policy that generated the given
    rewards and actions.

    In the notation used throughout documentation and comments, T refers to the
    time dimension ranging from 0 to T-1. B refers to the batch size and
    ACTION_SPACE refers to the list of numbers each representing a number of
    actions.

    Args:
        behaviour_policy_logits: A list with length of ACTION_SPACE of float32
            tensors of shapes [T, B, ACTION_SPACE[0]], ...,
            [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
            parameterizing the softmax behavior policy.
        target_policy_logits: A list with length of ACTION_SPACE of float32
            tensors of shapes [T, B, ACTION_SPACE[0]], ...,
            [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
            parameterizing the softmax target policy.
        actions: A list with length of ACTION_SPACE of tensors of shapes
            [T, B, ...], ..., [T, B, ...]
            with actions sampled from the behavior policy.
        discounts: A float32 tensor of shape [T, B] with the discount
            encountered when following the behavior policy.
        rewards: A float32 tensor of shape [T, B] with the rewards generated by
            following the behavior policy.
        values: A float32 tensor of shape [T, B] with the value function
            estimates wrt. the target policy.
        bootstrap_value: A float32 of shape [B] with the value function
            estimate at time T.
        dist_class: action distribution class for the logits.
        model: backing ModelV2 instance
        behaviour_action_log_probs: Precalculated values of the behavior
            actions.
        clip_rho_threshold: A scalar float32 tensor with the clipping threshold
            for importance weights (rho) when calculating the baseline targets
            (vs). rho^bar in the paper.
        clip_pg_rho_threshold: A scalar float32 tensor with the clipping
            threshold on rho_s in:
            \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).

    Returns:
        A `VTraceFromLogitsReturns` namedtuple with the following fields:
        vs: A float32 tensor of shape [T, B]. Can be used as target to train a
            baseline (V(x_t) - vs_t)^2.
        pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an
            estimate of the advantage in the calculation of policy gradients.
        log_rhos: A float32 tensor of shape [T, B] containing the log
            importance sampling weights (log rhos).
        behaviour_action_log_probs: A float32 tensor of shape [T, B] containing
            behaviour policy action log probabilities (log \mu(a_t)).
        target_action_log_probs: A float32 tensor of shape [T, B] containing
            target policy action probabilities (log \pi(a_t)).
    """

    behaviour_policy_logits = convert_to_torch_tensor(behaviour_policy_logits,
                                                      device="cpu")
    target_policy_logits = convert_to_torch_tensor(target_policy_logits,
                                                   device="cpu")
    actions = convert_to_torch_tensor(actions, device="cpu")

    # Make sure tensor ranks are as expected.
    # The rest will be checked by from_action_log_probs.
    for i in range(len(behaviour_policy_logits)):
        assert len(behaviour_policy_logits[i].size()) == 3
        assert len(target_policy_logits[i].size()) == 3

    target_action_log_probs = multi_log_probs_from_logits_and_actions(
        target_policy_logits, actions, dist_class, model)

    if len(behaviour_policy_logits) > 1 or behaviour_action_log_probs is None:
        # can't use precalculated values, recompute them. Note that
        # recomputing won't work well for autoregressive action dists
        # which may have variables not captured by 'logits'
        behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
            behaviour_policy_logits, actions, dist_class, model)

    behaviour_action_log_probs = convert_to_torch_tensor(
        behaviour_action_log_probs, device="cpu")
    behaviour_action_log_probs = force_list(behaviour_action_log_probs)
    log_rhos = get_log_rhos(target_action_log_probs,
                            behaviour_action_log_probs)

    vtrace_returns = from_importance_weights(
        log_rhos=log_rhos,
        discounts=discounts,
        rewards=rewards,
        values=values,
        bootstrap_value=bootstrap_value,
        clip_rho_threshold=clip_rho_threshold,
        clip_pg_rho_threshold=clip_pg_rho_threshold,
    )

    return VTraceFromLogitsReturns(
        log_rhos=log_rhos,
        behaviour_action_log_probs=behaviour_action_log_probs,
        target_action_log_probs=target_action_log_probs,
        **vtrace_returns._asdict())
示例#15
0
def cql_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: Type[TorchDistributionWrapper],
    train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
    logger.info(f"Current iteration = {policy.cur_iter}")
    policy.cur_iter += 1

    # Look up the target model (tower) using the model tower.
    target_model = policy.target_models[model]

    # For best performance, turn deterministic off
    deterministic = policy.config["_deterministic_loss"]
    assert not deterministic
    twin_q = policy.config["twin_q"]
    discount = policy.config["gamma"]
    action_low = model.action_space.low[0]
    action_high = model.action_space.high[0]

    # CQL Parameters
    bc_iters = policy.config["bc_iters"]
    cql_temp = policy.config["temperature"]
    num_actions = policy.config["num_actions"]
    min_q_weight = policy.config["min_q_weight"]
    use_lagrange = policy.config["lagrangian"]
    target_action_gap = policy.config["lagrangian_thresh"]

    obs = train_batch[SampleBatch.CUR_OBS]
    actions = train_batch[SampleBatch.ACTIONS]
    rewards = train_batch[SampleBatch.REWARDS].float()
    next_obs = train_batch[SampleBatch.NEXT_OBS]
    terminals = train_batch[SampleBatch.DONES]

    model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None)

    model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [],
                             None)

    target_model_out_tp1, _ = target_model(
        SampleBatch(obs=next_obs, _is_training=True), [], None)

    action_dist_class = _get_dist_class(policy, policy.config,
                                        policy.action_space)
    action_dist_t = action_dist_class(model.get_policy_output(model_out_t),
                                      policy.model)
    policy_t, log_pis_t = action_dist_t.sample_logp()
    log_pis_t = torch.unsqueeze(log_pis_t, -1)

    # Unlike original SAC, Alpha and Actor Loss are computed first.
    # Alpha Loss
    alpha_loss = -(model.log_alpha *
                   (log_pis_t + model.target_entropy).detach()).mean()

    batch_size = tree.flatten(obs)[0].shape[0]
    if batch_size == policy.config["train_batch_size"]:
        policy.alpha_optim.zero_grad()
        alpha_loss.backward()
        policy.alpha_optim.step()

    # Policy Loss (Either Behavior Clone Loss or SAC Loss)
    alpha = torch.exp(model.log_alpha)
    if policy.cur_iter >= bc_iters:
        min_q = model.get_q_values(model_out_t, policy_t)
        if twin_q:
            twin_q_ = model.get_twin_q_values(model_out_t, policy_t)
            min_q = torch.min(min_q, twin_q_)
        actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
    else:
        bc_logp = action_dist_t.logp(actions)
        actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean()
        # actor_loss = -bc_logp.mean()

    if batch_size == policy.config["train_batch_size"]:
        policy.actor_optim.zero_grad()
        actor_loss.backward(retain_graph=True)
        policy.actor_optim.step()

    # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
    # SAC Loss:
    # Q-values for the batched actions.
    action_dist_tp1 = action_dist_class(model.get_policy_output(model_out_tp1),
                                        policy.model)
    policy_tp1, _ = action_dist_tp1.sample_logp()

    q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
    q_t_selected = torch.squeeze(q_t, dim=-1)
    if twin_q:
        twin_q_t = model.get_twin_q_values(model_out_t,
                                           train_batch[SampleBatch.ACTIONS])
        twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)

    # Target q network evaluation.
    q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1)
    if twin_q:
        twin_q_tp1 = target_model.get_twin_q_values(target_model_out_tp1,
                                                    policy_tp1)
        # Take min over both twin-NNs.
        q_tp1 = torch.min(q_tp1, twin_q_tp1)

    q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
    q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_target = (
        rewards +
        (discount**policy.config["n_step"]) * q_tp1_best_masked).detach()

    # Compute the TD-error (potentially clipped), for priority replay buffer
    base_td_error = torch.abs(q_t_selected - q_t_target)
    if twin_q:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target)
    if twin_q:
        critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target)

    # CQL Loss (We are using Entropy version of CQL (the best version))
    rand_actions = convert_to_torch_tensor(
        torch.FloatTensor(actions.shape[0] * num_actions,
                          actions.shape[-1]).uniform_(action_low, action_high),
        policy.device,
    )
    curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class,
                                                    model_out_t, num_actions)
    next_actions, next_logp = policy_actions_repeat(model, action_dist_class,
                                                    model_out_tp1, num_actions)

    q1_rand = q_values_repeat(model, model_out_t, rand_actions)
    q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
    q1_next_actions = q_values_repeat(model, model_out_t, next_actions)

    if twin_q:
        q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
        q2_curr_actions = q_values_repeat(model,
                                          model_out_t,
                                          curr_actions,
                                          twin=True)
        q2_next_actions = q_values_repeat(model,
                                          model_out_t,
                                          next_actions,
                                          twin=True)

    random_density = np.log(0.5**curr_actions.shape[-1])
    cat_q1 = torch.cat(
        [
            q1_rand - random_density,
            q1_next_actions - next_logp.detach(),
            q1_curr_actions - curr_logp.detach(),
        ],
        1,
    )
    if twin_q:
        cat_q2 = torch.cat(
            [
                q2_rand - random_density,
                q2_next_actions - next_logp.detach(),
                q2_curr_actions - curr_logp.detach(),
            ],
            1,
        )

    min_qf1_loss_ = (torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() *
                     min_q_weight * cql_temp)
    min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight)
    if twin_q:
        min_qf2_loss_ = (torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() *
                         min_q_weight * cql_temp)
        min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight)

    if use_lagrange:
        alpha_prime = torch.clamp(model.log_alpha_prime.exp(),
                                  min=0.0,
                                  max=1000000.0)[0]
        min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
        if twin_q:
            min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
            alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
        else:
            alpha_prime_loss = -min_qf1_loss

    cql_loss = [min_qf1_loss]
    if twin_q:
        cql_loss.append(min_qf2_loss)

    critic_loss = [critic_loss_1 + min_qf1_loss]
    if twin_q:
        critic_loss.append(critic_loss_2 + min_qf2_loss)

    if batch_size == policy.config["train_batch_size"]:
        policy.critic_optims[0].zero_grad()
        critic_loss[0].backward(retain_graph=True)
        policy.critic_optims[0].step()

        if twin_q:
            policy.critic_optims[1].zero_grad()
            critic_loss[1].backward(retain_graph=False)
            policy.critic_optims[1].step()

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    # SAC stats.
    model.tower_stats["q_t"] = q_t_selected
    model.tower_stats["policy_t"] = policy_t
    model.tower_stats["log_pis_t"] = log_pis_t
    model.tower_stats["actor_loss"] = actor_loss
    model.tower_stats["critic_loss"] = critic_loss
    model.tower_stats["alpha_loss"] = alpha_loss
    model.tower_stats["log_alpha_value"] = model.log_alpha
    model.tower_stats["alpha_value"] = alpha
    model.tower_stats["target_entropy"] = model.target_entropy
    # CQL stats.
    model.tower_stats["cql_loss"] = cql_loss

    # TD-error tensor in final stats
    # will be concatenated and retrieved for each individual batch item.
    model.tower_stats["td_error"] = td_error

    if use_lagrange:
        model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0]
        model.tower_stats["alpha_prime_value"] = alpha_prime
        model.tower_stats["alpha_prime_loss"] = alpha_prime_loss

        if batch_size == policy.config["train_batch_size"]:
            policy.alpha_prime_optim.zero_grad()
            alpha_prime_loss.backward()
            policy.alpha_prime_optim.step()

    # Return all loss terms corresponding to our optimizers.
    return tuple([actor_loss] + critic_loss + [alpha_loss] +
                 ([alpha_prime_loss] if use_lagrange else []))
示例#16
0
def build_slateq_losses(
    policy: Policy,
    model: ModelV2,
    _,
    train_batch: SampleBatch,
) -> TensorType:
    """Constructs the choice- and Q-value losses for the SlateQTorchPolicy.

    Args:
        policy: The Policy to calculate the loss for.
        model: The Model to calculate the loss for.
        train_batch: The training data.

    Returns:
        The user-choice- and Q-value loss tensors.
    """

    # B=batch size
    # S=slate size
    # C=num candidates
    # E=embedding size
    # A=number of all possible slates

    # Q-value computations.
    # ---------------------
    # action.shape: [B, S]
    actions = train_batch[SampleBatch.ACTIONS]

    observation = convert_to_torch_tensor(
        train_batch[SampleBatch.OBS], device=actions.device
    )
    # user.shape: [B, E]
    user_obs = observation["user"]
    batch_size, embedding_size = user_obs.shape
    # doc.shape: [B, C, E]
    doc_obs = list(observation["doc"].values())

    A, S = policy.slates.shape

    # click_indicator.shape: [B, S]
    click_indicator = torch.stack(
        [k["click"] for k in observation["response"]], 1
    ).float()
    # item_reward.shape: [B, S]
    item_reward = torch.stack([k["watch_time"] for k in observation["response"]], 1)
    # q_values.shape: [B, C]
    q_values = model.get_q_values(user_obs, doc_obs)
    # slate_q_values.shape: [B, S]
    slate_q_values = torch.take_along_dim(q_values, actions.long(), dim=-1)
    # Only get the Q from the clicked document.
    # replay_click_q.shape: [B]
    replay_click_q = torch.sum(slate_q_values * click_indicator, dim=1)

    # Target computations.
    # --------------------
    next_obs = convert_to_torch_tensor(
        train_batch[SampleBatch.NEXT_OBS], device=actions.device
    )

    # user.shape: [B, E]
    user_next_obs = next_obs["user"]
    # doc.shape: [B, C, E]
    doc_next_obs = list(next_obs["doc"].values())
    # Only compute the watch time reward of the clicked item.
    reward = torch.sum(item_reward * click_indicator, dim=1)

    # TODO: Find out, whether it's correct here to use obs, not next_obs!
    # Dopamine uses obs, then next_obs only for the score.
    # next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs)
    next_q_values = policy.target_models[model].get_q_values(user_obs, doc_obs)
    scores, score_no_click = score_documents(user_next_obs, doc_next_obs)

    # next_q_values_slate.shape: [B, A, S]
    indices = policy.slates_indices.to(next_q_values.device)
    next_q_values_slate = torch.take_along_dim(next_q_values, indices, dim=1).reshape(
        [-1, A, S]
    )
    # scores_slate.shape [B, A, S]
    scores_slate = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S])
    # score_no_click_slate.shape: [B, A]
    score_no_click_slate = torch.reshape(
        torch.tile(score_no_click, policy.slates.shape[:1]), [batch_size, -1]
    )

    # next_q_target_slate.shape: [B, A]
    next_q_target_slate = torch.sum(next_q_values_slate * scores_slate, dim=2) / (
        torch.sum(scores_slate, dim=2) + score_no_click_slate
    )
    next_q_target_max, _ = torch.max(next_q_target_slate, dim=1)

    target = reward + policy.config["gamma"] * next_q_target_max * (
        1.0 - train_batch["dones"].float()
    )
    target = target.detach()

    clicked = torch.sum(click_indicator, dim=1)
    mask_clicked_slates = clicked > 0
    clicked_indices = torch.arange(batch_size).to(mask_clicked_slates.device)
    clicked_indices = torch.masked_select(clicked_indices, mask_clicked_slates)
    # Clicked_indices is a vector and torch.gather selects the batch dimension.
    q_clicked = torch.gather(replay_click_q, 0, clicked_indices)
    target_clicked = torch.gather(target, 0, clicked_indices)

    td_error = torch.where(
        clicked.bool(),
        replay_click_q - target,
        torch.zeros_like(train_batch[SampleBatch.REWARDS]),
    )
    if policy.config["use_huber"]:
        loss = huber_loss(td_error, delta=policy.config["huber_threshold"])
    else:
        loss = torch.pow(td_error, 2.0)
    loss = torch.mean(loss)
    td_error = torch.abs(td_error)
    mean_td_error = torch.mean(td_error)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_values"] = torch.mean(q_values)
    model.tower_stats["q_clicked"] = torch.mean(q_clicked)
    model.tower_stats["scores"] = torch.mean(scores)
    model.tower_stats["score_no_click"] = torch.mean(score_no_click)
    model.tower_stats["slate_q_values"] = torch.mean(slate_q_values)
    model.tower_stats["replay_click_q"] = torch.mean(replay_click_q)
    model.tower_stats["bellman_reward"] = torch.mean(reward)
    model.tower_stats["next_q_values"] = torch.mean(next_q_values)
    model.tower_stats["target"] = torch.mean(target)
    model.tower_stats["next_q_target_slate"] = torch.mean(next_q_target_slate)
    model.tower_stats["next_q_target_max"] = torch.mean(next_q_target_max)
    model.tower_stats["target_clicked"] = torch.mean(target_clicked)
    model.tower_stats["q_loss"] = loss
    model.tower_stats["td_error"] = td_error
    model.tower_stats["mean_td_error"] = mean_td_error
    model.tower_stats["mean_actions"] = torch.mean(actions.float())

    # selected_doc.shape: [batch_size, slate_size, embedding_size]
    selected_doc = torch.gather(
        # input.shape: [batch_size, num_docs, embedding_size]
        torch.stack(doc_obs, 1),
        1,
        # index.shape: [batch_size, slate_size, embedding_size]
        actions.unsqueeze(2).expand(-1, -1, embedding_size).long(),
    )

    scores = model.choice_model(user_obs, selected_doc)

    # click_indicator.shape: [batch_size, slate_size]
    # no_clicks.shape: [batch_size, 1]
    no_clicks = 1 - torch.sum(click_indicator, 1, keepdim=True)
    # targets.shape: [batch_size, slate_size+1]
    targets = torch.cat([click_indicator, no_clicks], dim=1)
    choice_loss = nn.functional.cross_entropy(scores, torch.argmax(targets, dim=1))
    # print(model.choice_model.a.item(), model.choice_model.b.item())

    model.tower_stats["choice_loss"] = choice_loss

    return choice_loss, loss