def train_step(self, exp: Experience, state, trainable=True): """This function trains the discriminator or generates intrinsic rewards. If ``trainable=True``, then it only generates and returns the pred loss; otherwise it only generates rewards with no grad. """ # Discriminator training from its own replay buffer or # Discriminator computing intrinsic rewards for training lower_rl untrans_observation, prev_skill, switch_skill, steps = exp.observation observation = self._observation_transformer(untrans_observation) loss = self._predict_skill_loss(observation, exp.prev_action, prev_skill, steps, state) first_observation = self._update_state_if_necessary( switch_skill, observation, state.first_observation) subtrajectory = self._clear_subtrajectory_if_necessary( state.subtrajectory, switch_skill) new_state = DiscriminatorState(first_observation=first_observation, untrans_observation=untrans_observation, subtrajectory=subtrajectory) valid_masks = (exp.step_type != StepType.FIRST) if self._sparse_reward: # Only give intrinsic rewards at the last step of the skill valid_masks &= switch_skill loss *= valid_masks.to(torch.float32) if trainable: info = LossInfo(loss=loss, extra=dict(discriminator_loss=loss)) return AlgStep(state=new_state, info=info) else: intrinsic_reward = -loss.detach() / self._skill_dim return AlgStep(state=common.detach(new_state), info=intrinsic_reward)
def _buffer_sampler(self, x, y): batch_size = get_nest_batch_size(y) if self._y_buffer.current_size >= batch_size: y1 = self._y_buffer.get_batch(batch_size) self._y_buffer.add_batch(y) else: self._y_buffer.add_batch(y) y1 = self._y_buffer.get_batch(batch_size) return x, common.detach(y1)
def rollout_step(self, time_step: TimeStep, state: ActorCriticState): """Rollout for one step.""" value, value_state = self._value_network(time_step.observation, state=state.value) # We detach exp.observation here so that in the case that exp.observation # is calculated by some other trainable module, the training of that # module will not be affected by the gradient back-propagated from the # actor. However, the gradient from critic will still affect the training # of that module. action_distribution, actor_state = self._actor_network( common.detach(time_step.observation), state=state.actor) action = dist_utils.sample_action_distribution(action_distribution) return AlgStep(output=action, state=ActorCriticState(actor=actor_state, value=value_state), info=ActorCriticInfo( value=value, action_distribution=action_distribution))
def train_step(self, exp: Experience, state: SacState): # We detach exp.observation here so that in the case that exp.observation # is calculated by some other trainable module, the training of that # module will not be affected by the gradient back-propagated from the # actor. However, the gradient from critic will still affect the training # of that module. (action_distribution, action, critics, action_state) = self._predict_action(common.detach(exp.observation), state=state.action) log_pi = nest.map_structure(lambda dist, a: dist.log_prob(a), action_distribution, action) if self._act_type == ActionType.Mixed: # For mixed type, add log_pi separately log_pi = type(self._action_spec)( (sum(nest.flatten(log_pi[0])), sum(nest.flatten(log_pi[1])))) else: log_pi = sum(nest.flatten(log_pi)) if self._prior_actor is not None: prior_step = self._prior_actor.train_step(exp, ()) log_prior = dist_utils.compute_log_probability( prior_step.output, action) log_pi = log_pi - log_prior actor_state, actor_loss = self._actor_train_step( exp, state.actor, action, critics, log_pi, action_distribution) critic_state, critic_info = self._critic_train_step( exp, state.critic, action, log_pi, action_distribution) alpha_loss = self._alpha_train_step(log_pi) state = SacState(action=action_state, actor=actor_state, critic=critic_state) info = SacInfo(action_distribution=action_distribution, actor=actor_loss, critic=critic_info, alpha=alpha_loss) return AlgStep(action, state, info)
def unroll(self, unroll_length): r"""Unroll ``unroll_length`` steps using the current policy. Because the ``self._env`` is a batched environment. The total number of environment steps is ``self._env.batch_size * unroll_length``. Args: unroll_length (int): number of steps to unroll. Returns: Experience: The stacked experience with shape :math:`[T, B, \ldots]` for each of its members. """ if self._current_time_step is None: self._current_time_step = common.get_initial_time_step(self._env) if self._current_policy_state is None: self._current_policy_state = self.get_initial_rollout_state( self._env.batch_size) if self._current_transform_state is None: self._current_transform_state = self.get_initial_transform_state( self._env.batch_size) time_step = self._current_time_step policy_state = self._current_policy_state trans_state = self._current_transform_state experience_list = [] initial_state = self.get_initial_rollout_state(self._env.batch_size) env_step_time = 0. store_exp_time = 0. for _ in range(unroll_length): policy_state = common.reset_state_if_necessary( policy_state, initial_state, time_step.is_first()) transformed_time_step, trans_state = self.transform_timestep( time_step, trans_state) # save the untransformed time step in case that sub-algorithms need # to store it in replay buffers transformed_time_step = transformed_time_step._replace( untransformed=time_step) policy_step = self.rollout_step(transformed_time_step, policy_state) # release the reference to ``time_step`` transformed_time_step = transformed_time_step._replace( untransformed=()) action = common.detach(policy_step.output) t0 = time.time() next_time_step = self._env.step(action) env_step_time += time.time() - t0 self.observe_for_metrics(time_step.cpu()) if self._exp_replayer_type == "one_time": exp = make_experience(transformed_time_step, policy_step, policy_state) else: exp = make_experience(time_step.cpu(), policy_step, policy_state) t0 = time.time() self.observe_for_replay(exp) store_exp_time += time.time() - t0 exp_for_training = Experience( action=action, reward=transformed_time_step.reward, discount=transformed_time_step.discount, step_type=transformed_time_step.step_type, state=policy_state, prev_action=transformed_time_step.prev_action, observation=transformed_time_step.observation, rollout_info=dist_utils.distributions_to_params( policy_step.info), env_id=transformed_time_step.env_id) experience_list.append(exp_for_training) time_step = next_time_step policy_state = policy_step.state alf.summary.scalar("time/unroll_env_step", env_step_time) alf.summary.scalar("time/unroll_store_exp", store_exp_time) experience = alf.nest.utils.stack_nests(experience_list) experience = experience._replace( rollout_info=dist_utils.params_to_distributions( experience.rollout_info, self._rollout_info_spec)) self._current_time_step = time_step # Need to detach so that the graph from this unroll is disconnected from # the next unroll. Otherwise backward() will report error for on-policy # training after the next unroll. self._current_policy_state = common.detach(policy_state) self._current_transform_state = common.detach(trans_state) return experience