class SyncExperienceReplayer(ExperienceReplayer): """ For synchronous off-policy training. Example algorithms: DDPG, SAC """ def __init__(self, experience_spec, batch_size, max_length, num_earliest_frames_ignored=0, prioritized_sampling=False, name="SyncExperienceReplayer"): """Create a ReplayBuffer. Args: data_experience_specspec (nested TensorSpec): spec describing a single item that can be stored in the replayer. batch_size (int): number of environments. max_length (int): The maximum number of items that can be stored for a single environment. num_earliest_frames_ignored (int): ignore the earlist so many frames when sample from the buffer. This is typically required when FrameStack is used. prioritized_sampling (bool): Use prioritized sampling if this is True. """ super().__init__() self._experience_spec = experience_spec self._buffer = ReplayBuffer( experience_spec, batch_size, max_length=max_length, prioritized_sampling=prioritized_sampling, num_earliest_frames_ignored=num_earliest_frames_ignored, name=name) self._data_iter = None def observe(self, exp): """ For the sync driver, `exp` has the shape (`env_batch_size`, ...) with `num_envs`==1 and `unroll_length`==1. """ outer_rank = alf.nest.utils.get_outer_rank(exp, self._experience_spec) if outer_rank == 1: self._buffer.add_batch(exp, exp.env_id) elif outer_rank == 3: # The shape is [learn_queue_cap, unroll_length, env_batch_size, ...] for q in range(exp.step_type.shape[0]): for t in range(exp.step_type.shape[1]): bat = alf.nest.map_structure(lambda x: x[q, t, ...], exp) self._buffer.add_batch(bat, bat.env_id) else: raise ValueError("Unsupported outer rank %s of `exp`" % outer_rank) def replay(self, sample_batch_size, mini_batch_length): """Get a random batch. Args: sample_batch_size (int): number of sequences mini_batch_length (int): the length of each sequence Returns: tuple: - nested Tensors: The samples. Its shapes are [batch_size, batch_length, ...] - BatchInfo: Information about the batch. Its shapes are [batch_size]. - env_ids: environment id for each sequence - positions: starting position in the replay buffer for each sequence. - importance_weights: importance weight divided by the average of all non-zero importance weights in the buffer. """ return self._buffer.get_batch(sample_batch_size, mini_batch_length) def replay_all(self): return self._buffer.gather_all() def clear(self): self._buffer.clear() def update_priority(self, env_ids, positions, priorities): """Update the priorities for the given experiences. Args: env_ids (Tensor): 1-D int64 Tensor. positions (Tensor): 1-D int64 Tensor with same shape as ``env_ids``. This position should be obtained the BatchInfo returned by ``get_batch()`` """ self._buffer.update_priority(env_ids, positions, priorities) @property def batch_size(self): return self._buffer.num_environments @property def total_size(self): return self._buffer.total_size @property def replay_buffer(self): return self._buffer
def test_frame_stacker(self, stack_axis=0): data_spec = DataItem(step_type=alf.TensorSpec((), dtype=torch.int32), observation=dict(scalar=alf.TensorSpec(()), vector=alf.TensorSpec((7, )), matrix=alf.TensorSpec((5, 6)), tensor=alf.TensorSpec( (2, 3, 4)))) replay_buffer = ReplayBuffer(data_spec=data_spec, num_environments=2, max_length=1024, num_earliest_frames_ignored=2) frame_stacker = FrameStacker( data_spec.observation, stack_size=3, stack_axis=stack_axis, fields=['scalar', 'vector', 'matrix', 'tensor']) new_spec = frame_stacker.transformed_observation_spec self.assertEqual(new_spec['scalar'].shape, (3, )) self.assertEqual(new_spec['vector'].shape, (21, )) if stack_axis == -1: self.assertEqual(new_spec['matrix'].shape, (5, 18)) self.assertEqual(new_spec['tensor'].shape, (2, 3, 12)) elif stack_axis == 0: self.assertEqual(new_spec['matrix'].shape, (15, 6)) self.assertEqual(new_spec['tensor'].shape, (6, 3, 4)) def _step_type(t, period): if t % period == 0: return StepType.FIRST if t % period == period - 1: return StepType.LAST return StepType.MID observation = alf.nest.map_structure( lambda spec: spec.randn((1000, 2)), data_spec.observation) state = common.zero_tensor_from_nested_spec(frame_stacker.state_spec, 2) def _get_stacked_data(t, b): if stack_axis == -1: return dict(scalar=observation['scalar'][t, b], vector=observation['vector'][t, b].reshape(-1), matrix=observation['matrix'][t, b].transpose( 0, 1).reshape(5, 18), tensor=observation['tensor'][t, b].permute( 1, 2, 0, 3).reshape(2, 3, 12)) elif stack_axis == 0: return dict(scalar=observation['scalar'][t, b], vector=observation['vector'][t, b].reshape(-1), matrix=observation['matrix'][t, b].reshape(15, 6), tensor=observation['tensor'][t, b].reshape(6, 3, 4)) def _check_equal(stacked, expected, b): self.assertEqual(stacked['scalar'][b], expected['scalar']) self.assertEqual(stacked['vector'][b], expected['vector']) self.assertEqual(stacked['matrix'][b], expected['matrix']) self.assertEqual(stacked['tensor'][b], expected['tensor']) for t in range(1000): batch = DataItem( step_type=torch.tensor([_step_type(t, 17), _step_type(t, 22)]), observation=alf.nest.map_structure(lambda x: x[t], observation)) replay_buffer.add_batch(batch) timestep, state = frame_stacker.transform_timestep(batch, state) if t == 0: for b in (0, 1): expected = _get_stacked_data([0, 0, 0], b) _check_equal(timestep.observation, expected, b) if t == 1: for b in (0, 1): expected = _get_stacked_data([0, 0, 1], b) _check_equal(timestep.observation, expected, b) if t == 2: for b in (0, 1): expected = _get_stacked_data([0, 1, 2], b) _check_equal(timestep.observation, expected, b) if t == 16: for b in (0, 1): expected = _get_stacked_data([14, 15, 16], b) _check_equal(timestep.observation, expected, b) if t == 17: for b, t in ((0, [17, 17, 17]), (1, [15, 16, 17])): expected = _get_stacked_data(t, b) _check_equal(timestep.observation, expected, b) if t == 18: for b, t in ((0, [17, 17, 18]), (1, [16, 17, 18])): expected = _get_stacked_data(t, b) _check_equal(timestep.observation, expected, b) if t == 22: for b, t in ((0, [20, 21, 22]), (1, [22, 22, 22])): expected = _get_stacked_data(t, b) _check_equal(timestep.observation, expected, b) batch_info = BatchInfo(env_ids=torch.tensor([0, 1, 0, 1], dtype=torch.int64), positions=torch.tensor([0, 1, 18, 22], dtype=torch.int64)) # [4, 2, ...] experience = replay_buffer.get_field( '', batch_info.env_ids.unsqueeze(-1), batch_info.positions.unsqueeze(-1) + torch.arange(2)) experience = experience._replace(batch_info=batch_info, replay_buffer=replay_buffer) experience = frame_stacker.transform_experience(experience) expected = _get_stacked_data([0, 0, 0], 0) _check_equal(experience.observation, expected, (0, 0)) expected = _get_stacked_data([0, 0, 1], 0) _check_equal(experience.observation, expected, (0, 1)) expected = _get_stacked_data([0, 0, 1], 1) _check_equal(experience.observation, expected, (1, 0)) expected = _get_stacked_data([0, 1, 2], 1) _check_equal(experience.observation, expected, (1, 1)) expected = _get_stacked_data([17, 17, 18], 0) _check_equal(experience.observation, expected, (2, 0)) expected = _get_stacked_data([17, 18, 19], 0) _check_equal(experience.observation, expected, (2, 1)) expected = _get_stacked_data([22, 22, 22], 1) _check_equal(experience.observation, expected, (3, 0)) expected = _get_stacked_data([22, 22, 23], 1) _check_equal(experience.observation, expected, (3, 1))
def _test_preprocess_experience(self, train_reward_function, td_steps, reanalyze_ratio, expected): """ The following summarizes how the data is generated: .. code-block:: python # position: 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' scale = 1. for current model 2. for target model observation = [position] * 3 reward = position if train_reward_function and td_steps!=-1 else position * (step_type == LAST) value = 0.5 * position * scale action_probs = scale * [position, position+1, position] for env 0 scale * [position+1, position, position] for env 1 action = 1 for env 0 0 for env 1 """ reanalyze_td_steps = 2 num_unroll_steps = 4 batch_size = 2 obs_dim = 3 observation_spec = alf.TensorSpec([obs_dim]) action_spec = alf.BoundedTensorSpec((), minimum=0, maximum=1, dtype=torch.int32) reward_spec = alf.TensorSpec(()) time_step_spec = ds.time_step_spec(observation_spec, action_spec, reward_spec) global _mcts_model_id _mcts_model_id = 0 muzero = MuzeroAlgorithm(observation_spec, action_spec, model_ctor=_create_mcts_model, mcts_algorithm_ctor=MockMCTSAlgorithm, num_unroll_steps=num_unroll_steps, td_steps=td_steps, train_game_over_function=True, train_reward_function=train_reward_function, reanalyze_ratio=reanalyze_ratio, reanalyze_td_steps=reanalyze_td_steps, data_transformer_ctor=partial(FrameStacker, stack_size=2)) data_transformer = FrameStacker(observation_spec, stack_size=2) time_step = common.zero_tensor_from_nested_spec( time_step_spec, batch_size) dt_state = common.zero_tensor_from_nested_spec( data_transformer.state_spec, batch_size) state = muzero.get_initial_predict_state(batch_size) transformed_time_step, dt_state = data_transformer.transform_timestep( time_step, dt_state) alg_step = muzero.rollout_step(transformed_time_step, state) alg_step_spec = dist_utils.extract_spec(alg_step) experience = ds.make_experience(time_step, alg_step, state) experience_spec = ds.make_experience(time_step_spec, alg_step_spec, muzero.train_state_spec) replay_buffer = ReplayBuffer(data_spec=experience_spec, num_environments=batch_size, max_length=16, keep_episodic_info=True) # 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' dt_state = common.zero_tensor_from_nested_spec( data_transformer.state_spec, batch_size) for i in range(len(step_type0)): step_type = [step_type0[i], step_type1[i]] step_type = [ ds.StepType.MID if c == 'M' else (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST) for c in step_type ] step_type = torch.tensor(step_type, dtype=torch.int32) reward = reward = torch.full([batch_size], float(i)) if not train_reward_function or td_steps == -1: reward = reward * (step_type == ds.StepType.LAST).to( torch.float32) time_step = time_step._replace( discount=(step_type != ds.StepType.LAST).to(torch.float32), step_type=step_type, observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]], dtype=torch.float32), reward=reward, env_id=torch.arange(batch_size, dtype=torch.int32)) transformed_time_step, dt_state = data_transformer.transform_timestep( time_step, dt_state) alg_step = muzero.rollout_step(transformed_time_step, state) experience = ds.make_experience(time_step, alg_step, state) replay_buffer.add_batch(experience) state = alg_step.state env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64) positions = torch.tensor(list(range(14)) + list(range(14)), dtype=torch.int64) experience = replay_buffer.get_field(None, env_ids.unsqueeze(-1).cpu(), positions.unsqueeze(-1).cpu()) experience = experience._replace(replay_buffer=replay_buffer, batch_info=BatchInfo( env_ids=env_ids, positions=positions), rollout_info_field='rollout_info') processed_experience = muzero.preprocess_experience(experience) import pprint pprint.pprint(processed_experience.rollout_info) alf.nest.map_structure(lambda x, y: self.assertEqual(x, y), processed_experience.rollout_info, expected)
def test_replay_buffer(self): dim = 20 max_length = 4 num_envs = 8 data_spec = DataItem( env_id=tf.TensorSpec(shape=(), dtype=tf.int32), x=tf.TensorSpec(shape=(dim, ), dtype=tf.float32), t=tf.TensorSpec(shape=(), dtype=tf.int32)) replay_buffer = ReplayBuffer( data_spec=data_spec, num_environments=num_envs, max_length=max_length) def _get_batch(env_ids, t, x): batch_size = len(env_ids) x = (x * tf.expand_dims(tf.range(batch_size, dtype=tf.float32), 1) * tf.expand_dims(tf.range(dim, dtype=tf.float32), 0)) return DataItem( env_id=tf.constant(env_ids), x=x, t=t * tf.ones((batch_size, ), tf.int32)) batch1 = _get_batch([0, 4, 7], t=0, x=0.1) replay_buffer.add_batch(batch1, batch1.env_id) self.assertArrayEqual(replay_buffer._current_size, [1, 0, 0, 0, 1, 0, 0, 1]) self.assertArrayEqual(replay_buffer._current_pos, [1, 0, 0, 0, 1, 0, 0, 1]) with self.assertRaises(tf.errors.InvalidArgumentError): replay_buffer.get_batch(8, 1) batch2 = _get_batch([1, 2, 3, 5, 6], t=0, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) self.assertArrayEqual(replay_buffer._current_size, [1, 1, 1, 1, 1, 1, 1, 1]) self.assertArrayEqual(replay_buffer._current_pos, [1, 1, 1, 1, 1, 1, 1, 1]) batch = replay_buffer.gather_all() self.assertEqual(batch.t.shape, [8, 1]) with self.assertRaises(tf.errors.InvalidArgumentError): replay_buffer.get_batch(8, 2) replay_buffer.get_batch(13, 1) batch = replay_buffer.get_batch(8, 1) # squeeze the time dimension batch = tf.nest.map_structure(lambda bat: tf.squeeze(bat, axis=1), batch) bat1 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch1.env_id, axis=0), batch) bat2 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch) self.assertArrayEqual(bat1.env_id, batch1.env_id) self.assertArrayEqual(bat1.x, batch1.x) self.assertArrayEqual(bat1.t, batch1.t) self.assertArrayEqual(bat2.env_id, batch2.env_id) self.assertArrayEqual(bat2.x, batch2.x) self.assertArrayEqual(bat2.t, batch2.t) for t in range(1, 10): batch3 = _get_batch([0, 4, 7], t=t, x=0.3) j = (t + 1) % max_length s = min(t + 1, max_length) replay_buffer.add_batch(batch3, batch3.env_id) self.assertArrayEqual(replay_buffer._current_size, [s, 1, 1, 1, s, 1, 1, s]) self.assertArrayEqual(replay_buffer._current_pos, [j, 1, 1, 1, j, 1, 1, j]) batch2 = _get_batch([1, 2, 3, 5, 6], t=1, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) batch = replay_buffer.get_batch(8, 1) # squeeze the time dimension batch = tf.nest.map_structure(lambda bat: tf.squeeze(bat, axis=1), batch) bat3 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch3.env_id, axis=0), batch) bat2 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch) self.assertArrayEqual(bat3.env_id, batch3.env_id) self.assertArrayEqual(bat3.x, batch3.x) self.assertArrayEqual(bat2.env_id, batch2.env_id) self.assertArrayEqual(bat2.x, batch2.x) batch = replay_buffer.get_batch(8, 2) t2 = [] t3 = [] for t in range(2): batch_t = tf.nest.map_structure(lambda b: b[:, t], batch) bat3 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch3.env_id, axis=0), batch_t) bat2 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch_t) t2.append(bat2.t) self.assertArrayEqual(bat3.env_id, batch3.env_id) self.assertArrayEqual(bat3.x, batch3.x) self.assertArrayEqual(bat2.env_id, batch2.env_id) self.assertArrayEqual(bat2.x, batch2.x) t3.append(bat3.t) # Test time consistency self.assertArrayEqual(t2[0] + 1, t2[1]) self.assertArrayEqual(t3[0] + 1, t3[1]) batch = replay_buffer.get_batch(128, 2) self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(batch.t.shape, [128, 2]) batch = replay_buffer.get_batch(10, 2) self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(batch.t.shape, [10, 2]) batch = replay_buffer.get_batch(4, 2) self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(batch.t.shape, [4, 2]) # Test gather_all() with self.assertRaises(tf.errors.InvalidArgumentError): replay_buffer.gather_all() for t in range(2, 10): batch4 = _get_batch([1, 2, 3, 5, 6], t=t, x=0.4) replay_buffer.add_batch(batch4, batch4.env_id) batch = replay_buffer.gather_all() self.assertEqual(batch.t.shape, [8, 4])
def test_recent_data_and_without_replacement(self): num_envs = 4 max_length = 100 replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=num_envs, max_length=max_length, with_replacement=False, recent_data_ratio=0.5, recent_data_steps=4) replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=0, x=0.)) batch, info = replay_buffer.get_batch(4, 1) self.assertEqual(info.env_ids, torch.tensor([0, 1, 2, 3])) replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=1, x=1.0)) batch, info = replay_buffer.get_batch(8, 1) self.assertEqual(info.env_ids, torch.tensor([0, 1, 2, 3] * 2)) for t in range(2, 32): replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=t, x=t)) batch, info = replay_buffer.get_batch(32, 1) self.assertEqual(info.env_ids[16:], torch.tensor([0, 1, 2, 3] * 4)) # The first half is from recent data self.assertEqual(info.env_ids[:16], torch.tensor([0, 1, 2, 3] * 4)) self.assertEqual( info.positions[:16], torch.tensor([28] * 4 + [29] * 4 + [30] * 4 + [31] * 4))
def test_prioritized_replay(self): replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=self.num_envs, max_length=self.max_length, prioritized_sampling=True) self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1) batch1 = get_batch([1], self.dim, x=0.25, t=0) replay_buffer.add_batch(batch1, batch1.env_id) batch, batch_info = replay_buffer.get_batch(1, 1) self.assertEqual(batch_info.env_ids, torch.tensor([1], dtype=torch.int64)) self.assertEqual(batch_info.importance_weights, 1.) self.assertEqual(batch_info.importance_weights, torch.tensor([1.])) self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 2) batch2 = get_batch([1], self.dim, x=0.5, t=1) replay_buffer.add_batch(batch1, batch1.env_id) batch, batch_info = replay_buffer.get_batch(4, 2) self.assertEqual(batch_info.env_ids, torch.tensor([1], dtype=torch.int64)) self.assertEqual(batch_info.importance_weights, torch.tensor([1.])) self.assertEqual(batch_info.importance_weights, torch.tensor([1.] * 4)) batch, batch_info = replay_buffer.get_batch(1000, 1) n0 = (replay_buffer.circular(batch_info.positions) == 0).sum() n1 = (replay_buffer.circular(batch_info.positions) == 1).sum() self.assertEqual(n0, 500) self.assertEqual(n1, 500) replay_buffer.update_priority(env_ids=torch.tensor([1, 1], dtype=torch.int64), positions=torch.tensor( [0, 1], dtype=torch.int64), priorities=torch.tensor([0.5, 1.5])) batch, batch_info = replay_buffer.get_batch(1000, 1) n0 = (replay_buffer.circular(batch_info.positions) == 0).sum() n1 = (replay_buffer.circular(batch_info.positions) == 1).sum() self.assertEqual(n0, 250) self.assertEqual(n1, 750) batch2 = get_batch([0, 2], self.dim, x=0.5, t=1) replay_buffer.add_batch(batch2, batch2.env_id) batch, batch_info = replay_buffer.get_batch(1000, 1) def _get(env_id, pos): flag = ((batch_info.env_ids == env_id) * (batch_info.positions == replay_buffer._pad(pos, env_id))) w = batch_info.importance_weights[torch.nonzero(flag, as_tuple=True)[0]] return flag.sum(), w n0, w0 = _get(0, 0) n1, w1 = _get(1, 0) n2, w2 = _get(1, 1) n3, w3 = _get(2, 0) self.assertEqual(n0, 300) self.assertEqual(n1, 100) self.assertEqual(n2, 300) self.assertEqual(n3, 300) self.assertTrue(torch.all(w0 == 1.2)) self.assertTrue(torch.all(w1 == 0.4)) self.assertTrue(torch.all(w2 == 1.2)) self.assertTrue(torch.all(w3 == 1.2)) replay_buffer.update_priority(env_ids=torch.tensor([1, 2], dtype=torch.int64), positions=torch.tensor( [1, 0], dtype=torch.int64), priorities=torch.tensor([1.0, 1.0])) batch, batch_info = replay_buffer.get_batch(1000, 1) n0, w0 = _get(0, 0) n1, w1 = _get(1, 0) n2, w2 = _get(1, 1) n3, w3 = _get(2, 0) self.assertEqual(n0, 375) self.assertEqual(n1, 125) self.assertEqual(n2, 250) self.assertEqual(n3, 250) self.assertTrue(torch.all(w0 == 1.5)) self.assertTrue(torch.all(w1 == 0.5)) self.assertTrue(torch.all(w2 == 1.0)) self.assertTrue(torch.all(w3 == 1.0))
def test_replay_with_hindsight_relabel(self): self.max_length = 8 torch.manual_seed(0) configs = [ "hindsight_relabel_fn.her_proportion=0.8", 'hindsight_relabel_fn.achieved_goal_field="o.a"', 'hindsight_relabel_fn.desired_goal_field="o.g"', "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn", ] gin.parse_config_files_and_bindings("", configs) replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=2, max_length=self.max_length, keep_episodic_info=True, step_type_field="t", with_replacement=True) steps = [ [ ds.StepType.FIRST, # will be overwritten ds.StepType.MID, # idx == 1 in buffer ds.StepType.LAST, ds.StepType.FIRST, ds.StepType.MID, ds.StepType.MID, ds.StepType.LAST, ds.StepType.FIRST, ds.StepType.MID # idx == 0 ], [ ds.StepType.FIRST, # will be overwritten in RingBuffer ds.StepType.LAST, # idx == 1 in RingBuffer ds.StepType.FIRST, ds.StepType.MID, ds.StepType.MID, ds.StepType.LAST, ds.StepType.FIRST, ds.StepType.MID, ds.StepType.MID # idx == 0 ] ] # insert data that will be overwritten later for b, t in list(itertools.product(range(2), range(8))): batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b) replay_buffer.add_batch(batch, batch.env_id) # insert data for b, t in list(itertools.product(range(2), range(9))): batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b) replay_buffer.add_batch(batch, batch.env_id) # Test padding idx = torch.tensor([[7, 0, 0, 6, 3, 3, 3, 0], [6, 0, 5, 2, 2, 2, 0, 6]]) pos = replay_buffer._pad(idx, torch.tensor([[0] * 8, [1] * 8])) self.assertTrue( torch.equal( pos, torch.tensor([[15, 16, 16, 14, 11, 11, 11, 16], [14, 16, 13, 10, 10, 10, 16, 14]]))) # Verify _index is built correctly. # Note, the _index_pos 8 represents headless timesteps, which are # outdated and not the same as the result of padding: 16. pos = torch.tensor([[15, 8, 8, 14, 11, 11, 11, 16], [14, 8, 13, 10, 10, 10, 16, 14]]) self.assertTrue(torch.equal(replay_buffer._indexed_pos, pos)) self.assertTrue( torch.equal(replay_buffer._headless_indexed_pos, torch.tensor([10, 9]))) # Save original exp for later testing. g_orig = replay_buffer._buffer.o["g"].clone() r_orig = replay_buffer._buffer.reward.clone() # HER selects indices [0, 2, 3, 4] to relabel, from all 5: # env_ids: [[0, 0], [1, 1], [0, 0], [1, 1], [0, 0]] # pos: [[6, 7], [1, 2], [1, 2], [3, 4], [5, 6]] + 8 # selected: x x x x # future: [ 7 2 2 4 6 ] + 8 # g [[.7,.7],[0, 0], [.2,.2],[1.4,1.4],[.6,.6]] # 0.1 * t + b with default 0 # reward: [[-1,0], [-1,-1],[-1,0], [-1,0], [-1,0]] # recomputed with default -1 env_ids = torch.tensor([0, 0, 1, 0]) dist = replay_buffer.steps_to_episode_end( replay_buffer._pad(torch.tensor([7, 2, 4, 6]), env_ids), env_ids) self.assertEqual(list(dist), [1, 0, 1, 0]) # Test HER relabeled experiences res = replay_buffer.get_batch(5, 2)[0] self.assertEqual(list(res.o["g"].shape), [5, 2]) # Test relabeling doesn't change original experience self.assertTrue(torch.allclose(r_orig, replay_buffer._buffer.reward)) self.assertTrue(torch.allclose(g_orig, replay_buffer._buffer.o["g"])) # test relabeled goals g = torch.tensor([0.7, 0., .2, 1.4, .6]).unsqueeze(1).expand(5, 2) self.assertTrue(torch.allclose(res.o["g"], g)) # test relabeled rewards r = torch.tensor([[-1., 0.], [-1., -1.], [-1., 0.], [-1., 0.], [-1., 0.]]) self.assertTrue(torch.allclose(res.reward, r))
def test_replay_buffer(self, allow_multiprocess, with_replacement): replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=self.num_envs, max_length=self.max_length, allow_multiprocess=allow_multiprocess) batch1 = get_batch([0, 4, 7], self.dim, t=0, x=0.1) replay_buffer.add_batch(batch1, batch1.env_id) self.assertEqual(replay_buffer._current_size, torch.tensor([1, 0, 0, 0, 1, 0, 0, 1])) self.assertEqual(replay_buffer._current_pos, torch.tensor([1, 0, 0, 0, 1, 0, 0, 1])) self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 1) batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=0, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) self.assertEqual(replay_buffer._current_size, torch.tensor([1, 1, 1, 1, 1, 1, 1, 1])) self.assertEqual(replay_buffer._current_pos, torch.tensor([1, 1, 1, 1, 1, 1, 1, 1])) batch = replay_buffer.gather_all() self.assertEqual(list(batch.t.shape), [8, 1]) # test that RingBuffer detaches gradients of inputs self.assertFalse(batch.x.requires_grad) self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 2) replay_buffer.get_batch(13, 1)[0] batch = replay_buffer.get_batch(8, 1)[0] # squeeze the time dimension batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch) bat1 = alf.nest.map_structure(lambda bat: bat[batch1.env_id], batch) bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch) self.assertEqual(bat1.env_id, batch1.env_id) self.assertEqual(bat1.x, batch1.x) self.assertEqual(bat1.t, batch1.t) self.assertEqual(bat2.env_id, batch2.env_id) self.assertEqual(bat2.x, batch2.x) self.assertEqual(bat2.t, batch2.t) for t in range(1, 10): batch3 = get_batch([0, 4, 7], self.dim, t=t, x=0.3) j = t + 1 s = min(t + 1, self.max_length) replay_buffer.add_batch(batch3, batch3.env_id) self.assertEqual(replay_buffer._current_size, torch.tensor([s, 1, 1, 1, s, 1, 1, s])) self.assertEqual(replay_buffer._current_pos, torch.tensor([j, 1, 1, 1, j, 1, 1, j])) batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=1, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) batch = replay_buffer.get_batch(8, 1)[0] # squeeze the time dimension batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch) bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id], batch) bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch) self.assertEqual(bat3.env_id, batch3.env_id) self.assertEqual(bat3.x, batch3.x) self.assertEqual(bat2.env_id, batch2.env_id) self.assertEqual(bat2.x, batch2.x) batch = replay_buffer.get_batch(8, 2)[0] t2 = [] t3 = [] for t in range(2): batch_t = alf.nest.map_structure(lambda b: b[:, t], batch) bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id], batch_t) bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch_t) t2.append(bat2.t) self.assertEqual(bat3.env_id, batch3.env_id) self.assertEqual(bat3.x, batch3.x) self.assertEqual(bat2.env_id, batch2.env_id) self.assertEqual(bat2.x, batch2.x) t3.append(bat3.t) # Test time consistency self.assertEqual(t2[0] + 1, t2[1]) self.assertEqual(t3[0] + 1, t3[1]) batch = replay_buffer.get_batch(128, 2)[0] self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(list(batch.t.shape), [128, 2]) batch = replay_buffer.get_batch(10, 2)[0] self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(list(batch.t.shape), [10, 2]) batch = replay_buffer.get_batch(4, 2)[0] self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(list(batch.t.shape), [4, 2]) # Test gather_all() # Exception because the size of all the environments are not same self.assertRaises(AssertionError, replay_buffer.gather_all) for t in range(2, 10): batch4 = get_batch([1, 2, 3, 5, 6], self.dim, t=t, x=0.4) replay_buffer.add_batch(batch4, batch4.env_id) batch = replay_buffer.gather_all() self.assertEqual(list(batch.t.shape), [8, 4]) # Test clear() replay_buffer.clear() self.assertEqual(replay_buffer.total_size, 0)
def test_compute_her_future_step_distance(self, end_prob): num_envs = 2 max_length = 100 torch.manual_seed(0) configs = [ "hindsight_relabel_fn.her_proportion=0.8", 'hindsight_relabel_fn.achieved_goal_field="o.a"', 'hindsight_relabel_fn.desired_goal_field="o.g"', "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn", ] gin.parse_config_files_and_bindings("", configs) replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=num_envs, max_length=max_length, keep_episodic_info=True, step_type_field="t") # insert data max_steps = 1000 # generate step_types with certain density of episode ends steps = self.generate_step_types(num_envs, max_steps, end_prob=end_prob) for t in range(max_steps): for b in range(num_envs): batch = get_batch([b], self.dim, t=steps[b * max_steps + t], x=1. / max_steps * t + b) replay_buffer.add_batch(batch, batch.env_id) if t > 1: sample_steps = min(t, max_length) env_ids = torch.tensor([0] * sample_steps + [1] * sample_steps) idx = torch.tensor( list(range(sample_steps)) + list(range(sample_steps))) gd = self.steps_to_episode_end(replay_buffer, env_ids, idx) idx_orig = replay_buffer._indexed_pos.clone() idx_headless_orig = replay_buffer._headless_indexed_pos.clone() d = replay_buffer.steps_to_episode_end( replay_buffer._pad(idx, env_ids), env_ids) # Test distance to end computation if not torch.equal(gd, d): outs = [ "t: ", t, "\nenvids:\n", env_ids, "\nidx:\n", idx, "\npos:\n", replay_buffer._pad(idx, env_ids), "\nNot Equal: a:\n", gd, "\nb:\n", d, "\nsteps:\n", replay_buffer._buffer.t, "\nindexed_pos:\n", replay_buffer._indexed_pos, "\nheadless_indexed_pos:\n", replay_buffer._headless_indexed_pos ] outs = [str(out) for out in outs] assert False, "".join(outs) # Save original exp for later testing. g_orig = replay_buffer._buffer.o["g"].clone() r_orig = replay_buffer._buffer.reward.clone() # HER relabel experience res = replay_buffer.get_batch(sample_steps, 2)[0] self.assertEqual(list(res.o["g"].shape), [sample_steps, 2]) # Test relabeling doesn't change original experience self.assertTrue( torch.allclose(r_orig, replay_buffer._buffer.reward)) self.assertTrue( torch.allclose(g_orig, replay_buffer._buffer.o["g"])) self.assertTrue( torch.all(idx_orig == replay_buffer._indexed_pos)) self.assertTrue( torch.all(idx_headless_orig == replay_buffer._headless_indexed_pos))
def test_preprocess_experience(self): """ The following summarizes how the data is generated: .. code-block:: python # position: 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' reward = position if train_reward_function and td_steps!=-1 else position * (step_type == LAST) action = t + 1 for env 0 t for env 1 """ num_unroll_steps = 4 batch_size = 2 obs_dim = 3 observation_spec = alf.TensorSpec([obs_dim]) action_spec = alf.BoundedTensorSpec((1, ), minimum=0, maximum=1, dtype=torch.float32) reward_spec = alf.TensorSpec(()) time_step_spec = ds.time_step_spec(observation_spec, action_spec, reward_spec) repr_learner = PredictiveRepresentationLearner( observation_spec, action_spec, num_unroll_steps=num_unroll_steps, decoder_ctor=partial(SimpleDecoder, target_field='reward', decoder_net_ctor=partial( EncodingNetwork, fc_layer_params=(4, ))), encoding_net_ctor=LSTMEncodingNetwork, dynamics_net_ctor=LSTMEncodingNetwork) time_step = common.zero_tensor_from_nested_spec( time_step_spec, batch_size) state = repr_learner.get_initial_predict_state(batch_size) alg_step = repr_learner.rollout_step(time_step, state) alg_step = alg_step._replace(output=torch.tensor([[1.], [0.]])) alg_step_spec = dist_utils.extract_spec(alg_step) experience = ds.make_experience(time_step, alg_step, state) experience_spec = ds.make_experience(time_step_spec, alg_step_spec, repr_learner.train_state_spec) replay_buffer = ReplayBuffer(data_spec=experience_spec, num_environments=batch_size, max_length=16, keep_episodic_info=True) # 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' for i in range(len(step_type0)): step_type = [step_type0[i], step_type1[i]] step_type = [ ds.StepType.MID if c == 'M' else (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST) for c in step_type ] step_type = torch.tensor(step_type, dtype=torch.int32) reward = reward = torch.full([batch_size], float(i)) time_step = time_step._replace( discount=(step_type != ds.StepType.LAST).to(torch.float32), step_type=step_type, observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]], dtype=torch.float32), reward=reward, env_id=torch.arange(batch_size, dtype=torch.int32)) alg_step = repr_learner.rollout_step(time_step, state) alg_step = alg_step._replace(output=i + torch.tensor([[1.], [0.]])) experience = ds.make_experience(time_step, alg_step, state) replay_buffer.add_batch(experience) state = alg_step.state env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64) positions = torch.tensor(list(range(14)) + list(range(14)), dtype=torch.int64) experience = replay_buffer.get_field(None, env_ids.unsqueeze(-1).cpu(), positions.unsqueeze(-1).cpu()) experience = experience._replace(replay_buffer=replay_buffer, batch_info=BatchInfo( env_ids=env_ids, positions=positions), rollout_info_field='rollout_info') processed_experience = repr_learner.preprocess_experience(experience) pprint.pprint(processed_experience.rollout_info) # yapf: disable expected = PredictiveRepresentationLearnerInfo( action=torch.tensor( [[[ 1., 2., 3., 4., 5.]], [[ 2., 3., 4., 5., 5.]], [[ 3., 4., 5., 5., 5.]], [[ 4., 5., 5., 5., 5.]], [[ 5., 5., 5., 5., 5.]], [[ 6., 7., 8., 9., 9.]], [[ 7., 8., 9., 9., 9.]], [[ 8., 9., 9., 9., 9.]], [[ 9., 9., 9., 9., 9.]], [[10., 11., 12., 13., 14.]], [[11., 12., 13., 14., 14.]], [[12., 13., 14., 14., 14.]], [[13., 14., 14., 14., 14.]], [[14., 14., 14., 14., 14.]], [[ 0., 1., 2., 3., 4.]], [[ 1., 2., 3., 4., 5.]], [[ 2., 3., 4., 5., 6.]], [[ 3., 4., 5., 6., 6.]], [[ 4., 5., 6., 6., 6.]], [[ 5., 6., 6., 6., 6.]], [[ 6., 6., 6., 6., 6.]], [[ 7., 8., 9., 10., 11.]], [[ 8., 9., 10., 11., 12.]], [[ 9., 10., 11., 12., 12.]], [[10., 11., 12., 12., 12.]], [[11., 12., 12., 12., 12.]], [[12., 12., 12., 12., 12.]], [[13., 13., 13., 13., 13.]]]).unsqueeze(-1), mask=torch.tensor( [[[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, True]], [[ True, True, True, True, True]], [[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, True]], [[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, False, False, False, False]]]), target=torch.tensor( [[[ 0., 1., 2., 3., 4.]], [[ 1., 2., 3., 4., 4.]], [[ 2., 3., 4., 4., 4.]], [[ 3., 4., 4., 4., 4.]], [[ 4., 4., 4., 4., 4.]], [[ 5., 6., 7., 8., 8.]], [[ 6., 7., 8., 8., 8.]], [[ 7., 8., 8., 8., 8.]], [[ 8., 8., 8., 8., 8.]], [[ 9., 10., 11., 12., 13.]], [[10., 11., 12., 13., 13.]], [[11., 12., 13., 13., 13.]], [[12., 13., 13., 13., 13.]], [[13., 13., 13., 13., 13.]], [[ 0., 1., 2., 3., 4.]], [[ 1., 2., 3., 4., 5.]], [[ 2., 3., 4., 5., 6.]], [[ 3., 4., 5., 6., 6.]], [[ 4., 5., 6., 6., 6.]], [[ 5., 6., 6., 6., 6.]], [[ 6., 6., 6., 6., 6.]], [[ 7., 8., 9., 10., 11.]], [[ 8., 9., 10., 11., 12.]], [[ 9., 10., 11., 12., 12.]], [[10., 11., 12., 12., 12.]], [[11., 12., 12., 12., 12.]], [[12., 12., 12., 12., 12.]], [[13., 13., 13., 13., 13.]]])) # yapf: enable alf.nest.map_structure(lambda x, y: self.assertEqual(x, y), processed_experience.rollout_info, expected)
def _reanalyze1(self, replay_buffer: ReplayBuffer, env_ids, positions, mcts_state_field): """Reanalyze one batch. This means: 1. Re-plan the policy using MCTS for n1 = 1 + num_unroll_steps to get fresh policy and value target. 2. Caluclate the value for following n2 = reanalyze_td_steps so that we have value for a total of 1 + num_unroll_steps + reanalyze_td_steps. 3. Use these values and rewards from replay buffer to caculate n2-step bootstraped value target for the first n1 steps. In order to do 1 and 2, we need to get the observations for n1 + n2 steps and processs them using data_transformer. """ batch_size = env_ids.shape[0] n1 = self._num_unroll_steps + 1 n2 = self._reanalyze_td_steps env_ids, positions = self._next_n_positions( replay_buffer, env_ids, positions, self._num_unroll_steps + n2) # [B, n1] positions1 = positions[:, :n1] game_overs = replay_buffer.get_field('discount', env_ids, positions1) == 0. steps_to_episode_end = replay_buffer.steps_to_episode_end( positions1, env_ids) bootstrap_n = steps_to_episode_end.clamp(max=n2) exp1, exp2 = self._prepare_reanalyze_data(replay_buffer, env_ids, positions, n1, n2) bootstrap_position = positions1 + bootstrap_n discount = replay_buffer.get_field('discount', env_ids, bootstrap_position) sum_reward = self._sum_discounted_reward(replay_buffer, env_ids, positions1, bootstrap_position, n2) if not self._train_reward_function: rewards = self._get_reward(replay_buffer, env_ids, bootstrap_position) with alf.device(self._device): bootstrap_n = convert_device(bootstrap_n) discount = convert_device(discount) sum_reward = convert_device(sum_reward) game_overs = convert_device(game_overs) # 1. Reanalyze the first n1 steps to get both the updated value and policy self._mcts.set_model(self._target_model) mcts_step = self._mcts.predict_step( exp1, alf.nest.get_field(exp1, mcts_state_field)) self._mcts.set_model(self._model) candidate_actions = () if not _is_empty(mcts_step.info.candidate_actions): candidate_actions = mcts_step.info.candidate_actions candidate_actions = candidate_actions.reshape( batch_size, n1, *candidate_actions.shape[1:]) candidate_action_policy = mcts_step.info.candidate_action_policy candidate_action_policy = candidate_action_policy.reshape( batch_size, n1, *candidate_action_policy.shape[1:]) values1 = mcts_step.info.value.reshape(batch_size, n1) # 2. Calulate the value of the next n2 steps so that n2-step return # can be computed. model_output = self._target_model.initial_inference( exp2.observation) values2 = model_output.value.reshape(batch_size, n2) # 3. Calculate n2-step return values = torch.cat([values1, values2], dim=1) # [B, n1] bootstrap_pos = torch.arange(n1).unsqueeze(0) + bootstrap_n values = values[torch.arange(batch_size).unsqueeze(-1), bootstrap_pos] values = values * discount * (self._discount**bootstrap_n.to( torch.float32)) values = values + sum_reward if not self._train_reward_function: # For this condition, we need to set the value at and after the # last step to be the last reward. values = torch.where(game_overs, convert_device(rewards), values) return candidate_actions, candidate_action_policy, values