コード例 #1
0
 def set_ensemble_weights(policy, pid, weights):
     weights = weights[pid]
     weights = convert_to_torch_tensor(weights, device=policy.device)
     model = policy.dynamics_model
     model.load_state_dict(weights)
コード例 #2
0
    def compute_log_likelihoods(
        self,
        actions: Union[List[TensorType], TensorType],
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[Union[List[TensorType],
                                          TensorType]] = None,
        prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None
    ) -> TensorType:

        if self.action_sampler_fn and self.action_distribution_fn is None:
            raise ValueError("Cannot compute log-prob/likelihood w/o an "
                             "`action_distribution_fn` and a provided "
                             "`action_sampler_fn`!")

        with torch.no_grad():
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS: obs_batch,
                SampleBatch.ACTIONS: actions
            })
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            state_batches = [
                convert_to_torch_tensor(s, self.device)
                for s in (state_batches or [])
            ]

            # Exploration hook before each forward pass.
            self.exploration.before_compute_actions(explore=False)

            # Action dist class and inputs are generated via custom function.
            if self.action_distribution_fn:

                # Try new action_distribution_fn signature, supporting
                # state_batches and seq_lens.
                try:
                    dist_inputs, dist_class, state_out = \
                        self.action_distribution_fn(
                            self,
                            self.model,
                            input_dict=input_dict,
                            state_batches=state_batches,
                            seq_lens=seq_lens,
                            explore=False,
                            is_training=False)
                # Trying the old way (to stay backward compatible).
                # TODO: Remove in future.
                except TypeError as e:
                    if "positional argument" in e.args[0] or \
                            "unexpected keyword argument" in e.args[0]:
                        dist_inputs, dist_class, _ = \
                            self.action_distribution_fn(
                                policy=self,
                                model=self.model,
                                obs_batch=input_dict[SampleBatch.CUR_OBS],
                                explore=False,
                                is_training=False)
                    else:
                        raise e

            # Default action-dist inputs calculation.
            else:
                dist_class = self.dist_class
                dist_inputs, _ = self.model(input_dict, state_batches,
                                            seq_lens)

            action_dist = dist_class(dist_inputs, self.model)
            log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])

            return log_likelihoods
コード例 #3
0
ファイル: vtrace_torch.py プロジェクト: wallacetroy/ray
def from_importance_weights(log_rhos,
                            discounts,
                            rewards,
                            values,
                            bootstrap_value,
                            clip_rho_threshold=1.0,
                            clip_pg_rho_threshold=1.0):
    """V-trace from log importance weights.

    Calculates V-trace actor critic targets as described in

    "IMPALA: Scalable Distributed Deep-RL with
    Importance Weighted Actor-Learner Architectures"
    by Espeholt, Soyer, Munos et al.

    In the notation used throughout documentation and comments, T refers to the
    time dimension ranging from 0 to T-1. B refers to the batch size. This code
    also supports the case where all tensors have the same number of additional
    dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C],
    `bootstrap_value` is [B, C].

    Args:
        log_rhos: A float32 tensor of shape [T, B] representing the log
            importance sampling weights, i.e.
            log(target_policy(a) / behaviour_policy(a)). V-trace performs
            operations on rhos in log-space for numerical stability.
        discounts: A float32 tensor of shape [T, B] with discounts encountered
            when following the behaviour policy.
        rewards: A float32 tensor of shape [T, B] containing rewards generated
            by following the behaviour policy.
        values: A float32 tensor of shape [T, B] with the value function
            estimates wrt. the target policy.
        bootstrap_value: A float32 of shape [B] with the value function
            estimate at time T.
        clip_rho_threshold: A scalar float32 tensor with the clipping threshold
            for importance weights (rho) when calculating the baseline targets
            (vs). rho^bar in the paper. If None, no clipping is applied.
        clip_pg_rho_threshold: A scalar float32 tensor with the clipping
            threshold on rho_s in
            \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
            If None, no clipping is applied.

    Returns:
        A VTraceReturns namedtuple (vs, pg_advantages) where:
        vs: A float32 tensor of shape [T, B]. Can be used as target to
            train a baseline (V(x_t) - vs_t)^2.
        pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
            advantage in the calculation of policy gradients.
    """
    log_rhos = convert_to_torch_tensor(log_rhos, device="cpu")
    discounts = convert_to_torch_tensor(discounts, device="cpu")
    rewards = convert_to_torch_tensor(rewards, device="cpu")
    values = convert_to_torch_tensor(values, device="cpu")
    bootstrap_value = convert_to_torch_tensor(bootstrap_value, device="cpu")

    # Make sure tensor ranks are consistent.
    rho_rank = len(log_rhos.size())  # Usually 2.
    assert rho_rank == len(values.size())
    assert rho_rank - 1 == len(bootstrap_value.size()),\
        "must have rank {}".format(rho_rank - 1)
    assert rho_rank == len(discounts.size())
    assert rho_rank == len(rewards.size())

    rhos = torch.exp(log_rhos)
    if clip_rho_threshold is not None:
        clipped_rhos = torch.clamp_max(rhos, clip_rho_threshold)
    else:
        clipped_rhos = rhos

    cs = torch.clamp_max(rhos, 1.0)
    # Append bootstrapped value to get [v1, ..., v_t+1]
    values_t_plus_1 = torch.cat(
        [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
    deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)

    vs_minus_v_xs = [torch.zeros_like(bootstrap_value)]
    for i in reversed(range(len(discounts))):
        discount_t, c_t, delta_t = discounts[i], cs[i], deltas[i]
        vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1])
    vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:])
    # Reverse the results back to original order.
    vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[0])
    # Add V(x_s) to get v_s.
    vs = vs_minus_v_xs + values

    # Advantage for policy gradient.
    vs_t_plus_1 = torch.cat(
        [vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
    if clip_pg_rho_threshold is not None:
        clipped_pg_rhos = torch.clamp_max(rhos, clip_pg_rho_threshold)
    else:
        clipped_pg_rhos = rhos
    pg_advantages = (
        clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values))

    # Make sure no gradients backpropagated through the returned values.
    return VTraceReturns(vs=vs.detach(), pg_advantages=pg_advantages.detach())
コード例 #4
0
 def set_weights(self, weights: ModelWeights) -> None:
     weights = convert_to_torch_tensor(weights, device=self.device)
     self.model.load_state_dict(weights)
コード例 #5
0
def cql_loss(policy: Policy, model: ModelV2,
             dist_class: Type[TorchDistributionWrapper],
             train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    print(policy.cur_iter)
    policy.cur_iter += 1
    # For best performance, turn deterministic off
    deterministic = policy.config["_deterministic_loss"]
    twin_q = policy.config["twin_q"]
    discount = policy.config["gamma"]
    action_low = model.action_space.low[0]
    action_high = model.action_space.high[0]

    # CQL Parameters
    bc_iters = policy.config["bc_iters"]
    cql_temp = policy.config["temperature"]
    num_actions = policy.config["num_actions"]
    min_q_weight = policy.config["min_q_weight"]
    use_lagrange = policy.config["lagrangian"]
    target_action_gap = policy.config["lagrangian_thresh"]

    obs = train_batch[SampleBatch.CUR_OBS]
    actions = train_batch[SampleBatch.ACTIONS]
    rewards = train_batch[SampleBatch.REWARDS]
    next_obs = train_batch[SampleBatch.NEXT_OBS]
    terminals = train_batch[SampleBatch.DONES]

    model_out_t, _ = model({
        "obs": obs,
        "is_training": True,
    }, [], None)

    model_out_tp1, _ = model({
        "obs": next_obs,
        "is_training": True,
    }, [], None)

    target_model_out_tp1, _ = policy.target_model({
        "obs": next_obs,
        "is_training": True,
    }, [], None)

    action_dist_class = _get_dist_class(policy.config, policy.action_space)
    action_dist_t = action_dist_class(
        model.get_policy_output(model_out_t), policy.model)
    policy_t = action_dist_t.sample() if not deterministic else \
        action_dist_t.deterministic_sample()
    log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)

    # Unlike original SAC, Alpha and Actor Loss are computed first.
    # Alpha Loss
    alpha_loss = -(model.log_alpha *
                   (log_pis_t + model.target_entropy).detach()).mean()

    # Policy Loss (Either Behavior Clone Loss or SAC Loss)
    alpha = torch.exp(model.log_alpha)
    if policy.cur_iter >= bc_iters:
        min_q = model.get_q_values(model_out_t, policy_t)
        if twin_q:
            twin_q_ = model.get_twin_q_values(model_out_t, policy_t)
            min_q = torch.min(min_q, twin_q_)
        actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
    else:
        bc_logp = action_dist_t.logp(actions)
        actor_loss = (alpha * log_pis_t - bc_logp).mean()

    # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
    # SAC Loss
    action_dist_tp1 = action_dist_class(
        model.get_policy_output(model_out_tp1), policy.model)
    policy_tp1 = action_dist_tp1.sample() if not deterministic else \
        action_dist_tp1.deterministic_sample()

    # Q-values for the batched actions.
    q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
    q_t = torch.squeeze(q_t, dim=-1)
    if twin_q:
        twin_q_t = model.get_twin_q_values(model_out_t,
                                           train_batch[SampleBatch.ACTIONS])
        twin_q_t = torch.squeeze(twin_q_t, dim=-1)

    # Target q network evaluation.
    q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
    if twin_q:
        twin_q_tp1 = policy.target_model.get_twin_q_values(
            target_model_out_tp1, policy_tp1)
        # Take min over both twin-NNs.
        q_tp1 = torch.min(q_tp1, twin_q_tp1)
    q_tp1 = torch.squeeze(input=q_tp1, dim=-1)
    q_tp1 = (1.0 - terminals.float()) * q_tp1

    # compute RHS of bellman equation
    q_t_target = (
        rewards + (discount**policy.config["n_step"]) * q_tp1).detach()

    # Compute the TD-error (potentially clipped), for priority replay buffer
    base_td_error = torch.abs(q_t - q_t_target)
    if twin_q:
        twin_td_error = torch.abs(twin_q_t - q_t_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error
    critic_loss = [nn.MSELoss()(q_t, q_t_target)]
    if twin_q:
        critic_loss.append(nn.MSELoss()(twin_q_t, q_t_target))

    # CQL Loss (We are using Entropy version of CQL (the best version))
    rand_actions = convert_to_torch_tensor(
        torch.FloatTensor(actions.shape[0] * num_actions,
                          actions.shape[-1]).uniform_(action_low, action_high),
        policy.device)
    curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class,
                                                    obs, num_actions)
    next_actions, next_logp = policy_actions_repeat(model, action_dist_class,
                                                    next_obs, num_actions)

    curr_logp = curr_logp.view(actions.shape[0], num_actions, 1)
    next_logp = next_logp.view(actions.shape[0], num_actions, 1)

    q1_rand = q_values_repeat(model, model_out_t, rand_actions)
    q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
    q1_next_actions = q_values_repeat(model, model_out_t, next_actions)

    if twin_q:
        q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
        q2_curr_actions = q_values_repeat(
            model, model_out_t, curr_actions, twin=True)
        q2_next_actions = q_values_repeat(
            model, model_out_t, next_actions, twin=True)

    random_density = np.log(0.5**curr_actions.shape[-1])
    cat_q1 = torch.cat([
        q1_rand - random_density, q1_next_actions - next_logp.detach(),
        q1_curr_actions - curr_logp.detach()
    ], 1)
    if twin_q:
        cat_q2 = torch.cat([
            q2_rand - random_density, q2_next_actions - next_logp.detach(),
            q2_curr_actions - curr_logp.detach()
        ], 1)

    min_qf1_loss = torch.logsumexp(
        cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
    min_qf1_loss = min_qf1_loss - q_t.mean() * min_q_weight
    if twin_q:
        min_qf2_loss = torch.logsumexp(
            cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
        min_qf2_loss = min_qf2_loss - twin_q_t.mean() * min_q_weight

    if use_lagrange:
        alpha_prime = torch.clamp(
            model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0]
        min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
        if twin_q:
            min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
            alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
        else:
            alpha_prime_loss = -min_qf1_loss

    cql_loss = [min_qf2_loss]
    if twin_q:
        cql_loss.append(min_qf2_loss)

    critic_loss[0] += min_qf1_loss
    if twin_q:
        critic_loss[1] += min_qf2_loss

    # Save for stats function.
    policy.q_t = q_t
    policy.policy_t = policy_t
    policy.log_pis_t = log_pis_t
    policy.td_error = td_error
    policy.actor_loss = actor_loss
    policy.critic_loss = critic_loss
    policy.alpha_loss = alpha_loss
    policy.log_alpha_value = model.log_alpha
    policy.alpha_value = alpha
    policy.target_entropy = model.target_entropy
    # CQL Stats
    policy.cql_loss = cql_loss
    if use_lagrange:
        policy.log_alpha_prime_value = model.log_alpha_prime[0]
        policy.alpha_prime_value = alpha_prime
        policy.alpha_prime_loss = alpha_prime_loss

    # Return all loss terms corresponding to our optimizers.
    if use_lagrange:
        return tuple([policy.actor_loss] + policy.critic_loss +
                     [policy.alpha_loss] + [policy.alpha_prime_loss])
    return tuple([policy.actor_loss] + policy.critic_loss +
                 [policy.alpha_loss])
コード例 #6
0
ファイル: vtrace_torch.py プロジェクト: wallacetroy/ray
def multi_from_logits(behaviour_policy_logits,
                      target_policy_logits,
                      actions,
                      discounts,
                      rewards,
                      values,
                      bootstrap_value,
                      dist_class,
                      model,
                      behaviour_action_log_probs=None,
                      clip_rho_threshold=1.0,
                      clip_pg_rho_threshold=1.0):
    """V-trace for softmax policies.

    Calculates V-trace actor critic targets for softmax polices as described in

    "IMPALA: Scalable Distributed Deep-RL with
    Importance Weighted Actor-Learner Architectures"
    by Espeholt, Soyer, Munos et al.

    Target policy refers to the policy we are interested in improving and
    behaviour policy refers to the policy that generated the given
    rewards and actions.

    In the notation used throughout documentation and comments, T refers to the
    time dimension ranging from 0 to T-1. B refers to the batch size and
    ACTION_SPACE refers to the list of numbers each representing a number of
    actions.

    Args:
        behaviour_policy_logits: A list with length of ACTION_SPACE of float32
            tensors of shapes [T, B, ACTION_SPACE[0]], ...,
            [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
            parameterizing the softmax behavior policy.
        target_policy_logits: A list with length of ACTION_SPACE of float32
            tensors of shapes [T, B, ACTION_SPACE[0]], ...,
            [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
            parameterizing the softmax target policy.
        actions: A list with length of ACTION_SPACE of tensors of shapes
            [T, B, ...], ..., [T, B, ...]
            with actions sampled from the behavior policy.
        discounts: A float32 tensor of shape [T, B] with the discount
            encountered when following the behavior policy.
        rewards: A float32 tensor of shape [T, B] with the rewards generated by
            following the behavior policy.
        values: A float32 tensor of shape [T, B] with the value function
            estimates wrt. the target policy.
        bootstrap_value: A float32 of shape [B] with the value function
            estimate at time T.
        dist_class: action distribution class for the logits.
        model: backing ModelV2 instance
        behaviour_action_log_probs: Precalculated values of the behavior
            actions.
        clip_rho_threshold: A scalar float32 tensor with the clipping threshold
            for importance weights (rho) when calculating the baseline targets
            (vs). rho^bar in the paper.
        clip_pg_rho_threshold: A scalar float32 tensor with the clipping
            threshold on rho_s in:
            \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).

    Returns:
        A `VTraceFromLogitsReturns` namedtuple with the following fields:
        vs: A float32 tensor of shape [T, B]. Can be used as target to train a
            baseline (V(x_t) - vs_t)^2.
        pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an
            estimate of the advantage in the calculation of policy gradients.
        log_rhos: A float32 tensor of shape [T, B] containing the log
            importance sampling weights (log rhos).
        behaviour_action_log_probs: A float32 tensor of shape [T, B] containing
            behaviour policy action log probabilities (log \mu(a_t)).
        target_action_log_probs: A float32 tensor of shape [T, B] containing
            target policy action probabilities (log \pi(a_t)).
    """

    behaviour_policy_logits = convert_to_torch_tensor(
        behaviour_policy_logits, device="cpu")
    target_policy_logits = convert_to_torch_tensor(
        target_policy_logits, device="cpu")
    actions = convert_to_torch_tensor(actions, device="cpu")

    for i in range(len(behaviour_policy_logits)):
        # Make sure tensor ranks are as expected.
        # The rest will be checked by from_action_log_probs.
        assert len(behaviour_policy_logits[i].size()) == 3
        assert len(target_policy_logits[i].size()) == 3

    target_action_log_probs = multi_log_probs_from_logits_and_actions(
        target_policy_logits, actions, dist_class, model)

    if (len(behaviour_policy_logits) > 1
            or behaviour_action_log_probs is None):
        # can't use precalculated values, recompute them. Note that
        # recomputing won't work well for autoregressive action dists
        # which may have variables not captured by 'logits'
        behaviour_action_log_probs = (multi_log_probs_from_logits_and_actions(
            behaviour_policy_logits, actions, dist_class, model))

    behaviour_action_log_probs = force_list(behaviour_action_log_probs)
    log_rhos = get_log_rhos(target_action_log_probs,
                            behaviour_action_log_probs)

    vtrace_returns = from_importance_weights(
        log_rhos=log_rhos,
        discounts=discounts,
        rewards=rewards,
        values=values,
        bootstrap_value=bootstrap_value,
        clip_rho_threshold=clip_rho_threshold,
        clip_pg_rho_threshold=clip_pg_rho_threshold)

    return VTraceFromLogitsReturns(
        log_rhos=log_rhos,
        behaviour_action_log_probs=behaviour_action_log_probs,
        target_action_log_probs=target_action_log_probs,
        **vtrace_returns._asdict())
コード例 #7
0
ファイル: ppo_torch_policy.py プロジェクト: roclark/ray
 def value(**input_dict):
     model_out, _ = self.model.from_batch(
         convert_to_torch_tensor(input_dict, self.device),
         is_training=False)
     # [0] = remove the batch dim.
     return self.model.value_function()[0]
コード例 #8
0
ファイル: torch_policy.py プロジェクト: radovankavicky/raylab
 def set_weights(self, weights: dict):
     self.module.load_state_dict(
         convert_to_torch_tensor(weights["module"], device=self.device))
     # Optimizer state dicts don't store tensors, only ids
     self.optimizers.load_state_dict(weights["optimizers"])
コード例 #9
0
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        explore=None,
                        timestep=None,
                        **kwargs):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS: np.asarray(obs_batch),
                "is_training": False,
            })
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = \
                    np.asarray(prev_action_batch)
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = \
                    np.asarray(prev_reward_batch)
            state_batches = [
                convert_to_torch_tensor(s) for s in (state_batches or [])
            ]

            if self.action_sampler_fn:
                action_dist = dist_inputs = None
                state_out = []
                actions, logp = self.action_sampler_fn(
                    self,
                    self.model,
                    input_dict[SampleBatch.CUR_OBS],
                    explore=explore,
                    timestep=timestep)
            else:
                # Call the exploration before_compute_actions hook.
                self.exploration.before_compute_actions(
                    explore=explore, timestep=timestep)
                if self.action_distribution_fn:
                    dist_inputs, dist_class, state_out = \
                        self.action_distribution_fn(
                            self,
                            self.model,
                            input_dict[SampleBatch.CUR_OBS],
                            explore=explore,
                            timestep=timestep,
                            is_training=False)
                else:
                    dist_class = self.dist_class
                    dist_inputs, state_out = self.model(
                        input_dict, state_batches, seq_lens)
                if not (isinstance(dist_class, functools.partial)
                        or issubclass(dist_class, TorchDistributionWrapper)):
                    raise ValueError(
                        "`dist_class` ({}) not a TorchDistributionWrapper "
                        "subclass! Make sure your `action_distribution_fn` or "
                        "`make_model_and_action_dist` return a correct "
                        "distribution class.".format(dist_class.__name__))
                action_dist = dist_class(dist_inputs, self.model)

                # Get the exploration action from the forward results.
                actions, logp = \
                    self.exploration.get_exploration_action(
                        action_distribution=action_dist,
                        timestep=timestep,
                        explore=explore)

            input_dict[SampleBatch.ACTIONS] = actions

            # Add default and custom fetches.
            extra_fetches = self.extra_action_out(input_dict, state_batches,
                                                  self.model, action_dist)
            # Action-logp and action-prob.
            if logp is not None:
                logp = convert_to_non_torch_type(logp)
                extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp
            # Action-dist inputs.
            if dist_inputs is not None:
                extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
            return convert_to_non_torch_type((actions, state_out,
                                              extra_fetches))
コード例 #10
0
def cql_loss(policy: Policy, model: ModelV2,
             dist_class: Type[TorchDistributionWrapper],
             train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    logger.info(f"Current iteration = {policy.cur_iter}")
    policy.cur_iter += 1

    # For best performance, turn deterministic off
    deterministic = policy.config["_deterministic_loss"]
    assert not deterministic
    twin_q = policy.config["twin_q"]
    discount = policy.config["gamma"]
    action_low = model.action_space.low[0]
    action_high = model.action_space.high[0]

    # CQL Parameters
    bc_iters = policy.config["bc_iters"]
    cql_temp = policy.config["temperature"]
    num_actions = policy.config["num_actions"]
    min_q_weight = policy.config["min_q_weight"]
    use_lagrange = policy.config["lagrangian"]
    target_action_gap = policy.config["lagrangian_thresh"]

    obs = train_batch[SampleBatch.CUR_OBS]
    actions = train_batch[SampleBatch.ACTIONS]
    rewards = train_batch[SampleBatch.REWARDS]
    next_obs = train_batch[SampleBatch.NEXT_OBS]
    terminals = train_batch[SampleBatch.DONES]

    policy_optimizer = policy._optimizers[0]
    critic1_optimizer = policy._optimizers[1]
    critic2_optimizer = policy._optimizers[2]
    alpha_optimizer = policy._optimizers[3]

    model_out_t, _ = model({
        "obs": obs,
        "is_training": True,
    }, [], None)

    model_out_tp1, _ = model({
        "obs": next_obs,
        "is_training": True,
    }, [], None)

    target_model_out_tp1, _ = policy.target_model(
        {
            "obs": next_obs,
            "is_training": True,
        }, [], None)

    action_dist_class = _get_dist_class(policy.config, policy.action_space)
    action_dist_t = action_dist_class(model.get_policy_output(model_out_t),
                                      policy.model)
    policy_t, log_pis_t = action_dist_t.sample_logp()
    log_pis_t = torch.unsqueeze(log_pis_t, -1)

    # Unlike original SAC, Alpha and Actor Loss are computed first.
    # Alpha Loss
    alpha_loss = -(model.log_alpha *
                   (log_pis_t + model.target_entropy).detach()).mean()

    if obs.shape[0] == policy.config["train_batch_size"]:
        alpha_optimizer.zero_grad()
        alpha_loss.backward()
        alpha_optimizer.step()

    # Policy Loss (Either Behavior Clone Loss or SAC Loss)
    alpha = torch.exp(model.log_alpha)
    if policy.cur_iter >= bc_iters:
        min_q = model.get_q_values(model_out_t, policy_t)
        if twin_q:
            twin_q_ = model.get_twin_q_values(model_out_t, policy_t)
            min_q = torch.min(min_q, twin_q_)
        actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
    else:

        def bc_log(model, obs, actions):
            z = atanh(actions)
            logits = model.get_policy_output(obs)
            mean, log_std = torch.chunk(logits, 2, dim=-1)
            # Mean Clamping for Stability
            mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX)
            log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT,
                                  MAX_LOG_NN_OUTPUT)
            std = torch.exp(log_std)
            normal_dist = torch.distributions.Normal(mean, std)
            return torch.sum(normal_dist.log_prob(z) -
                             torch.log(1 - actions * actions + SMALL_NUMBER),
                             dim=-1)

        bc_logp = bc_log(model, model_out_t, actions)
        actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean()

    if obs.shape[0] == policy.config["train_batch_size"]:
        policy_optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        policy_optimizer.step()

    # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
    # SAC Loss:
    # Q-values for the batched actions.
    action_dist_tp1 = action_dist_class(model.get_policy_output(model_out_tp1),
                                        policy.model)
    policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp()

    log_pis_tp1 = torch.unsqueeze(log_pis_tp1, -1)
    q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
    q_t_selected = torch.squeeze(q_t, dim=-1)
    if twin_q:
        twin_q_t = model.get_twin_q_values(model_out_t,
                                           train_batch[SampleBatch.ACTIONS])
        twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)

    # Target q network evaluation.
    q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
    if twin_q:
        twin_q_tp1 = policy.target_model.get_twin_q_values(
            target_model_out_tp1, policy_tp1)
        # Take min over both twin-NNs.
        q_tp1 = torch.min(q_tp1, twin_q_tp1)

    q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
    q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_target = (
        rewards +
        (discount**policy.config["n_step"]) * q_tp1_best_masked).detach()

    # Compute the TD-error (potentially clipped), for priority replay buffer
    base_td_error = torch.abs(q_t_selected - q_t_target)
    if twin_q:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target)
    if twin_q:
        critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target)

    # CQL Loss (We are using Entropy version of CQL (the best version))
    rand_actions = convert_to_torch_tensor(
        torch.FloatTensor(actions.shape[0] * num_actions,
                          actions.shape[-1]).uniform_(action_low, action_high),
        policy.device)
    curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class,
                                                    model_out_t, num_actions)
    next_actions, next_logp = policy_actions_repeat(model, action_dist_class,
                                                    model_out_tp1, num_actions)

    q1_rand = q_values_repeat(model, model_out_t, rand_actions)
    q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
    q1_next_actions = q_values_repeat(model, model_out_t, next_actions)

    if twin_q:
        q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
        q2_curr_actions = q_values_repeat(model,
                                          model_out_t,
                                          curr_actions,
                                          twin=True)
        q2_next_actions = q_values_repeat(model,
                                          model_out_t,
                                          next_actions,
                                          twin=True)

    random_density = np.log(0.5**curr_actions.shape[-1])
    cat_q1 = torch.cat([
        q1_rand - random_density, q1_next_actions - next_logp.detach(),
        q1_curr_actions - curr_logp.detach()
    ], 1)
    if twin_q:
        cat_q2 = torch.cat([
            q2_rand - random_density, q2_next_actions - next_logp.detach(),
            q2_curr_actions - curr_logp.detach()
        ], 1)

    min_qf1_loss_ = torch.logsumexp(cat_q1 / cql_temp,
                                    dim=1).mean() * min_q_weight * cql_temp
    min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight)
    if twin_q:
        min_qf2_loss_ = torch.logsumexp(cat_q2 / cql_temp,
                                        dim=1).mean() * min_q_weight * cql_temp
        min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight)

    if use_lagrange:
        alpha_prime = torch.clamp(model.log_alpha_prime.exp(),
                                  min=0.0,
                                  max=1000000.0)[0]
        min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
        if twin_q:
            min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
            alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
        else:
            alpha_prime_loss = -min_qf1_loss

    cql_loss = [min_qf1_loss]
    if twin_q:
        cql_loss.append(min_qf2_loss)

    critic_loss = [critic_loss_1 + min_qf1_loss]
    if twin_q:
        critic_loss.append(critic_loss_2 + min_qf2_loss)

    if obs.shape[0] == policy.config["train_batch_size"]:
        critic1_optimizer.zero_grad()
        critic_loss[0].backward(retain_graph=True)
        critic1_optimizer.step()

        critic2_optimizer.zero_grad()
        critic_loss[1].backward(retain_graph=False)
        critic2_optimizer.step()

    # Save for stats function.
    policy.q_t = q_t_selected
    policy.policy_t = policy_t
    policy.log_pis_t = log_pis_t
    model.td_error = td_error
    policy.actor_loss = actor_loss
    policy.critic_loss = critic_loss
    policy.alpha_loss = alpha_loss
    policy.log_alpha_value = model.log_alpha
    policy.alpha_value = alpha
    policy.target_entropy = model.target_entropy
    # CQL Stats
    policy.cql_loss = cql_loss
    if use_lagrange:
        policy.log_alpha_prime_value = model.log_alpha_prime[0]
        policy.alpha_prime_value = alpha_prime
        policy.alpha_prime_loss = alpha_prime_loss

    # Return all loss terms corresponding to our optimizers.
    if use_lagrange:
        return tuple([policy.actor_loss] + policy.critic_loss +
                     [policy.alpha_loss] + [policy.alpha_prime_loss])
    return tuple([policy.actor_loss] + policy.critic_loss +
                 [policy.alpha_loss])
コード例 #11
0
 def _copy_weights_of_one_model(self, model_in_use, model_evading):
     weights = model_evading.state_dict()
     weights = convert_to_torch_tensor(weights, device=self.device)
     model_in_use.load_state_dict(weights)