def test_log_probability_(self): M = td.categorical.Categorical( torch.Tensor([[0.25, 0.75], [0.5, 0.5], [0.75, 0.25]])) actions = torch.Tensor([0]).repeat(3) expected = torch.Tensor([-1.38629436, -0.6931471, -0.287682]) obtained = compute_log_probability(M, actions) np.testing.assert_array_almost_equal(expected, obtained)
def test_log_probability(self): m = td.categorical.Categorical(torch.Tensor([0.25, 0.75])) M = m.expand([2, 3]) actions = torch.Tensor([1]).repeat(2, 3) expected = torch.Tensor([[-0.287682, -0.287682, -0.287682], [-0.287682, -0.287682, -0.287682]]) obtained = compute_log_probability(M, actions) np.testing.assert_array_almost_equal(expected, obtained)
def _train_step(self, time_step: TimeStep, state: SarsaState): not_first_step = time_step.step_type != StepType.FIRST prev_critics, critic_states = self._critic_networks( (state.prev_observation, time_step.prev_action), state.critics) critic_states = common.reset_state_if_necessary( state.critics, critic_states, not_first_step) action_distribution, action, actor_state, noise_state = self._get_action( self._actor_network, time_step, state) critics, _ = self._critic_networks((time_step.observation, action), critic_states) critic = critics.min(dim=1)[0] dqda = nest_utils.grad(action, critic.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = dqda.clamp(-self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) loss = loss.sum(list(range(1, loss.ndim))) return loss actor_loss = nest_map(actor_loss_fn, dqda, action) actor_loss = math_ops.add_n(alf.nest.flatten(actor_loss)) neg_entropy = () if self._log_alpha is not None: neg_entropy = dist_utils.compute_log_probability( action_distribution, action) target_critics, target_critic_states = self._target_critic_networks( (time_step.observation, action), state.target_critics) info = SarsaInfo(action_distribution=action_distribution, actor_loss=actor_loss, critics=prev_critics, neg_entropy=neg_entropy, target_critics=target_critics.min(dim=1)[0]) rl_state = SarsaState(noise=noise_state, prev_observation=time_step.observation, prev_step_type=time_step.step_type, actor=actor_state, critics=critic_states, target_critics=target_critic_states) return AlgStep(action, rl_state, info)
def _predict_skill_loss(self, observation, prev_action, prev_skill, steps, state): # steps -> {1,2,3} if self._skill_type == "action": subtrajectory = (state.first_observation, prev_action) elif self._skill_type == "action_difference": action_difference = prev_action - state.subtrajectory[:, 1, :] subtrajectory = (state.first_observation, action_difference) elif self._skill_type == "state_action": subtrajectory = (observation, prev_action) elif self._skill_type == "state": subtrajectory = observation elif self._skill_type == "state_difference": subtrajectory = observation - state.untrans_observation elif "concatenation" in self._skill_type: subtrajectory = alf.nest.map_structure( lambda traj: traj.reshape(observation.shape[0], -1), state.subtrajectory) if is_action_skill(self._skill_type): subtrajectory = (state.first_observation, subtrajectory.prev_action) else: subtrajectory = subtrajectory.observation if self._skill_encoder is not None: steps = self._num_steps_per_skill - steps if not isinstance(subtrajectory, tuple): subtrajectory = (subtrajectory, ) subtrajectory = subtrajectory + (steps, ) with torch.no_grad(): prev_skill, _ = self._skill_encoder((prev_skill, steps)) if isinstance(self._skill_discriminator, EncodingNetwork): pred_skill, _ = self._skill_discriminator(subtrajectory) loss = torch.sum(losses.element_wise_squared_loss( prev_skill, pred_skill), dim=-1) else: pred_skill_dist, _ = self._skill_discriminator(subtrajectory) loss = -dist_utils.compute_log_probability(pred_skill_dist, prev_skill) return loss
def train_step(self, exp: Experience, state: SacState): # We detach exp.observation here so that in the case that exp.observation # is calculated by some other trainable module, the training of that # module will not be affected by the gradient back-propagated from the # actor. However, the gradient from critic will still affect the training # of that module. (action_distribution, action, critics, action_state) = self._predict_action(common.detach(exp.observation), state=state.action) log_pi = nest.map_structure(lambda dist, a: dist.log_prob(a), action_distribution, action) if self._act_type == ActionType.Mixed: # For mixed type, add log_pi separately log_pi = type(self._action_spec)( (sum(nest.flatten(log_pi[0])), sum(nest.flatten(log_pi[1])))) else: log_pi = sum(nest.flatten(log_pi)) if self._prior_actor is not None: prior_step = self._prior_actor.train_step(exp, ()) log_prior = dist_utils.compute_log_probability( prior_step.output, action) log_pi = log_pi - log_prior actor_state, actor_loss = self._actor_train_step( exp, state.actor, action, critics, log_pi, action_distribution) critic_state, critic_info = self._critic_train_step( exp, state.critic, action, log_pi, action_distribution) alpha_loss = self._alpha_train_step(log_pi) state = SacState(action=action_state, actor=actor_state, critic=critic_state) info = SacInfo(action_distribution=action_distribution, actor=actor_loss, critic=critic_info, alpha=alpha_loss) return AlgStep(action, state, info)
def _pg_loss(self, experience, train_info, advantages): action_log_prob = dist_utils.compute_log_probability( train_info.action_distribution, experience.action) return -advantages * action_log_prob
def action_importance_ratio(action_distribution, collect_action_distribution, action, clipping_mode, scope, importance_ratio_clipping, log_prob_clipping, check_numerics, debug_summaries): """ ratio for importance sampling, used in PPO loss and vtrace loss. Caller has to save tf.name_scope() and pass scope to this function. Args: action_distribution (nested tf.distribution): Distribution over actions under target policy. collect_action_distribution (nested tf.distribution): distribution over actions from behavior policy, used to sample actions for the rollout. action (nested tf.distribution): possibly batched action tuple taken during rollout. clipping_mode (str): mode for clipping the importance ratio. 'double_sided': clips the range of importance ratio into [1-importance_ratio_clipping, 1+importance_ratio_clipping], which is used by PPOLoss. 'capping': clips the range of importance ratio into min(1+importance_ratio_clipping, importance_ratio), which is used by VTraceLoss, where c_bar or rho_bar = 1+importance_ratio_clipping. scope (name scope manager): returned by tf.name_scope(), set outside. importance_ratio_clipping (float): Epsilon in clipped, surrogate PPO objective. See the cited paper for more detail. log_prob_clipping (float): If >0, clipping log probs to the range (-log_prob_clipping, log_prob_clipping) to prevent inf / NaN values. check_numerics (bool): If true, adds tf.debugging.check_numerics to help find NaN / Inf values. For debugging only. debug_summaries (bool): If true, output summary metrics to tf. Returns: importance_ratio (Tensor), importance_ratio_clipped (Tensor). """ current_policy_distribution = action_distribution sample_action_log_probs = dist_utils.compute_log_probability( collect_action_distribution, action).detach() action_log_prob = dist_utils.compute_log_probability( current_policy_distribution, action) if log_prob_clipping > 0.0: action_log_prob = action_log_prob.clamp(-log_prob_clipping, log_prob_clipping) if check_numerics: assert torch.all(torch.isfinite(action_log_prob)) # Prepare both clipped and unclipped importance ratios. importance_ratio = (action_log_prob - sample_action_log_probs).exp() if check_numerics: assert torch.all(torch.isfinite(importance_ratio)) if clipping_mode == 'double_sided': importance_ratio_clipped = importance_ratio.clamp( 1 - importance_ratio_clipping, 1 + importance_ratio_clipping) elif clipping_mode == 'capping': importance_ratio_clipped = torch.min( importance_ratio, torch.tensor(1 + importance_ratio_clipping)) else: raise Exception('Unsupported clipping mode: ' + clipping_mode) if debug_summaries and alf.summary.should_record_summaries(): with scope: if importance_ratio_clipping > 0.0: clip_fraction = (torch.abs(importance_ratio - 1.0) > importance_ratio_clipping).to( torch.float32).mean() alf.summary.scalar('clip_fraction', clip_fraction) alf.summary.histogram('action_log_prob', action_log_prob) alf.summary.histogram('action_log_prob_sample', sample_action_log_probs) alf.summary.histogram('importance_ratio', importance_ratio) alf.summary.scalar('importance_ratio_mean', importance_ratio.mean()) alf.summary.histogram('importance_ratio_clipped', importance_ratio_clipped) return importance_ratio, importance_ratio_clipped