Ejemplo n.º 1
0
    def predict_step(self, time_step: TimeStep, state, epsilon_greedy):
        """This function does one thing, i.e., every ``self._num_steps_per_skill``
        it calls ``self._rl`` to generate new skills.
        """
        switch_skill = self._should_switch_skills(time_step, state)
        discriminator_step = self._discriminator_predict_step(
            time_step, state, switch_skill)

        def _generate_new_skills(time_step, state):
            rl_step = self._rl.predict_step(time_step, state.rl,
                                            epsilon_greedy)
            return SkillGeneratorState(skill=rl_step.output,
                                       steps=torch.zeros_like(state.steps),
                                       rl=rl_step.state)

        new_state = conditional_update(target=state,
                                       cond=switch_skill,
                                       func=_generate_new_skills,
                                       time_step=time_step,
                                       state=state)

        new_state = new_state._replace(steps=new_state.steps + 1,
                                       discriminator=discriminator_step.state)

        return AlgStep(output=new_state.skill,
                       state=new_state,
                       info=SkillGeneratorInfo(switch_skill=switch_skill))
Ejemplo n.º 2
0
    def predict_step(self, time_step: TimeStep, state, epsilon_greedy):
        switch_action = self._should_switch_action(time_step, state)

        @torch.no_grad()
        def _generate_new_action(time_step, state):
            repr_state = ()
            if self._repr_learner is not None:
                repr_step = self._repr_learner.predict_step(
                    time_step, state.repr)
                time_step = time_step._replace(observation=repr_step.output)
                repr_state = repr_step.state

            rl_step = self._rl.predict_step(time_step, state.rl,
                                            epsilon_greedy)
            steps, action = rl_step.output
            return ActionRepeatState(
                action=action,
                steps=steps + 1,  # [0, K-1] -> [1, K]
                rl=rl_step.state,
                repr=repr_state)

        new_state = conditional_update(
            target=state,
            cond=switch_action,
            func=_generate_new_action,
            time_step=time_step,
            state=state)
        new_state = new_state._replace(steps=new_state.steps - 1)

        return AlgStep(
            output=new_state.action,
            state=new_state,
            # plot steps and action when rendering video
            info=dict(action=(new_state.action, new_state.steps)))
Ejemplo n.º 3
0
    def test_conditional_update(self):
        def _func(x, y):
            return x + 1, y - 1

        batch_size = 256
        target = (tf.random.uniform([batch_size,
                                     3]), tf.random.uniform([batch_size]))
        x = tf.random.uniform([batch_size, 3])
        y = tf.random.uniform([batch_size])

        cond = tf.constant([False] * batch_size)
        updated_target = conditional_update(target, cond, _func, x, y)
        self.assertAllEqual(updated_target[0], target[0])
        self.assertAllEqual(updated_target[1], target[1])

        cond = tf.constant([True] * batch_size)
        updated_target = conditional_update(target, cond, _func, x, y)
        self.assertAllEqual(updated_target[0], x + 1)
        self.assertAllEqual(updated_target[1], y - 1)

        cond = tf.random.uniform((batch_size, )) < 0.5
        updated_target = conditional_update(target, cond, _func, x, y)
        self.assertAllEqual(select_from_mask(updated_target[0], cond),
                            select_from_mask(x + 1, cond))
        self.assertAllEqual(select_from_mask(updated_target[1], cond),
                            select_from_mask(y - 1, cond))

        vx = tf.Variable(initial_value=0.)
        vy = tf.Variable(initial_value=0.)

        def _func1(x, y):
            vx.assign(tf.reduce_sum(x))
            vy.assign(tf.reduce_sum(y))
            return ()

        # test empty return
        conditional_update((), cond, _func1, x, y)
        self.assertEqual(vx, tf.reduce_sum(select_from_mask(x, cond)))
        self.assertEqual(vy, tf.reduce_sum(select_from_mask(y, cond)))
Ejemplo n.º 4
0
    def test_conditional_update(self):
        def _func(x, y):
            return x + 1, y - 1

        batch_size = 256
        target = (torch.rand([batch_size, 3]), torch.rand([batch_size]))
        x = torch.rand([batch_size, 3])
        y = torch.rand([batch_size])

        cond = torch.as_tensor([False] * batch_size)
        updated_target = conditional_update(target, cond, _func, x, y)
        self.assertTensorEqual(updated_target[0], target[0])
        self.assertTensorEqual(updated_target[1], target[1])

        cond = torch.as_tensor([True] * batch_size)
        updated_target = conditional_update(target, cond, _func, x, y)
        self.assertTensorEqual(updated_target[0], x + 1)
        self.assertTensorEqual(updated_target[1], y - 1)

        cond = torch.rand((batch_size, )) < 0.5
        updated_target = conditional_update(target, cond, _func, x, y)
        self.assertTensorEqual(select_from_mask(updated_target[0], cond),
                               select_from_mask(x + 1, cond))
        self.assertTensorEqual(select_from_mask(updated_target[1], cond),
                               select_from_mask(y - 1, cond))

        vx = torch.zeros(())
        vy = torch.zeros(())

        def _func1(x, y):
            vx.copy_(torch.sum(x))
            vy.copy_(torch.sum(y))
            return ()

        # test empty return
        conditional_update((), cond, _func1, x, y)
        self.assertEqual(vx, torch.sum(select_from_mask(x, cond)))
        self.assertEqual(vy, torch.sum(select_from_mask(y, cond)))
Ejemplo n.º 5
0
    def test_conditional_update_high_dims(self):
        def _func(x):
            return x**2

        batch_size = 100
        x = torch.rand([batch_size, 3, 4, 5])
        y = torch.rand([batch_size, 3, 4, 5])
        cond = torch.randint(high=2, size=[batch_size]).to(torch.bool)

        updated_y = conditional_update(y, cond, _func, x)
        self.assertTensorEqual(select_from_mask(updated_y, cond),
                               select_from_mask(x**2, cond))
        self.assertTensorEqual(select_from_mask(updated_y, ~cond),
                               select_from_mask(y, ~cond))
Ejemplo n.º 6
0
    def rollout_step(self, time_step: TimeStep, state: ActionRepeatState):
        switch_action = self._should_switch_action(time_step, state)
        state = state._replace(
            rl_reward=state.rl_reward + state.rl_discount * time_step.reward,
            rl_discount=state.rl_discount * time_step.discount * self._gamma)

        @torch.no_grad()
        def _generate_new_action(time_step, state):
            rl_time_step = time_step._replace(
                reward=state.rl_reward, discount=state.rl_discount)

            observation, repr_state = rl_time_step.observation, ()
            if self._repr_learner is not None:
                repr_step = self._repr_learner.rollout_step(
                    time_step, state.repr)
                observation = repr_step.output
                repr_state = repr_step.state

            rl_step = self._rl.rollout_step(
                rl_time_step._replace(observation=observation), state.rl)
            # Store to replay buffer.
            super(DynamicActionRepeatAgent, self).observe_for_replay(
                make_experience(
                    rl_time_step._replace(
                        # Store the untransformed observation so that later it will
                        # be transformed again during training
                        observation=rl_time_step.untransformed.observation),
                    rl_step,
                    state))
            steps, action = rl_step.output
            return ActionRepeatState(
                action=action,
                steps=steps + 1,  # [0, K-1] -> [1, K]
                repr=repr_state,
                rl=rl_step.state,
                rl_reward=torch.zeros_like(state.rl_reward),
                rl_discount=torch.ones_like(state.rl_discount))

        new_state = conditional_update(
            target=state,
            cond=switch_action,
            func=_generate_new_action,
            time_step=time_step,
            state=state)

        new_state = new_state._replace(steps=new_state.steps - 1)

        return AlgStep(output=new_state.action, state=new_state)
Ejemplo n.º 7
0
    def _make_low_level_observation(self, subtrajectory, skill, switch_skill,
                                    steps, updated_first_observation):
        r"""Given the skill generator's output, this function makes the
        skill-conditioned observation for the lower-level policy. Both observation
        and action are a stacking of recent states, with the most recent one appearing
        at index=0.

        X: first observation of a skill
        X': first observation of the next skill
        O: middle observation of a skill
        _: zero

        num_steps_per_skill=3:

            subtrajectory (discriminator)    low_rl_observation
            O _ _                        ->  O X _   (steps==2)
            O O _                        ->  O O X   (steps==3)
            X'O O                        ->  X'_ _   (steps==1,switch_skill==True)

        The same applies to action except that there is no ``first_observation``.
        """
        subtrajectory.observation[
            torch.arange(updated_first_observation.shape[0]).long(),
            steps - 1] = updated_first_observation

        def _zero(subtrajectory):
            subtrajectory.prev_action.fill_(0.)
            # When switch_skill is because of FINAL steps, filling
            # 0s might have issues if the FINAL step comes before num_steps_per_skill.
            # But since RL algorithms don't train FINAL steps, for now we'll leave
            # it like this for simplicity.
            subtrajectory.observation[:, 1:, ...] = 0.
            return subtrajectory

        subtrajectory = conditional_update(target=subtrajectory,
                                           cond=switch_skill,
                                           func=_zero,
                                           subtrajectory=subtrajectory)
        subtrajectory = alf.nest.map_structure(
            lambda traj: traj.reshape(traj.shape[0], -1), subtrajectory)

        low_rl_observation = (alf.nest.flatten(subtrajectory) +
                              [self._num_steps_per_skill - steps, skill])
        return low_rl_observation
Ejemplo n.º 8
0
    def rollout_step(self, time_step: TimeStep, state):
        r"""This function does three things:

        1. every ``self._num_steps_per_skill`` it calls ``self._rl`` to generate new
           skills.
        2. at the same time writes ``time_step`` to a replay buffer when new skills
           are generated.
        3. call ``rollout_step()`` of the discriminator to write ``time_step``
           to a replay buffer for training

        Regarding accumulating rewards for the higher-level policy. Suppose that
        during an episode we have :math:`H` segments where each segment contains
        :math:`K` steps. Then the objective for the higher-level policy is:

        .. math::

            \begin{array}{ll}
                &\sum_{h=0}^{H-1}(\gamma^K)^h\sum_{t=0}^{K-1}\gamma^t r(s_{t+hK},a_{t+hK})\\
                =&\sum_{h=0}^{H-1}\beta^h R_h\\
            \end{array}

        where :math:`\gamma` is the discount and :math:`r(\cdot)` is the extrinsic
        reward of the original task. Thus :math:`\beta=\gamma^K` should be the
        discount per higher-level time step and :math:`R_h=\sum_{t=0}^{K-1}\gamma^t r(s_{t+hK},a_{t+hK})`
        should the reward per higher-level time step.
        """
        switch_skill = self._should_switch_skills(time_step, state)

        discriminator_step = self._discriminator_rollout_step(
            time_step, state, switch_skill)

        state = state._replace(
            rl_reward=state.rl_reward + state.rl_discount * time_step.reward,
            rl_discount=state.rl_discount * time_step.discount * self._gamma)

        def _generate_new_skills(time_step, state):
            rl_prev_action = state.skill
            # avoid dividing by 0
            #steps = torch.max(
            #    state.steps.to(torch.float32), torch.as_tensor(1.0))
            rl_time_step = time_step._replace(reward=state.rl_reward,
                                              discount=state.rl_discount,
                                              prev_action=rl_prev_action)

            rl_step = self._rl.rollout_step(rl_time_step, state.rl)

            # store to replay buffer
            self._rl.observe_for_replay(
                make_experience(
                    # ``rl_time_step.observation`` has been transformed!!!
                    rl_time_step._replace(
                        observation=rl_time_step.untransformed.observation),
                    rl_step,
                    state.rl))

            return SkillGeneratorState(
                skill=rl_step.output,
                steps=torch.zeros_like(state.steps),
                discriminator=state.discriminator,
                rl=rl_step.state,
                rl_reward=torch.zeros_like(state.rl_reward),
                rl_discount=torch.ones_like(state.rl_discount))

        new_state = conditional_update(target=state,
                                       cond=switch_skill,
                                       func=_generate_new_skills,
                                       time_step=time_step,
                                       state=state)
        new_state = new_state._replace(steps=new_state.steps + 1,
                                       discriminator=discriminator_step.state)
        return AlgStep(output=new_state.skill,
                       state=new_state,
                       info=SkillGeneratorInfo(skill=new_state.skill,
                                               steps=new_state.steps,
                                               switch_skill=switch_skill))
Ejemplo n.º 9
0
    def rollout_step(self, time_step: TimeStep, state: ActionRepeatState):
        switch_action = self._should_switch_action(time_step, state)

        # state.k is the current step index over K steps
        state = state._replace(
            rl_reward=state.rl_reward + torch.pow(
                self._gamma, state.k.to(torch.float32)) * time_step.reward,
            rl_discount=state.rl_discount * time_step.discount * self._gamma,
            k=state.k + 1)

        if self._reward_normalizer is not None:
            # The probability of a reward at step k being kept till K steps is:
            # 1/k * k/(k+1) * .. * (K-1)/K = 1/K. This provides enough randomness
            # to make the normalizer unbiased.
            state = state._replace(
                sample_rewards=torch.where((
                    torch.rand_like(state.sample_rewards) < 1. /
                    state.k.to(torch.float32)
                ), time_step.reward, state.sample_rewards))

        @torch.no_grad()
        def _generate_new_action(time_step, state):
            rl_time_step = time_step._replace(
                reward=state.rl_reward,
                # To keep consistent with other algorithms, we choose to multiply
                # discount with gamma once more in td_loss.py
                discount=state.rl_discount / self._gamma)

            observation, repr_state = rl_time_step.observation, ()
            if self._repr_learner is not None:
                repr_step = self._repr_learner.rollout_step(
                    time_step, state.repr)
                observation = repr_step.output
                repr_state = repr_step.state

            rl_step = self._rl.rollout_step(
                rl_time_step._replace(observation=observation), state.rl)
            rl_step = rl_step._replace(
                info=(rl_step.info, state.k, state.sample_rewards))
            # Store to replay buffer.
            super(DynamicActionRepeatAgent, self).observe_for_replay(
                make_experience(
                    rl_time_step._replace(
                        # Store the untransformed observation so that later it will
                        # be transformed again during training
                        observation=rl_time_step.untransformed.observation),
                    rl_step,
                    state))
            steps, action = rl_step.output
            return ActionRepeatState(
                action=action,
                steps=steps + 1,  # [0, K-1] -> [1, K]
                k=torch.zeros_like(state.k),
                repr=repr_state,
                rl=rl_step.state,
                rl_reward=torch.zeros_like(state.rl_reward),
                sample_rewards=torch.zeros_like(state.sample_rewards),
                rl_discount=torch.ones_like(state.rl_discount))

        new_state = conditional_update(
            target=state,
            cond=switch_action,
            func=_generate_new_action,
            time_step=time_step,
            state=state)

        new_state = new_state._replace(steps=new_state.steps - 1)

        return AlgStep(output=new_state.action, state=new_state)