def _test_forward(self): """Generate a dummy input according to `nested_input_tensor_spec` and forward. Can be used to calculate output spec or testing the network. """ inputs = common.zero_tensor_from_nested_spec(self._input_tensor_spec, batch_size=1) states = common.zero_tensor_from_nested_spec(self.state_spec, batch_size=1) return self.forward(inputs, states)
def test_encoding_network_input_preprocessor(self): input_spec = TensorSpec((1, )) inputs = common.zero_tensor_from_nested_spec(input_spec, batch_size=1) network = EncodingNetwork(input_tensor_spec=input_spec, input_preprocessors=torch.tanh) output, _ = network(inputs) self.assertEqual(output.size()[1], 1)
def _calc_change(self, exp_array): """Calculate the distance between old/new action distributions. The distance is: ||logits_1 - logits_2||^2 for Categorical distribution KL(d1||d2) + KL(d2||d1) for others """ def _dist(d1, d2): if isinstance(d1, tfp.distributions.Categorical): return tf.reduce_sum(tf.square(d1.logits - d2.logits), axis=-1) elif isinstance(d1, tfp.distributions.Deterministic): return tf.reduce_sum(tf.square(d1.loc - d2.loc), axis=-1) else: if isinstance(d1, SquashToSpecNormal): # TODO `SquashToSpecNormal.kl_divergence` checks that two distributions should have # same action mean and magnitude, but this check fails in graph mode d1 = d1.input_distribution d2 = d2.input_distribution return tf.reduce_sum( d1.kl_divergence(d2) + d2.kl_divergence(d1), axis=-1) def _update_total_dists(new_action, exp, total_dists): old_action = nested_distributions_from_specs( common.to_distribution_spec(self.action_distribution_spec), exp.action_param) dists = nest_map(_dist, old_action, new_action) valid_masks = tf.cast( tf.not_equal(exp.step_type, StepType.LAST), tf.float32) dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists) return nest_map(lambda x, y: x + y, total_dists, dists) num_steps = exp_array.step_type.size() # element_shape for `TensorArray` can be (None, ...) batch_size = tf.shape(exp_array.step_type.read(0))[0] state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state) # exp_array.state is no longer needed exp_array = exp_array._replace(state=()) initial_state = common.zero_tensor_from_nested_spec( self.predict_state_spec, batch_size) total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec) for t in tf.range(num_steps): exp = tf.nest.map_structure(lambda x: x.read(t), exp_array) state = common.reset_state_if_necessary( state, initial_state, exp.step_type == StepType.FIRST) time_step = ActionTimeStep( observation=exp.observation, step_type=exp.step_type) policy_step = self._ac_algorithm.predict( time_step=time_step, state=state) new_action, state = policy_step.action, policy_step.state new_action = common.to_distribution(new_action) total_dists = _update_total_dists(new_action, exp, total_dists) size = tf.cast(num_steps * batch_size, tf.float32) total_dists = nest_map(lambda d: d / size, total_dists) return total_dists
def setUp(self): self._input_spec = [ TensorSpec((3, 20, 20), torch.float32), TensorSpec((1, 20, 20), torch.float32) ] self._image = zero_tensor_from_nested_spec(self._input_spec, batch_size=1) self._conv_layer_params = ((8, 3, 1), (16, 3, 2, 1)) self._fc_layer_params = (100, ) self._input_preprocessors = [torch.tanh, None] self._preprocessing_combiner = NestConcat(dim=1)
def test_encoding_network_preprocessing_combiner(self): input_spec = dict(a=TensorSpec((3, 80, 80)), b=[TensorSpec((80, 80)), TensorSpec(())]) imgs = common.zero_tensor_from_nested_spec(input_spec, batch_size=1) network = EncodingNetwork(input_tensor_spec=input_spec, preprocessing_combiner=NestSum(average=True), conv_layer_params=((1, 2, 2, 0), )) self.assertEqual(network._processed_input_tensor_spec, TensorSpec((3, 80, 80))) output, _ = network(imgs) self.assertTensorEqual(output, torch.zeros((40 * 40, )))
def test_q_value_network(self, lstm_hidden_size): input_spec = [TensorSpec((3, 20, 20), torch.float32)] conv_layer_params = ((8, 3, 1), (16, 3, 2, 1)) image = common.zero_tensor_from_nested_spec(input_spec, batch_size=1) network_ctor, state = self._init(lstm_hidden_size) q_net = network_ctor(input_spec, self._action_spec, input_preprocessors=[torch.relu], preprocessing_combiner=NestSum(), conv_layer_params=conv_layer_params) q_value, state = q_net(image, state) # (batch_size, num_actions) self.assertEqual(q_value.shape, (1, self._num_actions))
def test_encoding_network_nested_input(self, lstm): input_spec = dict(a=TensorSpec((3, 80, 80)), b=[ TensorSpec((80, )), BoundedTensorSpec((), dtype="int64"), dict(x=TensorSpec((100, )), y=TensorSpec((200, ))) ]) imgs = common.zero_tensor_from_nested_spec(input_spec, batch_size=1) input_preprocessors = dict( a=EmbeddingPreprocessor(input_spec["a"], conv_layer_params=((1, 2, 2, 0), ), embedding_dim=100), b=[ EmbeddingPreprocessor(input_spec["b"][0], embedding_dim=50), EmbeddingPreprocessor(input_spec["b"][1], embedding_dim=50), dict(x=None, y=torch.relu) ]) if lstm: network_ctor = functools.partial(LSTMEncodingNetwork, hidden_size=(100, )) else: network_ctor = EncodingNetwork network = network_ctor(input_tensor_spec=input_spec, input_preprocessors=input_preprocessors, preprocessing_combiner=NestConcat()) output, _ = network(imgs, state=[(torch.zeros((1, 100)), ) * 2]) if lstm: self.assertEqual(network.output_spec, TensorSpec((100, ))) self.assertEqual(output.size()[-1], 100) else: self.assertEqual(len(list(network.parameters())), 4 + 2 + 1) self.assertEqual(network.output_spec, TensorSpec((500, ))) self.assertEqual(output.size()[-1], 500)
def test_frame_stacker(self, stack_axis=0): data_spec = DataItem(step_type=alf.TensorSpec((), dtype=torch.int32), observation=dict(scalar=alf.TensorSpec(()), vector=alf.TensorSpec((7, )), matrix=alf.TensorSpec((5, 6)), tensor=alf.TensorSpec( (2, 3, 4)))) replay_buffer = ReplayBuffer(data_spec=data_spec, num_environments=2, max_length=1024, num_earliest_frames_ignored=2) frame_stacker = FrameStacker( data_spec.observation, stack_size=3, stack_axis=stack_axis, fields=['scalar', 'vector', 'matrix', 'tensor']) new_spec = frame_stacker.transformed_observation_spec self.assertEqual(new_spec['scalar'].shape, (3, )) self.assertEqual(new_spec['vector'].shape, (21, )) if stack_axis == -1: self.assertEqual(new_spec['matrix'].shape, (5, 18)) self.assertEqual(new_spec['tensor'].shape, (2, 3, 12)) elif stack_axis == 0: self.assertEqual(new_spec['matrix'].shape, (15, 6)) self.assertEqual(new_spec['tensor'].shape, (6, 3, 4)) def _step_type(t, period): if t % period == 0: return StepType.FIRST if t % period == period - 1: return StepType.LAST return StepType.MID observation = alf.nest.map_structure( lambda spec: spec.randn((1000, 2)), data_spec.observation) state = common.zero_tensor_from_nested_spec(frame_stacker.state_spec, 2) def _get_stacked_data(t, b): if stack_axis == -1: return dict(scalar=observation['scalar'][t, b], vector=observation['vector'][t, b].reshape(-1), matrix=observation['matrix'][t, b].transpose( 0, 1).reshape(5, 18), tensor=observation['tensor'][t, b].permute( 1, 2, 0, 3).reshape(2, 3, 12)) elif stack_axis == 0: return dict(scalar=observation['scalar'][t, b], vector=observation['vector'][t, b].reshape(-1), matrix=observation['matrix'][t, b].reshape(15, 6), tensor=observation['tensor'][t, b].reshape(6, 3, 4)) def _check_equal(stacked, expected, b): self.assertEqual(stacked['scalar'][b], expected['scalar']) self.assertEqual(stacked['vector'][b], expected['vector']) self.assertEqual(stacked['matrix'][b], expected['matrix']) self.assertEqual(stacked['tensor'][b], expected['tensor']) for t in range(1000): batch = DataItem( step_type=torch.tensor([_step_type(t, 17), _step_type(t, 22)]), observation=alf.nest.map_structure(lambda x: x[t], observation)) replay_buffer.add_batch(batch) timestep, state = frame_stacker.transform_timestep(batch, state) if t == 0: for b in (0, 1): expected = _get_stacked_data([0, 0, 0], b) _check_equal(timestep.observation, expected, b) if t == 1: for b in (0, 1): expected = _get_stacked_data([0, 0, 1], b) _check_equal(timestep.observation, expected, b) if t == 2: for b in (0, 1): expected = _get_stacked_data([0, 1, 2], b) _check_equal(timestep.observation, expected, b) if t == 16: for b in (0, 1): expected = _get_stacked_data([14, 15, 16], b) _check_equal(timestep.observation, expected, b) if t == 17: for b, t in ((0, [17, 17, 17]), (1, [15, 16, 17])): expected = _get_stacked_data(t, b) _check_equal(timestep.observation, expected, b) if t == 18: for b, t in ((0, [17, 17, 18]), (1, [16, 17, 18])): expected = _get_stacked_data(t, b) _check_equal(timestep.observation, expected, b) if t == 22: for b, t in ((0, [20, 21, 22]), (1, [22, 22, 22])): expected = _get_stacked_data(t, b) _check_equal(timestep.observation, expected, b) batch_info = BatchInfo(env_ids=torch.tensor([0, 1, 0, 1], dtype=torch.int64), positions=torch.tensor([0, 1, 18, 22], dtype=torch.int64)) # [4, 2, ...] experience = replay_buffer.get_field( '', batch_info.env_ids.unsqueeze(-1), batch_info.positions.unsqueeze(-1) + torch.arange(2)) experience = experience._replace(batch_info=batch_info, replay_buffer=replay_buffer) experience = frame_stacker.transform_experience(experience) expected = _get_stacked_data([0, 0, 0], 0) _check_equal(experience.observation, expected, (0, 0)) expected = _get_stacked_data([0, 0, 1], 0) _check_equal(experience.observation, expected, (0, 1)) expected = _get_stacked_data([0, 0, 1], 1) _check_equal(experience.observation, expected, (1, 0)) expected = _get_stacked_data([0, 1, 2], 1) _check_equal(experience.observation, expected, (1, 1)) expected = _get_stacked_data([17, 17, 18], 0) _check_equal(experience.observation, expected, (2, 0)) expected = _get_stacked_data([17, 18, 19], 0) _check_equal(experience.observation, expected, (2, 1)) expected = _get_stacked_data([22, 22, 22], 1) _check_equal(experience.observation, expected, (3, 0)) expected = _get_stacked_data([22, 22, 23], 1) _check_equal(experience.observation, expected, (3, 1))
def _prepare_specs(self, algorithm): """Prepare various tensor specs.""" def extract_spec(nest): return tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype), nest) time_step = self.get_initial_time_step() self._time_step_spec = extract_spec(time_step) self._action_spec = self._env.action_spec() policy_step = algorithm.predict(time_step, self._initial_state) info_spec = extract_spec(policy_step.info) self._pred_policy_step_spec = PolicyStep( action=self._action_spec, state=algorithm.predict_state_spec, info=info_spec) def _to_distribution_spec(spec): if isinstance(spec, tf.TensorSpec): return DistributionSpec(tfp.distributions.Deterministic, input_params_spec={"loc": spec}, sample_spec=spec) return spec self._action_distribution_spec = tf.nest.map_structure( _to_distribution_spec, algorithm.action_distribution_spec) self._action_dist_param_spec = tf.nest.map_structure( lambda spec: spec.input_params_spec, self._action_distribution_spec) self._experience_spec = Experience( step_type=self._time_step_spec.step_type, reward=self._time_step_spec.reward, discount=self._time_step_spec.discount, observation=self._time_step_spec.observation, prev_action=self._action_spec, action=self._action_spec, info=info_spec, action_distribution=self._action_dist_param_spec) action_dist_params = common.zero_tensor_from_nested_spec( self._experience_spec.action_distribution, self._env.batch_size) action_dist = nested_distributions_from_specs( self._action_distribution_spec, action_dist_params) exp = Experience(step_type=time_step.step_type, reward=time_step.reward, discount=time_step.discount, observation=time_step.observation, prev_action=time_step.prev_action, action=time_step.prev_action, info=policy_step.info, action_distribution=action_dist) processed_exp = algorithm.preprocess_experience(exp) self._processed_experience_spec = self._experience_spec._replace( info=extract_spec(processed_exp.info)) policy_step = common.algorithm_step( algorithm, ob_transformer=self._observation_transformer, time_step=exp, state=common.get_initial_policy_state(self._env.batch_size, algorithm.train_state_spec), training=True) info_spec = extract_spec(policy_step.info) self._training_info_spec = make_training_info( action=self._action_spec, action_distribution=self._action_dist_param_spec, step_type=self._time_step_spec.step_type, reward=self._time_step_spec.reward, discount=self._time_step_spec.discount, info=info_spec, collect_info=self._processed_experience_spec.info, collect_action_distribution=self._action_dist_param_spec)
def get_initial_train_state(self, batch_size): """Always return the training state spec.""" return common.zero_tensor_from_nested_spec( self._algorithm.train_state_spec, batch_size)
def _test_preprocess_experience(self, train_reward_function, td_steps, reanalyze_ratio, expected): """ The following summarizes how the data is generated: .. code-block:: python # position: 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' scale = 1. for current model 2. for target model observation = [position] * 3 reward = position if train_reward_function and td_steps!=-1 else position * (step_type == LAST) value = 0.5 * position * scale action_probs = scale * [position, position+1, position] for env 0 scale * [position+1, position, position] for env 1 action = 1 for env 0 0 for env 1 """ reanalyze_td_steps = 2 num_unroll_steps = 4 batch_size = 2 obs_dim = 3 observation_spec = alf.TensorSpec([obs_dim]) action_spec = alf.BoundedTensorSpec((), minimum=0, maximum=1, dtype=torch.int32) reward_spec = alf.TensorSpec(()) time_step_spec = ds.time_step_spec(observation_spec, action_spec, reward_spec) global _mcts_model_id _mcts_model_id = 0 muzero = MuzeroAlgorithm(observation_spec, action_spec, model_ctor=_create_mcts_model, mcts_algorithm_ctor=MockMCTSAlgorithm, num_unroll_steps=num_unroll_steps, td_steps=td_steps, train_game_over_function=True, train_reward_function=train_reward_function, reanalyze_ratio=reanalyze_ratio, reanalyze_td_steps=reanalyze_td_steps, data_transformer_ctor=partial(FrameStacker, stack_size=2)) data_transformer = FrameStacker(observation_spec, stack_size=2) time_step = common.zero_tensor_from_nested_spec( time_step_spec, batch_size) dt_state = common.zero_tensor_from_nested_spec( data_transformer.state_spec, batch_size) state = muzero.get_initial_predict_state(batch_size) transformed_time_step, dt_state = data_transformer.transform_timestep( time_step, dt_state) alg_step = muzero.rollout_step(transformed_time_step, state) alg_step_spec = dist_utils.extract_spec(alg_step) experience = ds.make_experience(time_step, alg_step, state) experience_spec = ds.make_experience(time_step_spec, alg_step_spec, muzero.train_state_spec) replay_buffer = ReplayBuffer(data_spec=experience_spec, num_environments=batch_size, max_length=16, keep_episodic_info=True) # 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' dt_state = common.zero_tensor_from_nested_spec( data_transformer.state_spec, batch_size) for i in range(len(step_type0)): step_type = [step_type0[i], step_type1[i]] step_type = [ ds.StepType.MID if c == 'M' else (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST) for c in step_type ] step_type = torch.tensor(step_type, dtype=torch.int32) reward = reward = torch.full([batch_size], float(i)) if not train_reward_function or td_steps == -1: reward = reward * (step_type == ds.StepType.LAST).to( torch.float32) time_step = time_step._replace( discount=(step_type != ds.StepType.LAST).to(torch.float32), step_type=step_type, observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]], dtype=torch.float32), reward=reward, env_id=torch.arange(batch_size, dtype=torch.int32)) transformed_time_step, dt_state = data_transformer.transform_timestep( time_step, dt_state) alg_step = muzero.rollout_step(transformed_time_step, state) experience = ds.make_experience(time_step, alg_step, state) replay_buffer.add_batch(experience) state = alg_step.state env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64) positions = torch.tensor(list(range(14)) + list(range(14)), dtype=torch.int64) experience = replay_buffer.get_field(None, env_ids.unsqueeze(-1).cpu(), positions.unsqueeze(-1).cpu()) experience = experience._replace(replay_buffer=replay_buffer, batch_info=BatchInfo( env_ids=env_ids, positions=positions), rollout_info_field='rollout_info') processed_experience = muzero.preprocess_experience(experience) import pprint pprint.pprint(processed_experience.rollout_info) alf.nest.map_structure(lambda x, y: self.assertEqual(x, y), processed_experience.rollout_info, expected)
def prepare_off_policy_specs(self, time_step: ActionTimeStep): """Prepare various tensor specs for off_policy training. prepare_off_policy_specs is called by OffPolicyDriver._prepare_spec(). """ self._env_batch_size = time_step.step_type.shape[0] self._time_step_spec = common.extract_spec(time_step) initial_state = common.get_initial_policy_state( self._env_batch_size, self.train_state_spec) transformed_timestep = self.transform_timestep(time_step) policy_step = self.rollout(transformed_timestep, initial_state) info_spec = common.extract_spec(policy_step.info) self._action_distribution_spec = tf.nest.map_structure( common.to_distribution_spec, self.action_distribution_spec) self._action_dist_param_spec = tf.nest.map_structure( lambda spec: spec.input_params_spec, self._action_distribution_spec) self._experience_spec = Experience( step_type=self._time_step_spec.step_type, reward=self._time_step_spec.reward, discount=self._time_step_spec.discount, observation=self._time_step_spec.observation, prev_action=self._action_spec, action=self._action_spec, info=info_spec, action_distribution=self._action_dist_param_spec, state=self.train_state_spec if self._use_rollout_state else ()) action_dist_params = common.zero_tensor_from_nested_spec( self._experience_spec.action_distribution, self._env_batch_size) action_dist = nested_distributions_from_specs( self._action_distribution_spec, action_dist_params) exp = Experience(step_type=time_step.step_type, reward=time_step.reward, discount=time_step.discount, observation=time_step.observation, prev_action=time_step.prev_action, action=time_step.prev_action, info=policy_step.info, action_distribution=action_dist, state=initial_state if self._use_rollout_state else ()) transformed_exp = self.transform_timestep(exp) processed_exp = self.preprocess_experience(transformed_exp) self._processed_experience_spec = self._experience_spec._replace( observation=common.extract_spec(processed_exp.observation), info=common.extract_spec(processed_exp.info)) policy_step = common.algorithm_step( algorithm_step_func=self.train_step, time_step=processed_exp, state=initial_state) info_spec = common.extract_spec(policy_step.info) self._training_info_spec = TrainingInfo( action_distribution=self._action_dist_param_spec, info=info_spec)
def _calc_change(self, exp_array): """Calculate the distance between old/new action distributions. The squared distance is: ||logits_1 - logits_2||^2 for Categorical distribution ||loc_1 - loc_2||^2 for Deterministic distribution KL(d1||d2) + KL(d2||d1) for others """ def _dist(d1, d2): if isinstance(d1, tfp.distributions.Categorical): dist = tf.square(d1.logits - d2.logits) elif isinstance(d1, tf.Tensor): dist = tf.square(d1 - d2) else: if isinstance(d1, SquashToSpecNormal): # TODO `SquashToSpecNormal.kl_divergence` checks that two # distributions should have same action mean and magnitude, # but this check fails in graph mode d1 = d1.input_distribution d2 = d2.input_distribution dist = d1.kl_divergence(d2) + d2.kl_divergence(d1) if len(dist.shape) > 1: # reduce to shape [B] reduce_dims = list(range(1, len(dist.shape))) dist = tf.reduce_sum(dist, axis=reduce_dims) return dist def _update_total_dists(new_action, exp, total_dists): old_action = nest_utils.params_to_distributions( exp.action_param, self._action_distribution_spec) dists = nest_map(_dist, old_action, new_action) valid_masks = tf.cast(tf.not_equal(exp.step_type, StepType.LAST), tf.float32) dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists) return nest_map(lambda x, y: x + y, total_dists, dists) num_steps = exp_array.step_type.size() # element_shape for `TensorArray` can be (None, ...) batch_size = tf.shape(exp_array.step_type.read(0))[0] state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state) # exp_array.state is no longer needed exp_array = exp_array._replace(state=()) initial_state = common.zero_tensor_from_nested_spec( self.predict_state_spec, batch_size) total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec) for t in tf.range(num_steps): exp = tf.nest.map_structure(lambda x: x.read(t), exp_array) state = common.reset_state_if_necessary( state, initial_state, exp.step_type == StepType.FIRST) time_step = ActionTimeStep(observation=exp.observation, step_type=exp.step_type) policy_step = self._ac_algorithm.predict(time_step=time_step, state=state, epsilon_greedy=1.0) assert ( common.is_namedtuple(policy_step.info) and "action_distribution" in policy_step.info._fields ), ("PolicyStep.info from ac_algorithm.predict() should be " "a namedtuple containing `action_distribution` in order to " "use TracAlgorithm.") new_action = policy_step.info.action_distribution state = policy_step.state total_dists = _update_total_dists(new_action, exp, total_dists) size = tf.cast(num_steps * batch_size, tf.float32) total_dists = nest_map(lambda d: tf.sqrt(d / size), total_dists) return total_dists
def __init__(self, num_expansions, model_output, known_value_bounds): batch_size, branch_factor = model_output.action_probs.shape action_spec = dist_utils.extract_spec(model_output.actions, from_dim=2) state_spec = dist_utils.extract_spec(model_output.state, from_dim=1) if known_value_bounds: self.fixed_bounds = True self.minimum, self.maximum = known_value_bounds else: self.fixed_bounds = False self.minimum, self.maximum = MAXIMUM_FLOAT_VALUE, -MAXIMUM_FLOAT_VALUE self.minimum = torch.full((batch_size, ), self.minimum, dtype=torch.float32) self.maximum = torch.full((batch_size, ), self.maximum, dtype=torch.float32) if known_value_bounds: self.normalize_scale = 1 / (self.maximum - self.minimum + 1e-30) self.normalize_base = self.minimum else: self.normalize_scale = torch.ones((batch_size, )) self.normalize_base = torch.zeros((batch_size, )) self.B = torch.arange(batch_size) self.root_indices = torch.zeros((batch_size, ), dtype=torch.int64) self.branch_factor = branch_factor parent_shape = (batch_size, num_expansions) children_shape = (batch_size, num_expansions, branch_factor) self.visit_count = torch.zeros(parent_shape, dtype=torch.int32) # the player who will take action from the current state self.to_play = torch.zeros(parent_shape, dtype=torch.int64) self.prior = torch.zeros(children_shape) # value for player 0 self.value_sum = torch.zeros(parent_shape) # 0 for not expanded, value in range [0, num_expansions) self.children_index = torch.zeros(children_shape, dtype=torch.int64) self.model_state = common.zero_tensor_from_nested_spec( state_spec, parent_shape) # reward for player 0 self.reward = None if isinstance(model_output.reward, torch.Tensor): self.reward = torch.zeros(parent_shape) self.action = None if isinstance(model_output.actions, torch.Tensor): # candidate actions for this state self.action = torch.zeros( children_shape + action_spec.shape, dtype=action_spec.dtype) self.game_over = None if isinstance(model_output.game_over, torch.Tensor): self.game_over = torch.zeros(parent_shape, dtype=torch.bool) # value in range [0, branch_factor) self.best_child_index = torch.zeros(parent_shape, dtype=torch.int64) self.ucb_score = torch.zeros(children_shape)
def test_preprocess_experience(self): """ The following summarizes how the data is generated: .. code-block:: python # position: 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' reward = position if train_reward_function and td_steps!=-1 else position * (step_type == LAST) action = t + 1 for env 0 t for env 1 """ num_unroll_steps = 4 batch_size = 2 obs_dim = 3 observation_spec = alf.TensorSpec([obs_dim]) action_spec = alf.BoundedTensorSpec((1, ), minimum=0, maximum=1, dtype=torch.float32) reward_spec = alf.TensorSpec(()) time_step_spec = ds.time_step_spec(observation_spec, action_spec, reward_spec) repr_learner = PredictiveRepresentationLearner( observation_spec, action_spec, num_unroll_steps=num_unroll_steps, decoder_ctor=partial(SimpleDecoder, target_field='reward', decoder_net_ctor=partial( EncodingNetwork, fc_layer_params=(4, ))), encoding_net_ctor=LSTMEncodingNetwork, dynamics_net_ctor=LSTMEncodingNetwork) time_step = common.zero_tensor_from_nested_spec( time_step_spec, batch_size) state = repr_learner.get_initial_predict_state(batch_size) alg_step = repr_learner.rollout_step(time_step, state) alg_step = alg_step._replace(output=torch.tensor([[1.], [0.]])) alg_step_spec = dist_utils.extract_spec(alg_step) experience = ds.make_experience(time_step, alg_step, state) experience_spec = ds.make_experience(time_step_spec, alg_step_spec, repr_learner.train_state_spec) replay_buffer = ReplayBuffer(data_spec=experience_spec, num_environments=batch_size, max_length=16, keep_episodic_info=True) # 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' for i in range(len(step_type0)): step_type = [step_type0[i], step_type1[i]] step_type = [ ds.StepType.MID if c == 'M' else (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST) for c in step_type ] step_type = torch.tensor(step_type, dtype=torch.int32) reward = reward = torch.full([batch_size], float(i)) time_step = time_step._replace( discount=(step_type != ds.StepType.LAST).to(torch.float32), step_type=step_type, observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]], dtype=torch.float32), reward=reward, env_id=torch.arange(batch_size, dtype=torch.int32)) alg_step = repr_learner.rollout_step(time_step, state) alg_step = alg_step._replace(output=i + torch.tensor([[1.], [0.]])) experience = ds.make_experience(time_step, alg_step, state) replay_buffer.add_batch(experience) state = alg_step.state env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64) positions = torch.tensor(list(range(14)) + list(range(14)), dtype=torch.int64) experience = replay_buffer.get_field(None, env_ids.unsqueeze(-1).cpu(), positions.unsqueeze(-1).cpu()) experience = experience._replace(replay_buffer=replay_buffer, batch_info=BatchInfo( env_ids=env_ids, positions=positions), rollout_info_field='rollout_info') processed_experience = repr_learner.preprocess_experience(experience) pprint.pprint(processed_experience.rollout_info) # yapf: disable expected = PredictiveRepresentationLearnerInfo( action=torch.tensor( [[[ 1., 2., 3., 4., 5.]], [[ 2., 3., 4., 5., 5.]], [[ 3., 4., 5., 5., 5.]], [[ 4., 5., 5., 5., 5.]], [[ 5., 5., 5., 5., 5.]], [[ 6., 7., 8., 9., 9.]], [[ 7., 8., 9., 9., 9.]], [[ 8., 9., 9., 9., 9.]], [[ 9., 9., 9., 9., 9.]], [[10., 11., 12., 13., 14.]], [[11., 12., 13., 14., 14.]], [[12., 13., 14., 14., 14.]], [[13., 14., 14., 14., 14.]], [[14., 14., 14., 14., 14.]], [[ 0., 1., 2., 3., 4.]], [[ 1., 2., 3., 4., 5.]], [[ 2., 3., 4., 5., 6.]], [[ 3., 4., 5., 6., 6.]], [[ 4., 5., 6., 6., 6.]], [[ 5., 6., 6., 6., 6.]], [[ 6., 6., 6., 6., 6.]], [[ 7., 8., 9., 10., 11.]], [[ 8., 9., 10., 11., 12.]], [[ 9., 10., 11., 12., 12.]], [[10., 11., 12., 12., 12.]], [[11., 12., 12., 12., 12.]], [[12., 12., 12., 12., 12.]], [[13., 13., 13., 13., 13.]]]).unsqueeze(-1), mask=torch.tensor( [[[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, True]], [[ True, True, True, True, True]], [[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, True]], [[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, False, False, False, False]]]), target=torch.tensor( [[[ 0., 1., 2., 3., 4.]], [[ 1., 2., 3., 4., 4.]], [[ 2., 3., 4., 4., 4.]], [[ 3., 4., 4., 4., 4.]], [[ 4., 4., 4., 4., 4.]], [[ 5., 6., 7., 8., 8.]], [[ 6., 7., 8., 8., 8.]], [[ 7., 8., 8., 8., 8.]], [[ 8., 8., 8., 8., 8.]], [[ 9., 10., 11., 12., 13.]], [[10., 11., 12., 13., 13.]], [[11., 12., 13., 13., 13.]], [[12., 13., 13., 13., 13.]], [[13., 13., 13., 13., 13.]], [[ 0., 1., 2., 3., 4.]], [[ 1., 2., 3., 4., 5.]], [[ 2., 3., 4., 5., 6.]], [[ 3., 4., 5., 6., 6.]], [[ 4., 5., 6., 6., 6.]], [[ 5., 6., 6., 6., 6.]], [[ 6., 6., 6., 6., 6.]], [[ 7., 8., 9., 10., 11.]], [[ 8., 9., 10., 11., 12.]], [[ 9., 10., 11., 12., 12.]], [[10., 11., 12., 12., 12.]], [[11., 12., 12., 12., 12.]], [[12., 12., 12., 12., 12.]], [[13., 13., 13., 13., 13.]]])) # yapf: enable alf.nest.map_structure(lambda x, y: self.assertEqual(x, y), processed_experience.rollout_info, expected)