def test_compute_her_future_step_distance(self, end_prob): num_envs = 2 max_length = 100 torch.manual_seed(0) configs = [ "hindsight_relabel_fn.her_proportion=0.8", 'hindsight_relabel_fn.achieved_goal_field="o.a"', 'hindsight_relabel_fn.desired_goal_field="o.g"', "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn", ] gin.parse_config_files_and_bindings("", configs) replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=num_envs, max_length=max_length, keep_episodic_info=True, step_type_field="t") # insert data max_steps = 1000 # generate step_types with certain density of episode ends steps = self.generate_step_types(num_envs, max_steps, end_prob=end_prob) for t in range(max_steps): for b in range(num_envs): batch = get_batch([b], self.dim, t=steps[b * max_steps + t], x=1. / max_steps * t + b) replay_buffer.add_batch(batch, batch.env_id) if t > 1: sample_steps = min(t, max_length) env_ids = torch.tensor([0] * sample_steps + [1] * sample_steps) idx = torch.tensor( list(range(sample_steps)) + list(range(sample_steps))) gd = self.steps_to_episode_end(replay_buffer, env_ids, idx) idx_orig = replay_buffer._indexed_pos.clone() idx_headless_orig = replay_buffer._headless_indexed_pos.clone() d = replay_buffer.steps_to_episode_end( replay_buffer._pad(idx, env_ids), env_ids) # Test distance to end computation if not torch.equal(gd, d): outs = [ "t: ", t, "\nenvids:\n", env_ids, "\nidx:\n", idx, "\npos:\n", replay_buffer._pad(idx, env_ids), "\nNot Equal: a:\n", gd, "\nb:\n", d, "\nsteps:\n", replay_buffer._buffer.t, "\nindexed_pos:\n", replay_buffer._indexed_pos, "\nheadless_indexed_pos:\n", replay_buffer._headless_indexed_pos ] outs = [str(out) for out in outs] assert False, "".join(outs) # Save original exp for later testing. g_orig = replay_buffer._buffer.o["g"].clone() r_orig = replay_buffer._buffer.reward.clone() # HER relabel experience res = replay_buffer.get_batch(sample_steps, 2)[0] self.assertEqual(list(res.o["g"].shape), [sample_steps, 2]) # Test relabeling doesn't change original experience self.assertTrue( torch.allclose(r_orig, replay_buffer._buffer.reward)) self.assertTrue( torch.allclose(g_orig, replay_buffer._buffer.o["g"])) self.assertTrue( torch.all(idx_orig == replay_buffer._indexed_pos)) self.assertTrue( torch.all(idx_headless_orig == replay_buffer._headless_indexed_pos))
def test_replay_with_hindsight_relabel(self): self.max_length = 8 torch.manual_seed(0) configs = [ "hindsight_relabel_fn.her_proportion=0.8", 'hindsight_relabel_fn.achieved_goal_field="o.a"', 'hindsight_relabel_fn.desired_goal_field="o.g"', "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn", ] gin.parse_config_files_and_bindings("", configs) replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=2, max_length=self.max_length, keep_episodic_info=True, step_type_field="t", with_replacement=True) steps = [ [ ds.StepType.FIRST, # will be overwritten ds.StepType.MID, # idx == 1 in buffer ds.StepType.LAST, ds.StepType.FIRST, ds.StepType.MID, ds.StepType.MID, ds.StepType.LAST, ds.StepType.FIRST, ds.StepType.MID # idx == 0 ], [ ds.StepType.FIRST, # will be overwritten in RingBuffer ds.StepType.LAST, # idx == 1 in RingBuffer ds.StepType.FIRST, ds.StepType.MID, ds.StepType.MID, ds.StepType.LAST, ds.StepType.FIRST, ds.StepType.MID, ds.StepType.MID # idx == 0 ] ] # insert data that will be overwritten later for b, t in list(itertools.product(range(2), range(8))): batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b) replay_buffer.add_batch(batch, batch.env_id) # insert data for b, t in list(itertools.product(range(2), range(9))): batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b) replay_buffer.add_batch(batch, batch.env_id) # Test padding idx = torch.tensor([[7, 0, 0, 6, 3, 3, 3, 0], [6, 0, 5, 2, 2, 2, 0, 6]]) pos = replay_buffer._pad(idx, torch.tensor([[0] * 8, [1] * 8])) self.assertTrue( torch.equal( pos, torch.tensor([[15, 16, 16, 14, 11, 11, 11, 16], [14, 16, 13, 10, 10, 10, 16, 14]]))) # Verify _index is built correctly. # Note, the _index_pos 8 represents headless timesteps, which are # outdated and not the same as the result of padding: 16. pos = torch.tensor([[15, 8, 8, 14, 11, 11, 11, 16], [14, 8, 13, 10, 10, 10, 16, 14]]) self.assertTrue(torch.equal(replay_buffer._indexed_pos, pos)) self.assertTrue( torch.equal(replay_buffer._headless_indexed_pos, torch.tensor([10, 9]))) # Save original exp for later testing. g_orig = replay_buffer._buffer.o["g"].clone() r_orig = replay_buffer._buffer.reward.clone() # HER selects indices [0, 2, 3, 4] to relabel, from all 5: # env_ids: [[0, 0], [1, 1], [0, 0], [1, 1], [0, 0]] # pos: [[6, 7], [1, 2], [1, 2], [3, 4], [5, 6]] + 8 # selected: x x x x # future: [ 7 2 2 4 6 ] + 8 # g [[.7,.7],[0, 0], [.2,.2],[1.4,1.4],[.6,.6]] # 0.1 * t + b with default 0 # reward: [[-1,0], [-1,-1],[-1,0], [-1,0], [-1,0]] # recomputed with default -1 env_ids = torch.tensor([0, 0, 1, 0]) dist = replay_buffer.steps_to_episode_end( replay_buffer._pad(torch.tensor([7, 2, 4, 6]), env_ids), env_ids) self.assertEqual(list(dist), [1, 0, 1, 0]) # Test HER relabeled experiences res = replay_buffer.get_batch(5, 2)[0] self.assertEqual(list(res.o["g"].shape), [5, 2]) # Test relabeling doesn't change original experience self.assertTrue(torch.allclose(r_orig, replay_buffer._buffer.reward)) self.assertTrue(torch.allclose(g_orig, replay_buffer._buffer.o["g"])) # test relabeled goals g = torch.tensor([0.7, 0., .2, 1.4, .6]).unsqueeze(1).expand(5, 2) self.assertTrue(torch.allclose(res.o["g"], g)) # test relabeled rewards r = torch.tensor([[-1., 0.], [-1., -1.], [-1., 0.], [-1., 0.], [-1., 0.]]) self.assertTrue(torch.allclose(res.reward, r))
def _reanalyze1(self, replay_buffer: ReplayBuffer, env_ids, positions, mcts_state_field): """Reanalyze one batch. This means: 1. Re-plan the policy using MCTS for n1 = 1 + num_unroll_steps to get fresh policy and value target. 2. Caluclate the value for following n2 = reanalyze_td_steps so that we have value for a total of 1 + num_unroll_steps + reanalyze_td_steps. 3. Use these values and rewards from replay buffer to caculate n2-step bootstraped value target for the first n1 steps. In order to do 1 and 2, we need to get the observations for n1 + n2 steps and processs them using data_transformer. """ batch_size = env_ids.shape[0] n1 = self._num_unroll_steps + 1 n2 = self._reanalyze_td_steps env_ids, positions = self._next_n_positions( replay_buffer, env_ids, positions, self._num_unroll_steps + n2) # [B, n1] positions1 = positions[:, :n1] game_overs = replay_buffer.get_field('discount', env_ids, positions1) == 0. steps_to_episode_end = replay_buffer.steps_to_episode_end( positions1, env_ids) bootstrap_n = steps_to_episode_end.clamp(max=n2) exp1, exp2 = self._prepare_reanalyze_data(replay_buffer, env_ids, positions, n1, n2) bootstrap_position = positions1 + bootstrap_n discount = replay_buffer.get_field('discount', env_ids, bootstrap_position) sum_reward = self._sum_discounted_reward(replay_buffer, env_ids, positions1, bootstrap_position, n2) if not self._train_reward_function: rewards = self._get_reward(replay_buffer, env_ids, bootstrap_position) with alf.device(self._device): bootstrap_n = convert_device(bootstrap_n) discount = convert_device(discount) sum_reward = convert_device(sum_reward) game_overs = convert_device(game_overs) # 1. Reanalyze the first n1 steps to get both the updated value and policy self._mcts.set_model(self._target_model) mcts_step = self._mcts.predict_step( exp1, alf.nest.get_field(exp1, mcts_state_field)) self._mcts.set_model(self._model) candidate_actions = () if not _is_empty(mcts_step.info.candidate_actions): candidate_actions = mcts_step.info.candidate_actions candidate_actions = candidate_actions.reshape( batch_size, n1, *candidate_actions.shape[1:]) candidate_action_policy = mcts_step.info.candidate_action_policy candidate_action_policy = candidate_action_policy.reshape( batch_size, n1, *candidate_action_policy.shape[1:]) values1 = mcts_step.info.value.reshape(batch_size, n1) # 2. Calulate the value of the next n2 steps so that n2-step return # can be computed. model_output = self._target_model.initial_inference( exp2.observation) values2 = model_output.value.reshape(batch_size, n2) # 3. Calculate n2-step return values = torch.cat([values1, values2], dim=1) # [B, n1] bootstrap_pos = torch.arange(n1).unsqueeze(0) + bootstrap_n values = values[torch.arange(batch_size).unsqueeze(-1), bootstrap_pos] values = values * discount * (self._discount**bootstrap_n.to( torch.float32)) values = values + sum_reward if not self._train_reward_function: # For this condition, we need to set the value at and after the # last step to be the last reward. values = torch.where(game_overs, convert_device(rewards), values) return candidate_actions, candidate_action_policy, values