def _calc_cost_for_action_sequence(self, time_step: TimeStep, state, ac_seqs): """ Args: time_step (TimeStep): input data for next step prediction state (MbrlState): input state for next step prediction ac_seqs: action_sequence (Tensor) of shape [batch_size, population_size, solution_dim]), where solution_dim = planning_horizon * num_actions Returns: cost (Tensor) with shape [batch_size, population_size] """ obs = time_step.observation batch_size = obs.shape[0] ac_seqs = torch.reshape( ac_seqs, [batch_size, self._population_size, self._planning_horizon, -1]) ac_seqs = ac_seqs.permute(2, 0, 1, 3) ac_seqs = torch.reshape( ac_seqs, (self._planning_horizon, -1, self._num_actions)) state = state._replace(dynamics=state.dynamics._replace(feature=obs)) init_obs = self._expand_to_population(obs) state = nest.map_structure(self._expand_to_population, state) obs = init_obs cost = 0 for i in range(ac_seqs.shape[0]): action = ac_seqs[i] time_step = time_step._replace(prev_action=action) time_step, state = self._dynamics_func(time_step, state) next_obs = time_step.observation # Note: currently using (next_obs, action), might need to # consider (obs, action) in order to be more compatible # with the conventional definition of the reward function reward_step, state = self._reward_func(next_obs, action, state) cost = cost - reward_step obs = next_obs # reshape cost back to [batch size, population_size] cost = torch.reshape(cost, [batch_size, -1]) return cost
def train_step(self, exp: Experience, state: SacState): # We detach exp.observation here so that in the case that exp.observation # is calculated by some other trainable module, the training of that # module will not be affected by the gradient back-propagated from the # actor. However, the gradient from critic will still affect the training # of that module. (action_distribution, action, critics, action_state) = self._predict_action(common.detach(exp.observation), state=state.action) log_pi = nest.map_structure(lambda dist, a: dist.log_prob(a), action_distribution, action) if self._act_type == ActionType.Mixed: # For mixed type, add log_pi separately log_pi = type(self._action_spec)( (sum(nest.flatten(log_pi[0])), sum(nest.flatten(log_pi[1])))) else: log_pi = sum(nest.flatten(log_pi)) if self._prior_actor is not None: prior_step = self._prior_actor.train_step(exp, ()) log_prior = dist_utils.compute_log_probability( prior_step.output, action) log_pi = log_pi - log_prior actor_state, actor_loss = self._actor_train_step( exp, state.actor, action, critics, log_pi, action_distribution) critic_state, critic_info = self._critic_train_step( exp, state.critic, action, log_pi, action_distribution) alpha_loss = self._alpha_train_step(log_pi) state = SacState(action=action_state, actor=actor_state, critic=critic_state) info = SacInfo(action_distribution=action_distribution, actor=actor_loss, critic=critic_info, alpha=alpha_loss) return AlgStep(action, state, info)
def _step(self, action): if self._done: return self.reset() if self._action_spec: nest.assert_same_structure(self._action_spec, action) self._num_steps += 1 observation = self._get_observation() if self._num_steps < self._min_duration: self._done = False elif self._max_duration and self._num_steps >= self._max_duration: self._done = True else: self._done = self._rng.uniform() < self._episode_end_probability if self._batch_size: action = nest.map_structure( lambda t: np.concatenate([np.expand_dims(t, 0)] * self. _batch_size), action) if self._done: reward = self._reward_fn(ds.StepType.LAST, action, observation) self._check_reward_shape(reward) time_step = ds.termination(observation, action, reward, env_id=self._env_id) self._num_steps = 0 else: reward = self._reward_fn(ds.StepType.MID, action, observation) self._check_reward_shape(reward) time_step = ds.transition(observation, action, reward, discount=self._discount, env_id=self._env_id) return time_step
def _critic_train_step(self, exp: Experience, state: SacCriticState, action, log_pi, action_distribution): critics, critics_state = self._compute_critics(self._critic_networks, exp.observation, exp.action, state.critics) target_critics, target_critics_state = self._compute_critics( self._target_critic_networks, exp.observation, action, state.target_critics) target_critics = target_critics.min(dim=1)[0] if self._act_type == ActionType.Discrete: critics = self._select_q_value(exp.action, critics) target_critics = self._select_q_value( action, target_critics.unsqueeze(dim=1)) elif self._act_type == ActionType.Mixed: critics = self._select_q_value(exp.action[0], critics) discrete_act_dist = action_distribution[0] target_critics = torch.sum(discrete_act_dist.probs * target_critics, dim=-1) target_critic = target_critics.reshape(exp.reward.shape) if self._use_entropy_reward: entropy_reward = nest.map_structure( lambda la, lp: -torch.exp(la) * lp, self._log_alpha, log_pi) entropy_reward = sum(nest.flatten(entropy_reward)) target_critic = target_critic + entropy_reward target_critic = target_critic.detach() state = SacCriticState(critics=critics_state, target_critics=target_critics_state) info = SacCriticInfo(critics=critics, target_critic=target_critic) return state, info
def _actor_train_step(self, exp: Experience, state: DdpgActorState): action, actor_state = self._actor_network(exp.observation, state=state.actor) q_values, critic_states = self._critic_networks( (exp.observation, action), state=state.critics) if q_values.ndim == 3: # Multidimensional reward: [B, num_criric_replicas, reward_dim] if self._reward_weights is None: q_values = q_values.sum(dim=2) else: q_values = torch.tensordot(q_values, self._reward_weights, dims=1) if self._num_critic_replicas > 1: q_value = q_values.min(dim=1)[0] else: q_value = q_values.squeeze(dim=1) dqda = nest_utils.grad(action, q_value.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = torch.clamp(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) if self._action_l2 > 0: assert action.requires_grad loss += self._action_l2 * (action**2) loss = loss.sum(list(range(1, loss.ndim))) return loss actor_loss = nest.map_structure(actor_loss_fn, dqda, action) state = DdpgActorState(actor=actor_state, critics=critic_states) info = LossInfo(loss=sum(nest.flatten(actor_loss)), extra=actor_loss) return AlgStep(output=action, state=state, info=info)
def after_update(self, experience, train_info: SacInfo): self._update_target() if self._max_log_alpha is not None: nest.map_structure( lambda la: la.data.copy_(torch.min(la, self._max_log_alpha)), self._log_alpha)
def _alpha_train_step(self, log_pi): alpha_loss = nest.map_structure( lambda la, lp, t: la * (-lp - t).detach(), self._log_alpha, log_pi, self._target_entropy) return sum(nest.flatten(alpha_loss))
def __init__(self, observation_spec, action_spec: BoundedTensorSpec, actor_network_cls=ActorDistributionNetwork, critic_network_cls=CriticNetwork, q_network_cls=QNetwork, reward_weights=None, use_entropy_reward=True, use_parallel_network=False, num_critic_replicas=2, env=None, config: TrainerConfig = None, critic_loss_ctor=None, target_entropy=None, prior_actor_ctor=None, target_kld_per_dim=3., initial_log_alpha=0.0, max_log_alpha=None, target_update_tau=0.05, target_update_period=1, dqda_clipping=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, debug_summaries=False, name="SacAlgorithm"): """ Args: observation_spec (nested TensorSpec): representing the observations. action_spec (nested BoundedTensorSpec): representing the actions; can be a mixture of discrete and continuous actions. The number of continuous actions can be arbitrary while only one discrete action is allowed currently. If it's a mixture, then it must be a tuple/list ``(discrete_action_spec, continuous_action_spec)``. actor_network_cls (Callable): is used to construct the actor network. The constructed actor network will be called to sample continuous actions. All of its output specs must be continuous. Note that we don't need a discrete actor network because a discrete action can simply be sampled from the Q values. critic_network_cls (Callable): is used to construct critic network. for estimating ``Q(s,a)`` given that the action is continuous. q_network (Callable): is used to construct QNetwork for estimating ``Q(s,a)`` given that the action is discrete. Its output spec must be consistent with the discrete action in ``action_spec``. reward_weights (None|list[float]): this is only used when the reward is multidimensional. In that case, the weighted sum of the q values is used for training the actor if reward_weights is not None. Otherwise, the sum of the q values is used. use_entropy_reward (bool): whether to include entropy as reward use_parallel_network (bool): whether to use parallel network for calculating critics. num_critic_replicas (int): number of critics to be used. Default is 2. env (Environment): The environment to interact with. ``env`` is a batched environment, which means that it runs multiple simulations simultateously. ``env` only needs to be provided to the root algorithm. config (TrainerConfig): config for training. It only needs to be provided to the algorithm which performs ``train_iter()`` by itself. critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss constructor. If ``None``, a default ``OneStepTDLoss`` will be used. initial_log_alpha (float): initial value for variable ``log_alpha``. max_log_alpha (float|None): if not None, ``log_alpha`` will be capped at this value. target_entropy (float|Callable|None): If a floating value, it's the target average policy entropy, for updating ``alpha``. If a callable function, then it will be called on the action spec to calculate a target entropy. If ``None``, a default entropy will be calculated. For the mixed action type, discrete action and continuous action will have separate alphas and target entropies, so this argument can be a 2-element list/tuple, where the first is for discrete action and the second for continuous action. prior_actor_ctor (Callable): If provided, it will be called using ``prior_actor_ctor(observation_spec, action_spec, debug_summaries=debug_summaries)`` to constructor a prior actor. The output of the prior actor is the distribution of the next action. Two prior actors are implemented: ``alf.algorithms.prior_actor.SameActionPriorActor`` and ``alf.algorithms.prior_actor.UniformPriorActor``. target_kld_per_dim (float): ``alpha`` is dynamically adjusted so that the KLD is about ``target_kld_per_dim * dim``. target_update_tau (float): Factor for soft update of the target networks. target_update_period (int): Period for soft update of the target networks. dqda_clipping (float): when computing the actor loss, clips the gradient dqda element-wise between ``[-dqda_clipping, dqda_clipping]``. Will not perform clipping if ``dqda_clipping == 0``. actor_optimizer (torch.optim.optimizer): The optimizer for actor. critic_optimizer (torch.optim.optimizer): The optimizer for critic. alpha_optimizer (torch.optim.optimizer): The optimizer for alpha. debug_summaries (bool): True if debug summaries should be created. name (str): The name of this algorithm. """ self._num_critic_replicas = num_critic_replicas self._use_parallel_network = use_parallel_network critic_networks, actor_network, self._act_type, reward_dim = self._make_networks( observation_spec, action_spec, actor_network_cls, critic_network_cls, q_network_cls) self._use_entropy_reward = use_entropy_reward if reward_dim > 1: assert not use_entropy_reward, ( "use_entropy_reward=True is not supported for multidimensional reward" ) assert self._act_type == ActionType.Continuous, ( "Only continuous action is supported for multidimensional reward" ) self._reward_weights = None if reward_weights: assert reward_dim > 1, ( "reward_weights cannot be used for one dimensional reward") assert len(reward_weights) == reward_dim, ( "Mismatch between len(reward_weights)=%s and reward_dim=%s" % (len(reward_weights), reward_dim)) self._reward_weights = torch.tensor(reward_weights, dtype=torch.float32) def _init_log_alpha(): return nn.Parameter(torch.tensor(float(initial_log_alpha))) if self._act_type == ActionType.Mixed: # separate alphas for discrete and continuous actions log_alpha = type(action_spec)( (_init_log_alpha(), _init_log_alpha())) else: log_alpha = _init_log_alpha() action_state_spec = SacActionState( actor_network=(() if self._act_type == ActionType.Discrete else actor_network.state_spec), critic=(() if self._act_type == ActionType.Continuous else critic_networks.state_spec)) super().__init__( observation_spec, action_spec, train_state_spec=SacState( action=action_state_spec, actor=(() if self._act_type != ActionType.Continuous else critic_networks.state_spec), critic=SacCriticState( critics=critic_networks.state_spec, target_critics=critic_networks.state_spec)), predict_state_spec=SacState(action=action_state_spec), env=env, config=config, debug_summaries=debug_summaries, name=name) if actor_optimizer is not None: self.add_optimizer(actor_optimizer, [actor_network]) if critic_optimizer is not None: self.add_optimizer(critic_optimizer, [critic_networks]) if alpha_optimizer is not None: self.add_optimizer(alpha_optimizer, nest.flatten(log_alpha)) self._log_alpha = log_alpha if self._act_type == ActionType.Mixed: self._log_alpha_paralist = nn.ParameterList( nest.flatten(log_alpha)) if max_log_alpha is not None: self._max_log_alpha = torch.tensor(float(max_log_alpha)) else: self._max_log_alpha = None self._actor_network = actor_network self._critic_networks = critic_networks self._target_critic_networks = self._critic_networks.copy( name='target_critic_networks') if critic_loss_ctor is None: critic_loss_ctor = OneStepTDLoss critic_loss_ctor = functools.partial(critic_loss_ctor, debug_summaries=debug_summaries) # Have different names to separate their summary curves self._critic_losses = [] for i in range(num_critic_replicas): self._critic_losses.append( critic_loss_ctor(name="critic_loss%d" % (i + 1))) self._prior_actor = None if prior_actor_ctor is not None: assert self._act_type == ActionType.Continuous, ( "Only continuous action is supported when using prior_actor") self._prior_actor = prior_actor_ctor( observation_spec=observation_spec, action_spec=action_spec, debug_summaries=debug_summaries) total_action_dims = sum( [spec.numel for spec in alf.nest.flatten(action_spec)]) self._target_entropy = -target_kld_per_dim * total_action_dims else: if self._act_type == ActionType.Mixed: if not isinstance(target_entropy, (tuple, list)): target_entropy = nest.map_structure( lambda _: target_entropy, self._action_spec) # separate target entropies for discrete and continuous actions self._target_entropy = nest.map_structure( lambda spec, t: _set_target_entropy(self.name, t, [spec]), self._action_spec, target_entropy) else: self._target_entropy = _set_target_entropy( self.name, target_entropy, nest.flatten(self._action_spec)) self._dqda_clipping = dqda_clipping self._update_target = common.get_target_updater( models=[self._critic_networks], target_models=[self._target_critic_networks], tau=target_update_tau, period=target_update_period)
def calc_loss(self, info: DynamicsInfo): # Here we take mean over the loss to avoid undesired additional # masking from base algorithm's ``update_with_gradient``. scalar_loss = nest.map_structure(torch.mean, info.loss) return LossInfo(scalar_loss=scalar_loss.loss, extra=scalar_loss.loss)
def _get_observation(self): batch_size = (self._batch_size, ) if self._batch_size else () return nest.map_structure( lambda spec: self._sample_spec(spec, batch_size).cpu().numpy(), self._observation_spec)