def _actor_train_step(self, exp: Experience, state, action, critics, log_pi, action_distribution): neg_entropy = sum(nest.flatten(log_pi)) if self._act_type == ActionType.Discrete: # Pure discrete case doesn't need to learn an actor network return (), LossInfo(extra=SacActorInfo(neg_entropy=neg_entropy)) if self._act_type == ActionType.Continuous: critics, critics_state = self._compute_critics( self._critic_networks, exp.observation, action, state) if critics.ndim == 3: # Multidimensional reward: [B, num_criric_replicas, reward_dim] if self._reward_weights is None: critics = critics.sum(dim=2) else: critics = torch.tensordot(critics, self._reward_weights, dims=1) target_q_value = critics.min(dim=1)[0] continuous_log_pi = log_pi cont_alpha = torch.exp(self._log_alpha).detach() else: # use the critics computed during action prediction for Mixed type critics_state = () discrete_act_dist = action_distribution[0] discrete_entropy = discrete_act_dist.entropy() # critics is already after min over replicas weighted_q_value = torch.sum(discrete_act_dist.probs * critics, dim=-1) discrete_alpha = torch.exp(self._log_alpha[0]).detach() target_q_value = weighted_q_value + discrete_alpha * discrete_entropy action, continuous_log_pi = action[1], log_pi[1] cont_alpha = torch.exp(self._log_alpha[1]).detach() dqda = nest_utils.grad(action, target_q_value.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = torch.clamp(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) return loss.sum(list(range(1, loss.ndim))) actor_loss = nest.map_structure(actor_loss_fn, dqda, action) actor_loss = math_ops.add_n(nest.flatten(actor_loss)) actor_info = LossInfo(loss=actor_loss + cont_alpha * continuous_log_pi, extra=SacActorInfo(actor_loss=actor_loss, neg_entropy=neg_entropy)) return critics_state, actor_info
def calc_loss(self, training_info: TrainingInfo): info = training_info.info # SarsaInfo critic_loss = losses.element_wise_squared_loss(info.returns, info.critic) not_first_step = tf.not_equal(training_info.step_type, StepType.FIRST) critic_loss *= tf.cast(not_first_step, tf.float32) def _summary(): with self.name_scope: tf.summary.scalar("values", tf.reduce_mean(info.critic)) tf.summary.scalar("returns", tf.reduce_mean(info.returns)) safe_mean_hist_summary("td_error", info.returns - info.critic) tf.summary.scalar( "explained_variance_of_return_by_value", common.explained_variance(info.critic, info.returns)) if self._debug_summaries: common.run_if(common.should_record_summaries(), _summary) return LossInfo( loss=info.actor_loss, # put critic_loss to scalar_loss because loss will be masked by # ~is_last at train_complete(). The critic_loss here should be # masked by ~is_first instead, which is done above. scalar_loss=tf.reduce_mean(critic_loss), extra=SarsaLossInfo(actor=info.actor_loss, critic=critic_loss))
def _actor_train_step(self, exp: Experience, state: DdpgActorState): action, actor_state = self._actor_network(exp.observation, exp.step_type, network_state=state.actor) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(action) q_value, critic_state = self._critic_network( (exp.observation, action), network_state=state.critic) dqda = tape.gradient(q_value, action) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( tf.stop_gradient(dqda + action), action) loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape)))) return loss actor_loss = tf.nest.map_structure(actor_loss_fn, dqda, action) state = DdpgActorState(actor=actor_state, critic=critic_state) info = LossInfo(loss=tf.add_n(tf.nest.flatten(actor_loss)), extra=actor_loss) return PolicyStep(action=action, state=state, info=info)
def train_step(self, exp: TimeStep, state): # [B, num_unroll_steps + 1] info = exp.rollout_info targets = common.as_list(info.target) batch_size = exp.step_type.shape[0] latent, state = self._encoding_net(exp.observation, state) sim_latent = self._multi_step_latent_rollout(latent, self._num_unroll_steps, info.action, state) loss = 0 for i, decoder in enumerate(self._decoders): # [num_unroll_steps + 1)*B, ...] train_info = decoder.train_step(sim_latent).info train_info_spec = dist_utils.extract_spec(train_info) train_info = dist_utils.distributions_to_params(train_info) train_info = alf.nest.map_structure( lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x. shape[1:]), train_info) # [num_unroll_steps + 1, B, ...] train_info = dist_utils.params_to_distributions( train_info, train_info_spec) target = alf.nest.map_structure(lambda x: x.transpose(0, 1), targets[i]) loss_info = decoder.calc_loss(target, train_info, info.mask.t()) loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info) loss += loss_info.loss loss_info = LossInfo(loss=loss, extra=loss) return AlgStep(output=latent, state=state, info=loss_info)
def train_step(self, exp: Experience, state, trainable=True): """This function trains the discriminator or generates intrinsic rewards. If ``trainable=True``, then it only generates and returns the pred loss; otherwise it only generates rewards with no grad. """ # Discriminator training from its own replay buffer or # Discriminator computing intrinsic rewards for training lower_rl untrans_observation, prev_skill, switch_skill, steps = exp.observation observation = self._observation_transformer(untrans_observation) loss = self._predict_skill_loss(observation, exp.prev_action, prev_skill, steps, state) first_observation = self._update_state_if_necessary( switch_skill, observation, state.first_observation) subtrajectory = self._clear_subtrajectory_if_necessary( state.subtrajectory, switch_skill) new_state = DiscriminatorState(first_observation=first_observation, untrans_observation=untrans_observation, subtrajectory=subtrajectory) valid_masks = (exp.step_type != StepType.FIRST) if self._sparse_reward: # Only give intrinsic rewards at the last step of the skill valid_masks &= switch_skill loss *= valid_masks.to(torch.float32) if trainable: info = LossInfo(loss=loss, extra=dict(discriminator_loss=loss)) return AlgStep(state=new_state, info=info) else: intrinsic_reward = -loss.detach() / self._skill_dim return AlgStep(state=common.detach(new_state), info=intrinsic_reward)
def _calc_critic_loss(self, experience, train_info: SacInfo): critic_info = train_info.critic critic_losses = [] for i, l in enumerate(self._critic_losses): critic_losses.append( l(experience=experience, value=critic_info.critics[:, :, i, ...], target_value=critic_info.target_critic).loss) critic_loss = math_ops.add_n(critic_losses) if (experience.batch_info != () and experience.batch_info.importance_weights != ()): valid_masks = (experience.step_type != StepType.LAST).to( torch.float32) valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0) priority = ((critic_loss * valid_masks).sum(dim=0) / valid_n).sqrt() else: priority = () return LossInfo(loss=critic_loss, priority=priority, extra=critic_loss / float(self._num_critic_replicas))
def calc_loss(self, experience, train_info: DdpgInfo): critic_losses = [None] * self._num_critic_replicas for i in range(self._num_critic_replicas): critic_losses[i] = self._critic_losses[i]( experience=experience, value=train_info.critic.q_values[:, :, i, ...], target_value=train_info.critic.target_q_values).loss critic_loss = math_ops.add_n(critic_losses) if (experience.batch_info != () and experience.batch_info.importance_weights != ()): valid_masks = (experience.step_type != StepType.LAST).to( torch.float32) valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0) priority = ((critic_loss * valid_masks).sum(dim=0) / valid_n).sqrt() else: priority = () actor_loss = train_info.actor_loss return LossInfo(loss=critic_loss + actor_loss.loss, priority=priority, extra=DdpgLossInfo(critic=critic_loss, actor=actor_loss.extra))
def _minmax_grad(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via minmax svgd (Fisher Neural Sampler). """ assert inputs is None, '"minmax" does not support conditional generator' # optimize the critic using resampled particles assert transform_func is None, ( "function value based vi is not supported for minmax_grad") num_particles = outputs.shape[0] for i in range(self._critic_iter_num): if self._minmax_resample: critic_inputs, _ = self._predict(inputs, batch_size=num_particles) else: critic_inputs = outputs.detach().clone() critic_inputs.requires_grad = True critic_loss = self._critic_train_step(critic_inputs, loss_func, entropy_regularization) self._critic.update_with_gradient(LossInfo(loss=critic_loss)) # compute amortized svgd loss = loss_func(outputs.detach()) critic_outputs = self._critic.predict_step(outputs.detach()).output loss_propagated = torch.sum(-critic_outputs.detach() * outputs, dim=-1) return loss, loss_propagated
def train_step(self, time_step: TimeStep, state: DynamicsState): """ Args: time_step (TimeStep): input data for dynamics learning state (DynamicsState): state for dynamics learning (previous observation) Returns: AlgStep: output: empty tuple () state (DynamicsState): state for training info (DynamicsInfo): """ feature = time_step.observation dynamics_step = self.predict_step(time_step, state) forward_pred = dynamics_step.output forward_loss = (feature - forward_pred)**2 forward_loss = 0.5 * forward_loss.mean( list(range(1, forward_loss.ndim))) # we mask out FIRST as its state is invalid valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32) forward_loss = forward_loss * valid_masks info = DynamicsInfo(loss=LossInfo( loss=forward_loss, extra=dict(forward_loss=forward_loss))) state = state._replace(feature=feature) return AlgStep(output=(), state=state, info=info)
def _step(self, time_step: TimeStep, state, calc_rewards=True): """This step is for both `rollout_step` and `train_step`. Args: time_step (TimeStep): input time_step data for ICM state (Tensor): state for ICM (previous observation) calc_rewards (bool): whether calculate rewards Returns: AlgStep: output: empty tuple () state: observation info (ICMInfo): """ feature = time_step.observation prev_action = time_step.prev_action.detach() # normalize observation for easier prediction if self._observation_normalizer is not None: feature = self._observation_normalizer.normalize(feature) if self._encoding_net is not None: feature, _ = self._encoding_net(feature) prev_feature = state forward_pred, _ = self._forward_net( inputs=[prev_feature.detach(), self._encode_action(prev_action)]) # nn.MSELoss doesn't support reducing along a dim forward_loss = 0.5 * torch.mean( math_ops.square(forward_pred - feature.detach()), dim=-1) action_pred, _ = self._inverse_net([prev_feature, feature]) if self._action_spec.is_discrete: inverse_loss = torch.nn.CrossEntropyLoss(reduction='none')( input=action_pred, target=prev_action.to(torch.int64)) else: # nn.MSELoss doesn't support reducing along a dim inverse_loss = 0.5 * torch.mean( math_ops.square(action_pred - prev_action), dim=-1) intrinsic_reward = () if calc_rewards: intrinsic_reward = forward_loss.detach() intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward) return AlgStep( output=(), state=feature, info=ICMInfo( reward=intrinsic_reward, loss=LossInfo( loss=forward_loss + inverse_loss, extra=dict( forward_loss=forward_loss, inverse_loss=inverse_loss))))
def calc_loss(self, experience, train_info: MerlinInfo): """Calculate loss.""" self.summarize_reward("reward", experience.reward) mbp_loss_info = self._mbp.calc_loss(experience, train_info.mbp_info) mba_loss_info = self._mba.calc_loss(experience, train_info.mba_info) return LossInfo(loss=mbp_loss_info.loss + mba_loss_info.loss, extra=MerlinLossInfo(mbp=mbp_loss_info.extra, mba=mba_loss_info.extra))
def calc_loss(self, training_info: TrainingInfo): critic_loss = self._calc_critic_loss(training_info) alpha_loss = training_info.info.alpha.loss actor_loss = training_info.info.actor.loss return LossInfo(loss=actor_loss.loss + critic_loss.loss + alpha_loss.loss, extra=SacLossInfo(actor=actor_loss.extra, critic=critic_loss.extra, alpha=alpha_loss.extra))
def train_step(self, time_step: TimeStep, state: DynamicsState): """ Args: time_step (TimeStep): time step structure. The ``prev_action`` from time_step will be used for predicting feature of the next step. It should be a Tensor of the shape [B, ...] or [B, n, ...] when n > 1, where n denotes the number of dynamics network replicas. When the input tensor has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. state (DynamicsState): state for dynamics learning with the following fields: - feature (Tensor): features of the previous observation of the shape [B, ...] or [B, n, ...] when n > 1. When ``state.feature`` has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. It is used for predicting the feature of the next step together with ``time_step.prev_action``. - network: the input state of the dynamics network Returns: AlgStep: outputs: empty tuple () state (DynamicsState): with the following fields - feature (Tensor): [B, ...] (or [B, n, ...] when n > 1) shape tensor representing the predicted feature of the next step - network: the updated state of the dynamics network info (DynamicsInfo): with the following fields being updated: - loss (LossInfo): - dist (td.Distribution): the predictive distribution which can be used for further calculation or summarization. """ feature = time_step.observation feature = self._expand_to_replica(feature, self._feature_spec) dynamics_step = self.predict_step(time_step, state) dist = dynamics_step.info.dist forward_loss = -dist.log_prob(feature - state.feature) if forward_loss.ndim > 2: # [B, n, ...] -> [B, ...] forward_loss = forward_loss.sum(1) if forward_loss.ndim > 1: forward_loss = forward_loss.mean(list(range(1, forward_loss.ndim))) valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32) forward_loss = forward_loss * valid_masks info = DynamicsInfo(loss=LossInfo( loss=forward_loss, extra=dict(forward_loss=forward_loss)), dist=dist) state = state._replace(feature=feature) return AlgStep(output=(), state=state, info=info)
def __call__(self, training_info: TrainingInfo, value): """Cacluate actor critic loss The first dimension of all the tensors is time dimension and the second dimesion is the batch dimension. Args: training_info (TrainingInfo): training_info collected by OnPolicyDriver/OffPolicyAlgorithm. All tensors in training_info are time-major value (tf.Tensor): the time-major tensor for the value at each time step Returns: loss_info (LossInfo): with loss_info.extra being ActorCriticLossInfo """ returns, advantages = self._calc_returns_and_advantages( training_info, value) def _summary(): with self.name_scope: tf.summary.scalar("values", tf.reduce_mean(value)) tf.summary.scalar("returns", tf.reduce_mean(returns)) tf.summary.scalar("advantages/mean", tf.reduce_mean(advantages)) tf.summary.histogram("advantages/value", advantages) tf.summary.scalar("explained_variance_of_return_by_value", common.explained_variance(value, returns)) if self._debug_summaries: common.run_if(common.should_record_summaries(), _summary) if self._normalize_advantages: advantages = _normalize_advantages(advantages, axes=(0, 1)) if self._advantage_clip: advantages = tf.clip_by_value(advantages, -self._advantage_clip, self._advantage_clip) pg_loss = self._pg_loss(training_info, tf.stop_gradient(advantages)) td_loss = self._td_error_loss_fn(tf.stop_gradient(returns), value) loss = pg_loss + self._td_loss_weight * td_loss entropy_loss = () if self._entropy_regularization is not None: entropy, entropy_for_gradient = dist_utils.entropy_with_fallback( training_info.info.action_distribution, self._action_spec) entropy_loss = -entropy loss -= self._entropy_regularization * entropy_for_gradient return LossInfo(loss=loss, extra=ActorCriticLossInfo(td_loss=td_loss, pg_loss=pg_loss, neg_entropy=entropy_loss))
def calc_loss(self, training_info: TrainingInfo): """Calculate loss.""" self.add_reward_summary("reward", training_info.reward) mbp_loss_info = self._mbp.calc_loss(training_info.info.mbp_info) mba_loss_info = self._mba.calc_loss( training_info._replace(info=training_info.info.mba_info)) return LossInfo(loss=mbp_loss_info.loss + mba_loss_info.loss, extra=MerlinLossInfo(mbp=mbp_loss_info.extra, mba=mba_loss_info.extra))
def _update_loss(loss_info, training_info, name, algorithm): if algorithm is None: return loss_info new_loss_info = algorithm.calc_loss( getattr(training_info.info, name)) return LossInfo( loss=add_ignore_empty(loss_info.loss, new_loss_info.loss), scalar_loss=add_ignore_empty(loss_info.scalar_loss, new_loss_info.scalar_loss), extra=loss_info.extra._replace(**{name: new_loss_info.extra}))
def calc_loss(self, target, predicted, mask=None): """Calculate the loss between ``target`` and ``predicted``. Args: target (Tensor): target to be predicted. Its shape is [T, B, ...] predicted (Tensor): predicted target. Its shape is [T, B, ...] mask (bool Tensor): indicating which target should be predicted. Its shape is [T, B]. Returns: LossInfo """ if self._target_normalizer: self._target_normalizer.update(target) target = self._target_normalizer.normalize(target) predicted = self._target_normalizer.normalize(predicted) loss = self._loss(predicted, target) if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): def _summarize1(pred, tgt, loss, mask, suffix): alf.summary.scalar( "explained_variance" + suffix, tensor_utils.explained_variance(pred, tgt, mask)) safe_mean_hist_summary('predict' + suffix, pred, mask) safe_mean_hist_summary('target' + suffix, tgt, mask) safe_mean_summary("loss" + suffix, loss, mask) def _summarize(pred, tgt, loss, mask, suffix): _summarize1(pred[0], tgt[0], loss[0], mask[0], suffix + "/current") if pred.shape[0] > 1: _summarize1(pred[1:], tgt[1:], loss[1:], mask[1:], suffix + "/future") if loss.ndim == 2: _summarize(predicted, target, loss, mask, '') elif not self._summarize_each_dimension: m = mask if m is not None: m = m.unsqueeze(-1).expand_as(predicted) _summarize(predicted, target, loss, m, '') else: for i in range(predicted.shape[2]): suffix = '/' + str(i) _summarize(predicted[..., i], target[..., i], loss[..., i], mask, suffix) if loss.ndim == 3: loss = loss.mean(dim=2) if mask is not None: loss = loss * mask return LossInfo(loss=loss * self._loss_weight, extra=loss)
def calc_loss(self, training_info: TrainingInfo): critic_loss = self._critic_loss( training_info=training_info, value=training_info.info.critic.q_value, target_value=training_info.info.critic.target_q_value) actor_loss = training_info.info.actor_loss return LossInfo(loss=critic_loss.loss + actor_loss.loss, extra=DdpgLossInfo(critic=critic_loss.extra, actor=actor_loss.extra))
def calc_loss(self, experience, train_info: MdqInfo): alpha_loss = train_info.alpha critic_loss, distill_loss = self._calc_critic_loss( experience, train_info) total_loss = critic_loss.loss + distill_loss + alpha_loss.loss.squeeze( -1) return LossInfo(loss=total_loss, extra=MdqLossInfo(critic=critic_loss.extra, alpha=alpha_loss.extra, distill=distill_loss))
def forward(self, experience, train_info): """Cacluate actor critic loss. The first dimension of all the tensors is time dimension and the second dimesion is the batch dimension. Args: experience (nest): experience used for training. All tensors are time-major. train_info (nest): information collected for training. It is batched from each ``AlgStep.info`` returned by ``rollout_step()`` (on-policy training) or ``train_step()`` (off-policy training). All tensors in ``train_info`` are time-major. Returns: LossInfo: with ``extra`` being ``ActorCriticLossInfo``. """ value = train_info.value returns, advantages = self._calc_returns_and_advantages( experience, value) if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): alf.summary.scalar("values", value.mean()) alf.summary.scalar("returns", returns.mean()) alf.summary.scalar("advantages/mean", advantages.mean()) alf.summary.histogram("advantages/value", advantages) alf.summary.scalar( "explained_variance_of_return_by_value", tensor_utils.explained_variance(value, returns)) if self._normalize_advantages: advantages = _normalize_advantages(advantages) if self._advantage_clip: advantages = torch.clamp(advantages, -self._advantage_clip, self._advantage_clip) pg_loss = self._pg_loss(experience, train_info, advantages.detach()) td_loss = self._td_error_loss_fn(returns.detach(), value) loss = pg_loss + self._td_loss_weight * td_loss entropy_loss = () if self._entropy_regularization is not None: entropy, entropy_for_gradient = dist_utils.entropy_with_fallback( train_info.action_distribution) entropy_loss = -entropy loss -= self._entropy_regularization * entropy_for_gradient return LossInfo(loss=loss, extra=ActorCriticLossInfo(td_loss=td_loss, pg_loss=pg_loss, neg_entropy=entropy_loss))
def decode_step(self, latent_vector, observations): """Calculate decoding loss.""" decoders = tf.nest.flatten(self._decoders) observations = tf.nest.flatten(observations) decoder_losses = [ decoder.train_step((latent_vector, obs)).info for decoder, obs in zip(decoders, observations) ] loss = tf.add_n([decoder_loss.loss for decoder_loss in decoder_losses]) decoder_losses = tf.nest.pack_sequence_as(self._decoders, decoder_losses) return LossInfo(loss=loss, extra=decoder_losses)
def calc_loss(self, experience, train_info: LossInfo): assert experience.batch_info != () if (experience.batch_info != () and experience.batch_info.importance_weights != ()): priority = (experience.rollout_info.value - experience.rollout_info.target.value[..., 0]) priority = priority.abs().sum(dim=1) else: priority = () return train_info._replace(loss=(), scalar_loss=train_info.loss.mean(), priority=priority)
def __call__(self, training_info: TrainingInfo, value, target_value): returns = value_ops.one_step_discounted_return( rewards=training_info.reward, values=target_value, step_types=training_info.step_type, discounts=training_info.discount * self._gamma) returns = common.tensor_extend(returns, value[-1]) if self._debug_summaries: with self.name_scope: tf.summary.scalar("values", tf.reduce_mean(value)) tf.summary.scalar("returns", tf.reduce_mean(returns)) loss = self._td_error_loss_fn(tf.stop_gradient(returns), value) return LossInfo(loss=loss, extra=loss)
def _update_loss(loss_info, algorithm, name): info = getattr(train_info, name) exp = _make_alg_experience(experience, name) new_loss_info = algorithm.calc_loss(exp, info) if loss_info is None: return new_loss_info._replace( extra={name: new_loss_info.extra}) else: loss_info.extra[name] = new_loss_info.extra return LossInfo(loss=add_ignore_empty(loss_info.loss, new_loss_info.loss), scalar_loss=add_ignore_empty( loss_info.scalar_loss, new_loss_info.scalar_loss), extra=loss_info.extra)
def _calc_critic_loss(self, training_info): critic_info = training_info.info.critic target_critic = critic_info.target_critic critic_loss1 = self._critic_loss(training_info=training_info, value=critic_info.critic1, target_value=target_critic) critic_loss2 = self._critic_loss(training_info=training_info, value=critic_info.critic2, target_value=target_critic) critic_loss = critic_loss1.loss + critic_loss2.loss return LossInfo(loss=critic_loss, extra=critic_loss)
def _calc_critic_loss(self, experience, train_info: MdqInfo): critic_info = train_info.critic # [t, B, n] critic_free_form = critic_info.critic_free_form # [t, B, n, action_dim] critic_adv_form = critic_info.critic_adv_form target_critic_free_form = critic_info.target_critic_free_form distill_target = critic_info.distill_target num_critic_replicas = critic_free_form.shape[2] alpha = torch.exp(self._log_alpha).detach() kl_wrt_prior = critic_info.kl_wrt_prior # [t, B, n, action_dim] -> [t, B] # note that currently the kl_wrt_prior is independent of ensembles, # we therefore slice over ensemble by taking the first element; # for the aciton dimension, the first element is the full KL kl_wrt_prior = kl_wrt_prior[..., 0, 0] # [t, B, n] -> [t, B] target_critic, min_target_ind = torch.min( target_critic_free_form, dim=2) # [t, B, n] -> [t, B] distill_target, _ = torch.min(distill_target, dim=2) target_critic_corrected = target_critic - alpha * kl_wrt_prior critic_losses = [] for j in range(num_critic_replicas): critic_losses.append(self._critic_losses[j]( experience=experience, value=critic_free_form[:, :, j], target_value=target_critic_corrected).loss) critic_loss = math_ops.add_n(critic_losses) distill_loss = ( critic_adv_form[..., -1] - distill_target.unsqueeze(2).detach())**2 # mean over replica distill_loss = distill_loss.mean(dim=2) return LossInfo( loss=critic_loss, extra=critic_loss / len(critic_losses)), distill_loss
def train_step(self, distribution, step_type): """Train step. Args: distribution (nested Distribution): action distribution from the policy. step_type (StepType): the step type for the distributions. Returns: AlgStep: ``info`` field is ``LossInfo``, other fields are empty. """ entropy, entropy_for_gradient = entropy_with_fallback(distribution) return AlgStep( output=(), state=(), info=EntropyTargetInfo(loss=LossInfo(loss=-entropy_for_gradient, extra=EntropyTargetLossInfo( neg_entropy=-entropy))))
def train_step(self, inputs, state: MBPState): """Train one step. Args: inputs (tuple): a tuple of (observation, action) """ observation, _ = inputs latent_vector, kld, next_state = self.encode_step(inputs, state) # TODO: decoder for action decoder_loss = self.decode_step(latent_vector, observation) return AlgorithmStep( outputs=latent_vector, state=next_state, info=LossInfo(loss=self._loss_weight * (decoder_loss.loss + kld), extra=MBPLossInfo(decoder=decoder_loss, vae=kld)))
def train_step(self, inputs, loss_func, outputs=None, batch_size=None, entropy_regularization=None, state=None): """ Args: inputs (nested Tensor): if None, the outputs is generated only from noise. outputs (Tensor): generator's output (possibly from previous runs) used for this train_step. loss_func (Callable): loss_func([outputs, inputs]) (loss_func(outputs) if inputs is None) returns a Tensor or namedtuple of tensors with field `loss`, which is a Tensor of shape [batch_size] a loss term for optimizing the generator. batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None. state: not used Returns: AlgorithmStep: outputs: Tensor with shape (batch_size, dim) info: LossInfo """ if outputs is None: outputs, gen_inputs = self._predict(inputs, batch_size=batch_size) if entropy_regularization is None: entropy_regularization = self._entropy_regularization loss, loss_propagated = self._grad_func(inputs, outputs, loss_func, entropy_regularization) mi_loss = () if self._mi_estimator is not None: mi_step = self._mi_estimator.train_step([gen_inputs, outputs]) mi_loss = mi_step.info.loss loss_propagated = loss_propagated + self._mi_weight * mi_loss return AlgStep(output=outputs, state=(), info=LossInfo( loss=loss_propagated, extra=GeneratorLossInfo(generator=loss, mi_estimator=mi_loss)))
def _step(self, time_step: TimeStep, state, calc_rewards=True): """ Args: time_step (TimeStep): input time_step data state (tuple): empty tuple () calc_rewards (bool): whether calculate rewards Returns: AlgStep: output: empty tuple () state: empty tuple () info: ICMInfo """ observation = time_step.observation if self._keep_stacked_frames > 0: # Assuming stacking in the first dim, we only keep the last frames. observation = observation[:, -self._keep_stacked_frames:, ...] if self._observation_normalizer is not None: observation = self._observation_normalizer.normalize(observation) if self._encoder_net is not None: with torch.no_grad(): observation, _ = self._encoder_net(observation) pred_embedding, _ = self._predictor_net(observation) with torch.no_grad(): target_embedding, _ = self._target_net(observation) loss = torch.sum(math_ops.square(pred_embedding - target_embedding), dim=-1) intrinsic_reward = () if calc_rewards: intrinsic_reward = loss.detach() if self._reward_normalizer: intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward, clip_value=self._reward_clip_value) return AlgStep(output=(), state=(), info=ICMInfo(reward=intrinsic_reward, loss=LossInfo(loss=loss)))