def test_num_earliest_frames_ignored_priortized(self): replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=self.num_envs, max_length=self.max_length, num_earliest_frames_ignored=2, keep_episodic_info=False, prioritized_sampling=True) batch1 = get_batch([1], self.dim, x=0.25, t=0) replay_buffer.add_batch(batch1, batch1.env_id) # not enough data self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1) batch2 = get_batch([1], self.dim, x=0.25, t=1) replay_buffer.add_batch(batch2, batch1.env_id) # not enough data self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1) batch3 = get_batch([1], self.dim, x=0.25, t=2) replay_buffer.add_batch(batch3, batch1.env_id) for _ in range(10): batch, batch_info = replay_buffer.get_batch(1, 1) self.assertEqual(batch_info.env_ids, torch.tensor([1], dtype=torch.int64)) self.assertEqual(batch_info.importance_weights, 1.) self.assertEqual(batch_info.importance_weights, torch.tensor([1.])) self.assertEqual(batch.t, torch.tensor([[2]]))
def test_recent_data_and_without_replacement(self): num_envs = 4 max_length = 100 replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=num_envs, max_length=max_length, with_replacement=False, recent_data_ratio=0.5, recent_data_steps=4) replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=0, x=0.)) batch, info = replay_buffer.get_batch(4, 1) self.assertEqual(info.env_ids, torch.tensor([0, 1, 2, 3])) replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=1, x=1.0)) batch, info = replay_buffer.get_batch(8, 1) self.assertEqual(info.env_ids, torch.tensor([0, 1, 2, 3] * 2)) for t in range(2, 32): replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=t, x=t)) batch, info = replay_buffer.get_batch(32, 1) self.assertEqual(info.env_ids[16:], torch.tensor([0, 1, 2, 3] * 4)) # The first half is from recent data self.assertEqual(info.env_ids[:16], torch.tensor([0, 1, 2, 3] * 4)) self.assertEqual( info.positions[:16], torch.tensor([28] * 4 + [29] * 4 + [30] * 4 + [31] * 4))
def test_num_earliest_frames_ignored_uniform(self): num_envs = 4 max_length = 100 replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=num_envs, max_length=max_length, keep_episodic_info=False, num_earliest_frames_ignored=2) replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=0, x=0.)) # not enough data self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1) replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=1, x=0.)) # not enough data self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1) replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=2, x=0.)) for _ in range(10): batch, batch_info = replay_buffer.get_batch(1, 1) self.assertEqual(batch.t, torch.tensor([[2]]))
def test_prioritized_replay(self): replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=self.num_envs, max_length=self.max_length, prioritized_sampling=True) self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1) batch1 = get_batch([1], self.dim, x=0.25, t=0) replay_buffer.add_batch(batch1, batch1.env_id) batch, batch_info = replay_buffer.get_batch(1, 1) self.assertEqual(batch_info.env_ids, torch.tensor([1], dtype=torch.int64)) self.assertEqual(batch_info.importance_weights, 1.) self.assertEqual(batch_info.importance_weights, torch.tensor([1.])) self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 2) batch2 = get_batch([1], self.dim, x=0.5, t=1) replay_buffer.add_batch(batch1, batch1.env_id) batch, batch_info = replay_buffer.get_batch(4, 2) self.assertEqual(batch_info.env_ids, torch.tensor([1], dtype=torch.int64)) self.assertEqual(batch_info.importance_weights, torch.tensor([1.])) self.assertEqual(batch_info.importance_weights, torch.tensor([1.] * 4)) batch, batch_info = replay_buffer.get_batch(1000, 1) n0 = (replay_buffer.circular(batch_info.positions) == 0).sum() n1 = (replay_buffer.circular(batch_info.positions) == 1).sum() self.assertEqual(n0, 500) self.assertEqual(n1, 500) replay_buffer.update_priority(env_ids=torch.tensor([1, 1], dtype=torch.int64), positions=torch.tensor( [0, 1], dtype=torch.int64), priorities=torch.tensor([0.5, 1.5])) batch, batch_info = replay_buffer.get_batch(1000, 1) n0 = (replay_buffer.circular(batch_info.positions) == 0).sum() n1 = (replay_buffer.circular(batch_info.positions) == 1).sum() self.assertEqual(n0, 250) self.assertEqual(n1, 750) batch2 = get_batch([0, 2], self.dim, x=0.5, t=1) replay_buffer.add_batch(batch2, batch2.env_id) batch, batch_info = replay_buffer.get_batch(1000, 1) def _get(env_id, pos): flag = ((batch_info.env_ids == env_id) * (batch_info.positions == replay_buffer._pad(pos, env_id))) w = batch_info.importance_weights[torch.nonzero(flag, as_tuple=True)[0]] return flag.sum(), w n0, w0 = _get(0, 0) n1, w1 = _get(1, 0) n2, w2 = _get(1, 1) n3, w3 = _get(2, 0) self.assertEqual(n0, 300) self.assertEqual(n1, 100) self.assertEqual(n2, 300) self.assertEqual(n3, 300) self.assertTrue(torch.all(w0 == 1.2)) self.assertTrue(torch.all(w1 == 0.4)) self.assertTrue(torch.all(w2 == 1.2)) self.assertTrue(torch.all(w3 == 1.2)) replay_buffer.update_priority(env_ids=torch.tensor([1, 2], dtype=torch.int64), positions=torch.tensor( [1, 0], dtype=torch.int64), priorities=torch.tensor([1.0, 1.0])) batch, batch_info = replay_buffer.get_batch(1000, 1) n0, w0 = _get(0, 0) n1, w1 = _get(1, 0) n2, w2 = _get(1, 1) n3, w3 = _get(2, 0) self.assertEqual(n0, 375) self.assertEqual(n1, 125) self.assertEqual(n2, 250) self.assertEqual(n3, 250) self.assertTrue(torch.all(w0 == 1.5)) self.assertTrue(torch.all(w1 == 0.5)) self.assertTrue(torch.all(w2 == 1.0)) self.assertTrue(torch.all(w3 == 1.0))
def test_replay_with_hindsight_relabel(self): self.max_length = 8 torch.manual_seed(0) configs = [ "hindsight_relabel_fn.her_proportion=0.8", 'hindsight_relabel_fn.achieved_goal_field="o.a"', 'hindsight_relabel_fn.desired_goal_field="o.g"', "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn", ] gin.parse_config_files_and_bindings("", configs) replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=2, max_length=self.max_length, keep_episodic_info=True, step_type_field="t", with_replacement=True) steps = [ [ ds.StepType.FIRST, # will be overwritten ds.StepType.MID, # idx == 1 in buffer ds.StepType.LAST, ds.StepType.FIRST, ds.StepType.MID, ds.StepType.MID, ds.StepType.LAST, ds.StepType.FIRST, ds.StepType.MID # idx == 0 ], [ ds.StepType.FIRST, # will be overwritten in RingBuffer ds.StepType.LAST, # idx == 1 in RingBuffer ds.StepType.FIRST, ds.StepType.MID, ds.StepType.MID, ds.StepType.LAST, ds.StepType.FIRST, ds.StepType.MID, ds.StepType.MID # idx == 0 ] ] # insert data that will be overwritten later for b, t in list(itertools.product(range(2), range(8))): batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b) replay_buffer.add_batch(batch, batch.env_id) # insert data for b, t in list(itertools.product(range(2), range(9))): batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b) replay_buffer.add_batch(batch, batch.env_id) # Test padding idx = torch.tensor([[7, 0, 0, 6, 3, 3, 3, 0], [6, 0, 5, 2, 2, 2, 0, 6]]) pos = replay_buffer._pad(idx, torch.tensor([[0] * 8, [1] * 8])) self.assertTrue( torch.equal( pos, torch.tensor([[15, 16, 16, 14, 11, 11, 11, 16], [14, 16, 13, 10, 10, 10, 16, 14]]))) # Verify _index is built correctly. # Note, the _index_pos 8 represents headless timesteps, which are # outdated and not the same as the result of padding: 16. pos = torch.tensor([[15, 8, 8, 14, 11, 11, 11, 16], [14, 8, 13, 10, 10, 10, 16, 14]]) self.assertTrue(torch.equal(replay_buffer._indexed_pos, pos)) self.assertTrue( torch.equal(replay_buffer._headless_indexed_pos, torch.tensor([10, 9]))) # Save original exp for later testing. g_orig = replay_buffer._buffer.o["g"].clone() r_orig = replay_buffer._buffer.reward.clone() # HER selects indices [0, 2, 3, 4] to relabel, from all 5: # env_ids: [[0, 0], [1, 1], [0, 0], [1, 1], [0, 0]] # pos: [[6, 7], [1, 2], [1, 2], [3, 4], [5, 6]] + 8 # selected: x x x x # future: [ 7 2 2 4 6 ] + 8 # g [[.7,.7],[0, 0], [.2,.2],[1.4,1.4],[.6,.6]] # 0.1 * t + b with default 0 # reward: [[-1,0], [-1,-1],[-1,0], [-1,0], [-1,0]] # recomputed with default -1 env_ids = torch.tensor([0, 0, 1, 0]) dist = replay_buffer.steps_to_episode_end( replay_buffer._pad(torch.tensor([7, 2, 4, 6]), env_ids), env_ids) self.assertEqual(list(dist), [1, 0, 1, 0]) # Test HER relabeled experiences res = replay_buffer.get_batch(5, 2)[0] self.assertEqual(list(res.o["g"].shape), [5, 2]) # Test relabeling doesn't change original experience self.assertTrue(torch.allclose(r_orig, replay_buffer._buffer.reward)) self.assertTrue(torch.allclose(g_orig, replay_buffer._buffer.o["g"])) # test relabeled goals g = torch.tensor([0.7, 0., .2, 1.4, .6]).unsqueeze(1).expand(5, 2) self.assertTrue(torch.allclose(res.o["g"], g)) # test relabeled rewards r = torch.tensor([[-1., 0.], [-1., -1.], [-1., 0.], [-1., 0.], [-1., 0.]]) self.assertTrue(torch.allclose(res.reward, r))
def test_replay_buffer(self, allow_multiprocess, with_replacement): replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=self.num_envs, max_length=self.max_length, allow_multiprocess=allow_multiprocess) batch1 = get_batch([0, 4, 7], self.dim, t=0, x=0.1) replay_buffer.add_batch(batch1, batch1.env_id) self.assertEqual(replay_buffer._current_size, torch.tensor([1, 0, 0, 0, 1, 0, 0, 1])) self.assertEqual(replay_buffer._current_pos, torch.tensor([1, 0, 0, 0, 1, 0, 0, 1])) self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 1) batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=0, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) self.assertEqual(replay_buffer._current_size, torch.tensor([1, 1, 1, 1, 1, 1, 1, 1])) self.assertEqual(replay_buffer._current_pos, torch.tensor([1, 1, 1, 1, 1, 1, 1, 1])) batch = replay_buffer.gather_all() self.assertEqual(list(batch.t.shape), [8, 1]) # test that RingBuffer detaches gradients of inputs self.assertFalse(batch.x.requires_grad) self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 2) replay_buffer.get_batch(13, 1)[0] batch = replay_buffer.get_batch(8, 1)[0] # squeeze the time dimension batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch) bat1 = alf.nest.map_structure(lambda bat: bat[batch1.env_id], batch) bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch) self.assertEqual(bat1.env_id, batch1.env_id) self.assertEqual(bat1.x, batch1.x) self.assertEqual(bat1.t, batch1.t) self.assertEqual(bat2.env_id, batch2.env_id) self.assertEqual(bat2.x, batch2.x) self.assertEqual(bat2.t, batch2.t) for t in range(1, 10): batch3 = get_batch([0, 4, 7], self.dim, t=t, x=0.3) j = t + 1 s = min(t + 1, self.max_length) replay_buffer.add_batch(batch3, batch3.env_id) self.assertEqual(replay_buffer._current_size, torch.tensor([s, 1, 1, 1, s, 1, 1, s])) self.assertEqual(replay_buffer._current_pos, torch.tensor([j, 1, 1, 1, j, 1, 1, j])) batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=1, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) batch = replay_buffer.get_batch(8, 1)[0] # squeeze the time dimension batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch) bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id], batch) bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch) self.assertEqual(bat3.env_id, batch3.env_id) self.assertEqual(bat3.x, batch3.x) self.assertEqual(bat2.env_id, batch2.env_id) self.assertEqual(bat2.x, batch2.x) batch = replay_buffer.get_batch(8, 2)[0] t2 = [] t3 = [] for t in range(2): batch_t = alf.nest.map_structure(lambda b: b[:, t], batch) bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id], batch_t) bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch_t) t2.append(bat2.t) self.assertEqual(bat3.env_id, batch3.env_id) self.assertEqual(bat3.x, batch3.x) self.assertEqual(bat2.env_id, batch2.env_id) self.assertEqual(bat2.x, batch2.x) t3.append(bat3.t) # Test time consistency self.assertEqual(t2[0] + 1, t2[1]) self.assertEqual(t3[0] + 1, t3[1]) batch = replay_buffer.get_batch(128, 2)[0] self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(list(batch.t.shape), [128, 2]) batch = replay_buffer.get_batch(10, 2)[0] self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(list(batch.t.shape), [10, 2]) batch = replay_buffer.get_batch(4, 2)[0] self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(list(batch.t.shape), [4, 2]) # Test gather_all() # Exception because the size of all the environments are not same self.assertRaises(AssertionError, replay_buffer.gather_all) for t in range(2, 10): batch4 = get_batch([1, 2, 3, 5, 6], self.dim, t=t, x=0.4) replay_buffer.add_batch(batch4, batch4.env_id) batch = replay_buffer.gather_all() self.assertEqual(list(batch.t.shape), [8, 4]) # Test clear() replay_buffer.clear() self.assertEqual(replay_buffer.total_size, 0)
def test_compute_her_future_step_distance(self, end_prob): num_envs = 2 max_length = 100 torch.manual_seed(0) configs = [ "hindsight_relabel_fn.her_proportion=0.8", 'hindsight_relabel_fn.achieved_goal_field="o.a"', 'hindsight_relabel_fn.desired_goal_field="o.g"', "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn", ] gin.parse_config_files_and_bindings("", configs) replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=num_envs, max_length=max_length, keep_episodic_info=True, step_type_field="t") # insert data max_steps = 1000 # generate step_types with certain density of episode ends steps = self.generate_step_types(num_envs, max_steps, end_prob=end_prob) for t in range(max_steps): for b in range(num_envs): batch = get_batch([b], self.dim, t=steps[b * max_steps + t], x=1. / max_steps * t + b) replay_buffer.add_batch(batch, batch.env_id) if t > 1: sample_steps = min(t, max_length) env_ids = torch.tensor([0] * sample_steps + [1] * sample_steps) idx = torch.tensor( list(range(sample_steps)) + list(range(sample_steps))) gd = self.steps_to_episode_end(replay_buffer, env_ids, idx) idx_orig = replay_buffer._indexed_pos.clone() idx_headless_orig = replay_buffer._headless_indexed_pos.clone() d = replay_buffer.steps_to_episode_end( replay_buffer._pad(idx, env_ids), env_ids) # Test distance to end computation if not torch.equal(gd, d): outs = [ "t: ", t, "\nenvids:\n", env_ids, "\nidx:\n", idx, "\npos:\n", replay_buffer._pad(idx, env_ids), "\nNot Equal: a:\n", gd, "\nb:\n", d, "\nsteps:\n", replay_buffer._buffer.t, "\nindexed_pos:\n", replay_buffer._indexed_pos, "\nheadless_indexed_pos:\n", replay_buffer._headless_indexed_pos ] outs = [str(out) for out in outs] assert False, "".join(outs) # Save original exp for later testing. g_orig = replay_buffer._buffer.o["g"].clone() r_orig = replay_buffer._buffer.reward.clone() # HER relabel experience res = replay_buffer.get_batch(sample_steps, 2)[0] self.assertEqual(list(res.o["g"].shape), [sample_steps, 2]) # Test relabeling doesn't change original experience self.assertTrue( torch.allclose(r_orig, replay_buffer._buffer.reward)) self.assertTrue( torch.allclose(g_orig, replay_buffer._buffer.o["g"])) self.assertTrue( torch.all(idx_orig == replay_buffer._indexed_pos)) self.assertTrue( torch.all(idx_headless_orig == replay_buffer._headless_indexed_pos))