Esempio n. 1
0
    def test_compute_her_future_step_distance(self, end_prob):
        num_envs = 2
        max_length = 100
        torch.manual_seed(0)
        configs = [
            "hindsight_relabel_fn.her_proportion=0.8",
            'hindsight_relabel_fn.achieved_goal_field="o.a"',
            'hindsight_relabel_fn.desired_goal_field="o.g"',
            "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn",
        ]
        gin.parse_config_files_and_bindings("", configs)

        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=num_envs,
                                     max_length=max_length,
                                     keep_episodic_info=True,
                                     step_type_field="t")
        # insert data
        max_steps = 1000
        # generate step_types with certain density of episode ends
        steps = self.generate_step_types(num_envs,
                                         max_steps,
                                         end_prob=end_prob)
        for t in range(max_steps):
            for b in range(num_envs):
                batch = get_batch([b],
                                  self.dim,
                                  t=steps[b * max_steps + t],
                                  x=1. / max_steps * t + b)
                replay_buffer.add_batch(batch, batch.env_id)
            if t > 1:
                sample_steps = min(t, max_length)
                env_ids = torch.tensor([0] * sample_steps + [1] * sample_steps)
                idx = torch.tensor(
                    list(range(sample_steps)) + list(range(sample_steps)))
                gd = self.steps_to_episode_end(replay_buffer, env_ids, idx)
                idx_orig = replay_buffer._indexed_pos.clone()
                idx_headless_orig = replay_buffer._headless_indexed_pos.clone()
                d = replay_buffer.steps_to_episode_end(
                    replay_buffer._pad(idx, env_ids), env_ids)
                # Test distance to end computation
                if not torch.equal(gd, d):
                    outs = [
                        "t: ", t, "\nenvids:\n", env_ids, "\nidx:\n", idx,
                        "\npos:\n",
                        replay_buffer._pad(idx, env_ids), "\nNot Equal: a:\n",
                        gd, "\nb:\n", d, "\nsteps:\n", replay_buffer._buffer.t,
                        "\nindexed_pos:\n", replay_buffer._indexed_pos,
                        "\nheadless_indexed_pos:\n",
                        replay_buffer._headless_indexed_pos
                    ]
                    outs = [str(out) for out in outs]
                    assert False, "".join(outs)

                # Save original exp for later testing.
                g_orig = replay_buffer._buffer.o["g"].clone()
                r_orig = replay_buffer._buffer.reward.clone()

                # HER relabel experience
                res = replay_buffer.get_batch(sample_steps, 2)[0]

                self.assertEqual(list(res.o["g"].shape), [sample_steps, 2])

                # Test relabeling doesn't change original experience
                self.assertTrue(
                    torch.allclose(r_orig, replay_buffer._buffer.reward))
                self.assertTrue(
                    torch.allclose(g_orig, replay_buffer._buffer.o["g"]))
                self.assertTrue(
                    torch.all(idx_orig == replay_buffer._indexed_pos))
                self.assertTrue(
                    torch.all(idx_headless_orig ==
                              replay_buffer._headless_indexed_pos))
Esempio n. 2
0
    def test_replay_with_hindsight_relabel(self):
        self.max_length = 8
        torch.manual_seed(0)
        configs = [
            "hindsight_relabel_fn.her_proportion=0.8",
            'hindsight_relabel_fn.achieved_goal_field="o.a"',
            'hindsight_relabel_fn.desired_goal_field="o.g"',
            "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn",
        ]
        gin.parse_config_files_and_bindings("", configs)

        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=2,
                                     max_length=self.max_length,
                                     keep_episodic_info=True,
                                     step_type_field="t",
                                     with_replacement=True)

        steps = [
            [
                ds.StepType.FIRST,  # will be overwritten
                ds.StepType.MID,  # idx == 1 in buffer
                ds.StepType.LAST,
                ds.StepType.FIRST,
                ds.StepType.MID,
                ds.StepType.MID,
                ds.StepType.LAST,
                ds.StepType.FIRST,
                ds.StepType.MID  # idx == 0
            ],
            [
                ds.StepType.FIRST,  # will be overwritten in RingBuffer
                ds.StepType.LAST,  # idx == 1 in RingBuffer
                ds.StepType.FIRST,
                ds.StepType.MID,
                ds.StepType.MID,
                ds.StepType.LAST,
                ds.StepType.FIRST,
                ds.StepType.MID,
                ds.StepType.MID  # idx == 0
            ]
        ]
        # insert data that will be overwritten later
        for b, t in list(itertools.product(range(2), range(8))):
            batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b)
            replay_buffer.add_batch(batch, batch.env_id)
        # insert data
        for b, t in list(itertools.product(range(2), range(9))):
            batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b)
            replay_buffer.add_batch(batch, batch.env_id)

        # Test padding
        idx = torch.tensor([[7, 0, 0, 6, 3, 3, 3, 0], [6, 0, 5, 2, 2, 2, 0,
                                                       6]])
        pos = replay_buffer._pad(idx, torch.tensor([[0] * 8, [1] * 8]))
        self.assertTrue(
            torch.equal(
                pos,
                torch.tensor([[15, 16, 16, 14, 11, 11, 11, 16],
                              [14, 16, 13, 10, 10, 10, 16, 14]])))

        # Verify _index is built correctly.
        # Note, the _index_pos 8 represents headless timesteps, which are
        # outdated and not the same as the result of padding: 16.
        pos = torch.tensor([[15, 8, 8, 14, 11, 11, 11, 16],
                            [14, 8, 13, 10, 10, 10, 16, 14]])

        self.assertTrue(torch.equal(replay_buffer._indexed_pos, pos))
        self.assertTrue(
            torch.equal(replay_buffer._headless_indexed_pos,
                        torch.tensor([10, 9])))

        # Save original exp for later testing.
        g_orig = replay_buffer._buffer.o["g"].clone()
        r_orig = replay_buffer._buffer.reward.clone()

        # HER selects indices [0, 2, 3, 4] to relabel, from all 5:
        # env_ids: [[0, 0], [1, 1], [0, 0], [1, 1], [0, 0]]
        # pos:     [[6, 7], [1, 2], [1, 2], [3, 4], [5, 6]] + 8
        # selected:    x               x       x       x
        # future:  [   7       2       2       4       6  ] + 8
        # g        [[.7,.7],[0, 0], [.2,.2],[1.4,1.4],[.6,.6]]  # 0.1 * t + b with default 0
        # reward:  [[-1,0], [-1,-1],[-1,0], [-1,0], [-1,0]]  # recomputed with default -1
        env_ids = torch.tensor([0, 0, 1, 0])
        dist = replay_buffer.steps_to_episode_end(
            replay_buffer._pad(torch.tensor([7, 2, 4, 6]), env_ids), env_ids)
        self.assertEqual(list(dist), [1, 0, 1, 0])

        # Test HER relabeled experiences
        res = replay_buffer.get_batch(5, 2)[0]

        self.assertEqual(list(res.o["g"].shape), [5, 2])

        # Test relabeling doesn't change original experience
        self.assertTrue(torch.allclose(r_orig, replay_buffer._buffer.reward))
        self.assertTrue(torch.allclose(g_orig, replay_buffer._buffer.o["g"]))

        # test relabeled goals
        g = torch.tensor([0.7, 0., .2, 1.4, .6]).unsqueeze(1).expand(5, 2)
        self.assertTrue(torch.allclose(res.o["g"], g))

        # test relabeled rewards
        r = torch.tensor([[-1., 0.], [-1., -1.], [-1., 0.], [-1., 0.],
                          [-1., 0.]])
        self.assertTrue(torch.allclose(res.reward, r))
Esempio n. 3
0
    def _reanalyze1(self, replay_buffer: ReplayBuffer, env_ids, positions,
                    mcts_state_field):
        """Reanalyze one batch.

        This means:
        1. Re-plan the policy using MCTS for n1 = 1 + num_unroll_steps to get fresh policy
        and value target.
        2. Caluclate the value for following n2 = reanalyze_td_steps so that we have value
        for a total of 1 + num_unroll_steps + reanalyze_td_steps.
        3. Use these values and rewards from replay buffer to caculate n2-step
        bootstraped value target for the first n1 steps.

        In order to do 1 and 2, we need to get the observations for n1 + n2 steps
        and processs them using data_transformer.
        """
        batch_size = env_ids.shape[0]
        n1 = self._num_unroll_steps + 1
        n2 = self._reanalyze_td_steps
        env_ids, positions = self._next_n_positions(
            replay_buffer, env_ids, positions, self._num_unroll_steps + n2)
        # [B, n1]
        positions1 = positions[:, :n1]
        game_overs = replay_buffer.get_field('discount', env_ids,
                                             positions1) == 0.

        steps_to_episode_end = replay_buffer.steps_to_episode_end(
            positions1, env_ids)
        bootstrap_n = steps_to_episode_end.clamp(max=n2)

        exp1, exp2 = self._prepare_reanalyze_data(replay_buffer, env_ids,
                                                  positions, n1, n2)

        bootstrap_position = positions1 + bootstrap_n
        discount = replay_buffer.get_field('discount', env_ids,
                                           bootstrap_position)
        sum_reward = self._sum_discounted_reward(replay_buffer, env_ids,
                                                 positions1,
                                                 bootstrap_position, n2)

        if not self._train_reward_function:
            rewards = self._get_reward(replay_buffer, env_ids,
                                       bootstrap_position)

        with alf.device(self._device):
            bootstrap_n = convert_device(bootstrap_n)
            discount = convert_device(discount)
            sum_reward = convert_device(sum_reward)
            game_overs = convert_device(game_overs)

            # 1. Reanalyze the first n1 steps to get both the updated value and policy
            self._mcts.set_model(self._target_model)
            mcts_step = self._mcts.predict_step(
                exp1, alf.nest.get_field(exp1, mcts_state_field))
            self._mcts.set_model(self._model)
            candidate_actions = ()
            if not _is_empty(mcts_step.info.candidate_actions):
                candidate_actions = mcts_step.info.candidate_actions
                candidate_actions = candidate_actions.reshape(
                    batch_size, n1, *candidate_actions.shape[1:])
            candidate_action_policy = mcts_step.info.candidate_action_policy
            candidate_action_policy = candidate_action_policy.reshape(
                batch_size, n1, *candidate_action_policy.shape[1:])
            values1 = mcts_step.info.value.reshape(batch_size, n1)

            # 2. Calulate the value of the next n2 steps so that n2-step return
            # can be computed.
            model_output = self._target_model.initial_inference(
                exp2.observation)
            values2 = model_output.value.reshape(batch_size, n2)

            # 3. Calculate n2-step return
            values = torch.cat([values1, values2], dim=1)
            # [B, n1]
            bootstrap_pos = torch.arange(n1).unsqueeze(0) + bootstrap_n
            values = values[torch.arange(batch_size).unsqueeze(-1),
                            bootstrap_pos]
            values = values * discount * (self._discount**bootstrap_n.to(
                torch.float32))
            values = values + sum_reward
            if not self._train_reward_function:
                # For this condition, we need to set the value at and after the
                # last step to be the last reward.
                values = torch.where(game_overs, convert_device(rewards),
                                     values)
            return candidate_actions, candidate_action_policy, values