def predict_step(self, time_step: TimeStep, state, epsilon_greedy): """This function does one thing, i.e., every ``self._num_steps_per_skill`` it calls ``self._rl`` to generate new skills. """ switch_skill = self._should_switch_skills(time_step, state) discriminator_step = self._discriminator_predict_step( time_step, state, switch_skill) def _generate_new_skills(time_step, state): rl_step = self._rl.predict_step(time_step, state.rl, epsilon_greedy) return SkillGeneratorState(skill=rl_step.output, steps=torch.zeros_like(state.steps), rl=rl_step.state) new_state = conditional_update(target=state, cond=switch_skill, func=_generate_new_skills, time_step=time_step, state=state) new_state = new_state._replace(steps=new_state.steps + 1, discriminator=discriminator_step.state) return AlgStep(output=new_state.skill, state=new_state, info=SkillGeneratorInfo(switch_skill=switch_skill))
def predict_step(self, time_step: TimeStep, state, epsilon_greedy): switch_action = self._should_switch_action(time_step, state) @torch.no_grad() def _generate_new_action(time_step, state): repr_state = () if self._repr_learner is not None: repr_step = self._repr_learner.predict_step( time_step, state.repr) time_step = time_step._replace(observation=repr_step.output) repr_state = repr_step.state rl_step = self._rl.predict_step(time_step, state.rl, epsilon_greedy) steps, action = rl_step.output return ActionRepeatState( action=action, steps=steps + 1, # [0, K-1] -> [1, K] rl=rl_step.state, repr=repr_state) new_state = conditional_update( target=state, cond=switch_action, func=_generate_new_action, time_step=time_step, state=state) new_state = new_state._replace(steps=new_state.steps - 1) return AlgStep( output=new_state.action, state=new_state, # plot steps and action when rendering video info=dict(action=(new_state.action, new_state.steps)))
def test_conditional_update(self): def _func(x, y): return x + 1, y - 1 batch_size = 256 target = (tf.random.uniform([batch_size, 3]), tf.random.uniform([batch_size])) x = tf.random.uniform([batch_size, 3]) y = tf.random.uniform([batch_size]) cond = tf.constant([False] * batch_size) updated_target = conditional_update(target, cond, _func, x, y) self.assertAllEqual(updated_target[0], target[0]) self.assertAllEqual(updated_target[1], target[1]) cond = tf.constant([True] * batch_size) updated_target = conditional_update(target, cond, _func, x, y) self.assertAllEqual(updated_target[0], x + 1) self.assertAllEqual(updated_target[1], y - 1) cond = tf.random.uniform((batch_size, )) < 0.5 updated_target = conditional_update(target, cond, _func, x, y) self.assertAllEqual(select_from_mask(updated_target[0], cond), select_from_mask(x + 1, cond)) self.assertAllEqual(select_from_mask(updated_target[1], cond), select_from_mask(y - 1, cond)) vx = tf.Variable(initial_value=0.) vy = tf.Variable(initial_value=0.) def _func1(x, y): vx.assign(tf.reduce_sum(x)) vy.assign(tf.reduce_sum(y)) return () # test empty return conditional_update((), cond, _func1, x, y) self.assertEqual(vx, tf.reduce_sum(select_from_mask(x, cond))) self.assertEqual(vy, tf.reduce_sum(select_from_mask(y, cond)))
def test_conditional_update(self): def _func(x, y): return x + 1, y - 1 batch_size = 256 target = (torch.rand([batch_size, 3]), torch.rand([batch_size])) x = torch.rand([batch_size, 3]) y = torch.rand([batch_size]) cond = torch.as_tensor([False] * batch_size) updated_target = conditional_update(target, cond, _func, x, y) self.assertTensorEqual(updated_target[0], target[0]) self.assertTensorEqual(updated_target[1], target[1]) cond = torch.as_tensor([True] * batch_size) updated_target = conditional_update(target, cond, _func, x, y) self.assertTensorEqual(updated_target[0], x + 1) self.assertTensorEqual(updated_target[1], y - 1) cond = torch.rand((batch_size, )) < 0.5 updated_target = conditional_update(target, cond, _func, x, y) self.assertTensorEqual(select_from_mask(updated_target[0], cond), select_from_mask(x + 1, cond)) self.assertTensorEqual(select_from_mask(updated_target[1], cond), select_from_mask(y - 1, cond)) vx = torch.zeros(()) vy = torch.zeros(()) def _func1(x, y): vx.copy_(torch.sum(x)) vy.copy_(torch.sum(y)) return () # test empty return conditional_update((), cond, _func1, x, y) self.assertEqual(vx, torch.sum(select_from_mask(x, cond))) self.assertEqual(vy, torch.sum(select_from_mask(y, cond)))
def test_conditional_update_high_dims(self): def _func(x): return x**2 batch_size = 100 x = torch.rand([batch_size, 3, 4, 5]) y = torch.rand([batch_size, 3, 4, 5]) cond = torch.randint(high=2, size=[batch_size]).to(torch.bool) updated_y = conditional_update(y, cond, _func, x) self.assertTensorEqual(select_from_mask(updated_y, cond), select_from_mask(x**2, cond)) self.assertTensorEqual(select_from_mask(updated_y, ~cond), select_from_mask(y, ~cond))
def rollout_step(self, time_step: TimeStep, state: ActionRepeatState): switch_action = self._should_switch_action(time_step, state) state = state._replace( rl_reward=state.rl_reward + state.rl_discount * time_step.reward, rl_discount=state.rl_discount * time_step.discount * self._gamma) @torch.no_grad() def _generate_new_action(time_step, state): rl_time_step = time_step._replace( reward=state.rl_reward, discount=state.rl_discount) observation, repr_state = rl_time_step.observation, () if self._repr_learner is not None: repr_step = self._repr_learner.rollout_step( time_step, state.repr) observation = repr_step.output repr_state = repr_step.state rl_step = self._rl.rollout_step( rl_time_step._replace(observation=observation), state.rl) # Store to replay buffer. super(DynamicActionRepeatAgent, self).observe_for_replay( make_experience( rl_time_step._replace( # Store the untransformed observation so that later it will # be transformed again during training observation=rl_time_step.untransformed.observation), rl_step, state)) steps, action = rl_step.output return ActionRepeatState( action=action, steps=steps + 1, # [0, K-1] -> [1, K] repr=repr_state, rl=rl_step.state, rl_reward=torch.zeros_like(state.rl_reward), rl_discount=torch.ones_like(state.rl_discount)) new_state = conditional_update( target=state, cond=switch_action, func=_generate_new_action, time_step=time_step, state=state) new_state = new_state._replace(steps=new_state.steps - 1) return AlgStep(output=new_state.action, state=new_state)
def _make_low_level_observation(self, subtrajectory, skill, switch_skill, steps, updated_first_observation): r"""Given the skill generator's output, this function makes the skill-conditioned observation for the lower-level policy. Both observation and action are a stacking of recent states, with the most recent one appearing at index=0. X: first observation of a skill X': first observation of the next skill O: middle observation of a skill _: zero num_steps_per_skill=3: subtrajectory (discriminator) low_rl_observation O _ _ -> O X _ (steps==2) O O _ -> O O X (steps==3) X'O O -> X'_ _ (steps==1,switch_skill==True) The same applies to action except that there is no ``first_observation``. """ subtrajectory.observation[ torch.arange(updated_first_observation.shape[0]).long(), steps - 1] = updated_first_observation def _zero(subtrajectory): subtrajectory.prev_action.fill_(0.) # When switch_skill is because of FINAL steps, filling # 0s might have issues if the FINAL step comes before num_steps_per_skill. # But since RL algorithms don't train FINAL steps, for now we'll leave # it like this for simplicity. subtrajectory.observation[:, 1:, ...] = 0. return subtrajectory subtrajectory = conditional_update(target=subtrajectory, cond=switch_skill, func=_zero, subtrajectory=subtrajectory) subtrajectory = alf.nest.map_structure( lambda traj: traj.reshape(traj.shape[0], -1), subtrajectory) low_rl_observation = (alf.nest.flatten(subtrajectory) + [self._num_steps_per_skill - steps, skill]) return low_rl_observation
def rollout_step(self, time_step: TimeStep, state): r"""This function does three things: 1. every ``self._num_steps_per_skill`` it calls ``self._rl`` to generate new skills. 2. at the same time writes ``time_step`` to a replay buffer when new skills are generated. 3. call ``rollout_step()`` of the discriminator to write ``time_step`` to a replay buffer for training Regarding accumulating rewards for the higher-level policy. Suppose that during an episode we have :math:`H` segments where each segment contains :math:`K` steps. Then the objective for the higher-level policy is: .. math:: \begin{array}{ll} &\sum_{h=0}^{H-1}(\gamma^K)^h\sum_{t=0}^{K-1}\gamma^t r(s_{t+hK},a_{t+hK})\\ =&\sum_{h=0}^{H-1}\beta^h R_h\\ \end{array} where :math:`\gamma` is the discount and :math:`r(\cdot)` is the extrinsic reward of the original task. Thus :math:`\beta=\gamma^K` should be the discount per higher-level time step and :math:`R_h=\sum_{t=0}^{K-1}\gamma^t r(s_{t+hK},a_{t+hK})` should the reward per higher-level time step. """ switch_skill = self._should_switch_skills(time_step, state) discriminator_step = self._discriminator_rollout_step( time_step, state, switch_skill) state = state._replace( rl_reward=state.rl_reward + state.rl_discount * time_step.reward, rl_discount=state.rl_discount * time_step.discount * self._gamma) def _generate_new_skills(time_step, state): rl_prev_action = state.skill # avoid dividing by 0 #steps = torch.max( # state.steps.to(torch.float32), torch.as_tensor(1.0)) rl_time_step = time_step._replace(reward=state.rl_reward, discount=state.rl_discount, prev_action=rl_prev_action) rl_step = self._rl.rollout_step(rl_time_step, state.rl) # store to replay buffer self._rl.observe_for_replay( make_experience( # ``rl_time_step.observation`` has been transformed!!! rl_time_step._replace( observation=rl_time_step.untransformed.observation), rl_step, state.rl)) return SkillGeneratorState( skill=rl_step.output, steps=torch.zeros_like(state.steps), discriminator=state.discriminator, rl=rl_step.state, rl_reward=torch.zeros_like(state.rl_reward), rl_discount=torch.ones_like(state.rl_discount)) new_state = conditional_update(target=state, cond=switch_skill, func=_generate_new_skills, time_step=time_step, state=state) new_state = new_state._replace(steps=new_state.steps + 1, discriminator=discriminator_step.state) return AlgStep(output=new_state.skill, state=new_state, info=SkillGeneratorInfo(skill=new_state.skill, steps=new_state.steps, switch_skill=switch_skill))
def rollout_step(self, time_step: TimeStep, state: ActionRepeatState): switch_action = self._should_switch_action(time_step, state) # state.k is the current step index over K steps state = state._replace( rl_reward=state.rl_reward + torch.pow( self._gamma, state.k.to(torch.float32)) * time_step.reward, rl_discount=state.rl_discount * time_step.discount * self._gamma, k=state.k + 1) if self._reward_normalizer is not None: # The probability of a reward at step k being kept till K steps is: # 1/k * k/(k+1) * .. * (K-1)/K = 1/K. This provides enough randomness # to make the normalizer unbiased. state = state._replace( sample_rewards=torch.where(( torch.rand_like(state.sample_rewards) < 1. / state.k.to(torch.float32) ), time_step.reward, state.sample_rewards)) @torch.no_grad() def _generate_new_action(time_step, state): rl_time_step = time_step._replace( reward=state.rl_reward, # To keep consistent with other algorithms, we choose to multiply # discount with gamma once more in td_loss.py discount=state.rl_discount / self._gamma) observation, repr_state = rl_time_step.observation, () if self._repr_learner is not None: repr_step = self._repr_learner.rollout_step( time_step, state.repr) observation = repr_step.output repr_state = repr_step.state rl_step = self._rl.rollout_step( rl_time_step._replace(observation=observation), state.rl) rl_step = rl_step._replace( info=(rl_step.info, state.k, state.sample_rewards)) # Store to replay buffer. super(DynamicActionRepeatAgent, self).observe_for_replay( make_experience( rl_time_step._replace( # Store the untransformed observation so that later it will # be transformed again during training observation=rl_time_step.untransformed.observation), rl_step, state)) steps, action = rl_step.output return ActionRepeatState( action=action, steps=steps + 1, # [0, K-1] -> [1, K] k=torch.zeros_like(state.k), repr=repr_state, rl=rl_step.state, rl_reward=torch.zeros_like(state.rl_reward), sample_rewards=torch.zeros_like(state.sample_rewards), rl_discount=torch.ones_like(state.rl_discount)) new_state = conditional_update( target=state, cond=switch_action, func=_generate_new_action, time_step=time_step, state=state) new_state = new_state._replace(steps=new_state.steps - 1) return AlgStep(output=new_state.action, state=new_state)