def test_conversions(self): dists = { 't': torch.tensor([[1., 2., 4.], [3., 3., 1.]]), 'd': dist_utils.DiagMultivariateNormal( torch.tensor([[1., 2.], [2., 2.]]), torch.tensor([[2., 3.], [1., 1.]])) } params = dist_utils.distributions_to_params(dists) dists_spec = dist_utils.extract_spec(dists, from_dim=1) self.assertEqual(dists_spec['t'], alf.TensorSpec(shape=(3, ), dtype=torch.float32)) self.assertEqual(type(dists_spec['d']), dist_utils.DistributionSpec) self.assertEqual(len(params), 2) self.assertEqual(dists['t'], params['t']) self.assertEqual(dists['d'].base_dist.mean, params['d']['loc']) self.assertEqual(dists['d'].base_dist.stddev, params['d']['scale']) dists1 = dist_utils.params_to_distributions(params, dists_spec) self.assertEqual(len(dists1), 2) self.assertEqual(dists1['t'], dists['t']) self.assertEqual(type(dists1['d']), type(dists['d'])) params_spec = dist_utils.to_distribution_param_spec(dists_spec) alf.nest.assert_same_structure(params_spec, params) params1_spec = dist_utils.extract_spec(params) self.assertEqual(params_spec, params1_spec)
def train_step(self, exp: TimeStep, state): # [B, num_unroll_steps + 1] info = exp.rollout_info targets = common.as_list(info.target) batch_size = exp.step_type.shape[0] latent, state = self._encoding_net(exp.observation, state) sim_latent = self._multi_step_latent_rollout(latent, self._num_unroll_steps, info.action, state) loss = 0 for i, decoder in enumerate(self._decoders): # [num_unroll_steps + 1)*B, ...] train_info = decoder.train_step(sim_latent).info train_info_spec = dist_utils.extract_spec(train_info) train_info = dist_utils.distributions_to_params(train_info) train_info = alf.nest.map_structure( lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x. shape[1:]), train_info) # [num_unroll_steps + 1, B, ...] train_info = dist_utils.params_to_distributions( train_info, train_info_spec) target = alf.nest.map_structure(lambda x: x.transpose(0, 1), targets[i]) loss_info = decoder.calc_loss(target, train_info, info.mask.t()) loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info) loss += loss_info.loss loss_info = LossInfo(loss=loss, extra=loss) return AlgStep(output=latent, state=state, info=loss_info)
def train_step(self, exp: Experience, state): def _hook(grad, name): alf.summary.scalar("MCTS_state_grad_norm/" + name, grad.norm()) model_output = self._model.initial_inference(exp.observation) if alf.summary.should_record_summaries(): model_output.state.register_hook(partial(_hook, name="s0")) model_output_spec = dist_utils.extract_spec(model_output) model_outputs = [dist_utils.distributions_to_params(model_output)] info = exp.rollout_info for i in range(self._num_unroll_steps): model_output = self._model.recurrent_inference( model_output.state, info.action[:, i, ...]) if alf.summary.should_record_summaries(): model_output.state.register_hook( partial(_hook, name="s" + str(i + 1))) model_output = model_output._replace(state=scale_gradient( model_output.state, self._recurrent_gradient_scaling_factor)) model_outputs.append( dist_utils.distributions_to_params(model_output)) model_outputs = alf.nest.utils.stack_nests(model_outputs, dim=1) model_outputs = dist_utils.params_to_distributions( model_outputs, model_output_spec) return AlgStep(info=self._model.calc_loss(model_outputs, info.target))
def _rollout_step(self, time_step: TimeStep, state): """A wrapper around user-defined ``rollout_step``. For every rl algorithm, this wrapper ensures that the rollout info spec will be computed. """ policy_step = self._original_rollout_step(time_step, state) if self._rollout_info_spec is None: self._rollout_info_spec = dist_utils.extract_spec(policy_step.info) return policy_step
def output_spec(self): """Return the spec of the network's encoding output. By default, we use `_test_forward` to automatically compute the output and get its spec. For efficiency, subclasses can overwrite this function if the output spec can be obtained easily in other ways. """ if self._output_spec is None: training = self.training self.eval() self._output_spec = extract_spec(self._test_forward()[0], from_dim=1) self.train(training) return self._output_spec
def _prepare_reanalyze_data(self, replay_buffer: ReplayBuffer, env_ids, positions, n1, n2): """ Get the n1 + n2 steps of experience indicated by ``positions`` and return as the first n1 as ``exp1`` and the next n2 steps as ``exp2``. """ batch_size = env_ids.shape[0] n = n1 + n2 flat_env_ids = env_ids.expand_as(positions).reshape(-1) flat_positions = positions.reshape(-1) exp = replay_buffer.get_field(None, flat_env_ids, flat_positions) if self._data_transformer_ctor is not None: if self._data_transformer is None: observation_spec = dist_utils.extract_spec(exp.observation) self._data_transformer = create_data_transformer( self._data_transformer_ctor, observation_spec) # DataTransformer assumes the shape of exp is [B, T, ...] # It also needs exp.batch_info and exp.replay_buffer. exp = alf.nest.map_structure(lambda x: x.unsqueeze(1), exp) exp = exp._replace(batch_info=BatchInfo(flat_env_ids, flat_positions), replay_buffer=replay_buffer) exp = self._data_transformer.transform_experience(exp) exp = exp._replace(batch_info=(), replay_buffer=()) exp = alf.nest.map_structure(lambda x: x.squeeze(1), exp) def _split1(x): shape = x.shape[1:] x = x.reshape(batch_size, n, *shape) return x[:, :n1, ...].reshape(batch_size * n1, *shape) def _split2(x): shape = x.shape[1:] x = x.reshape(batch_size, n, *shape) return x[:, n1:, ...].reshape(batch_size * n2, *shape) with alf.device(self._device): exp = convert_device(exp) exp1 = alf.nest.map_structure(_split1, exp) exp2 = alf.nest.map_structure(_split2, exp) return exp1, exp2
def _make_policy_step(self, time_step, state, policy_step): assert ( alf.nest.is_namedtuple(policy_step.info) and "action_distribution" in policy_step.info._fields), ( "PolicyStep.info from ac_algorithm.rollout_step() or " "ac_algorithm.train_step() should be a namedtuple containing " "`action_distribution` in order to use TracAlgorithm.") action_distribution = policy_step.info.action_distribution if self._action_distribution_spec is None: self._action_distribution_spec = dist_utils.extract_spec( action_distribution) ac_info = policy_step.info._replace(action_distribution=()) # EntropyTargetAlgorithm need info.action_distribution return policy_step._replace( info=TracInfo(action_distribution=action_distribution, observation=time_step.observation, prev_action=time_step.prev_action, state=self._ac_algorithm. convert_train_state_to_predict_state(state), ac=ac_info))
def train_step(self, exp: TimeStep, state): # [B, num_unroll_steps + 1] info = exp.rollout_info batch_size = exp.step_type.shape[0] latent, state = self._encoding_net(exp.observation, state) sim_latents = [latent] if self._num_unroll_steps > 0: if self._latent_to_dstate_fc is not None: dstate = self._latent_to_dstate_fc(latent) dstate = dstate.split(self._dynamics_state_dims, dim=1) dstate = alf.nest.pack_sequence_as( self._dynamics_net.state_spec, dstate) else: dstate = state for i in range(self._num_unroll_steps): sim_latent, dstate = self._dynamics_net(info.action[:, i, ...], dstate) sim_latents.append(sim_latent) sim_latent = torch.cat(sim_latents, dim=0) # [num_unroll_steps + 1)*B, ...] train_info = self._decoder.train_step(sim_latent).info train_info_spec = dist_utils.extract_spec(train_info) train_info = dist_utils.distributions_to_params(train_info) train_info = alf.nest.map_structure( lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x. shape[1:]), train_info) # [num_unroll_steps + 1, B, ...] train_info = dist_utils.params_to_distributions( train_info, train_info_spec) target = alf.nest.map_structure(lambda x: x.transpose(0, 1), info.target) loss_info = self._decoder.calc_loss(target, train_info, info.mask.t()) loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info) return AlgStep(output=latent, state=state, info=loss_info)
def _test_preprocess_experience(self, train_reward_function, td_steps, reanalyze_ratio, expected): """ The following summarizes how the data is generated: .. code-block:: python # position: 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' scale = 1. for current model 2. for target model observation = [position] * 3 reward = position if train_reward_function and td_steps!=-1 else position * (step_type == LAST) value = 0.5 * position * scale action_probs = scale * [position, position+1, position] for env 0 scale * [position+1, position, position] for env 1 action = 1 for env 0 0 for env 1 """ reanalyze_td_steps = 2 num_unroll_steps = 4 batch_size = 2 obs_dim = 3 observation_spec = alf.TensorSpec([obs_dim]) action_spec = alf.BoundedTensorSpec((), minimum=0, maximum=1, dtype=torch.int32) reward_spec = alf.TensorSpec(()) time_step_spec = ds.time_step_spec(observation_spec, action_spec, reward_spec) global _mcts_model_id _mcts_model_id = 0 muzero = MuzeroAlgorithm(observation_spec, action_spec, model_ctor=_create_mcts_model, mcts_algorithm_ctor=MockMCTSAlgorithm, num_unroll_steps=num_unroll_steps, td_steps=td_steps, train_game_over_function=True, train_reward_function=train_reward_function, reanalyze_ratio=reanalyze_ratio, reanalyze_td_steps=reanalyze_td_steps, data_transformer_ctor=partial(FrameStacker, stack_size=2)) data_transformer = FrameStacker(observation_spec, stack_size=2) time_step = common.zero_tensor_from_nested_spec( time_step_spec, batch_size) dt_state = common.zero_tensor_from_nested_spec( data_transformer.state_spec, batch_size) state = muzero.get_initial_predict_state(batch_size) transformed_time_step, dt_state = data_transformer.transform_timestep( time_step, dt_state) alg_step = muzero.rollout_step(transformed_time_step, state) alg_step_spec = dist_utils.extract_spec(alg_step) experience = ds.make_experience(time_step, alg_step, state) experience_spec = ds.make_experience(time_step_spec, alg_step_spec, muzero.train_state_spec) replay_buffer = ReplayBuffer(data_spec=experience_spec, num_environments=batch_size, max_length=16, keep_episodic_info=True) # 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' dt_state = common.zero_tensor_from_nested_spec( data_transformer.state_spec, batch_size) for i in range(len(step_type0)): step_type = [step_type0[i], step_type1[i]] step_type = [ ds.StepType.MID if c == 'M' else (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST) for c in step_type ] step_type = torch.tensor(step_type, dtype=torch.int32) reward = reward = torch.full([batch_size], float(i)) if not train_reward_function or td_steps == -1: reward = reward * (step_type == ds.StepType.LAST).to( torch.float32) time_step = time_step._replace( discount=(step_type != ds.StepType.LAST).to(torch.float32), step_type=step_type, observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]], dtype=torch.float32), reward=reward, env_id=torch.arange(batch_size, dtype=torch.int32)) transformed_time_step, dt_state = data_transformer.transform_timestep( time_step, dt_state) alg_step = muzero.rollout_step(transformed_time_step, state) experience = ds.make_experience(time_step, alg_step, state) replay_buffer.add_batch(experience) state = alg_step.state env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64) positions = torch.tensor(list(range(14)) + list(range(14)), dtype=torch.int64) experience = replay_buffer.get_field(None, env_ids.unsqueeze(-1).cpu(), positions.unsqueeze(-1).cpu()) experience = experience._replace(replay_buffer=replay_buffer, batch_info=BatchInfo( env_ids=env_ids, positions=positions), rollout_info_field='rollout_info') processed_experience = muzero.preprocess_experience(experience) import pprint pprint.pprint(processed_experience.rollout_info) alf.nest.map_structure(lambda x, y: self.assertEqual(x, y), processed_experience.rollout_info, expected)
def __init__(self, num_expansions, model_output, known_value_bounds): batch_size, branch_factor = model_output.action_probs.shape action_spec = dist_utils.extract_spec(model_output.actions, from_dim=2) state_spec = dist_utils.extract_spec(model_output.state, from_dim=1) if known_value_bounds: self.fixed_bounds = True self.minimum, self.maximum = known_value_bounds else: self.fixed_bounds = False self.minimum, self.maximum = MAXIMUM_FLOAT_VALUE, -MAXIMUM_FLOAT_VALUE self.minimum = torch.full((batch_size, ), self.minimum, dtype=torch.float32) self.maximum = torch.full((batch_size, ), self.maximum, dtype=torch.float32) if known_value_bounds: self.normalize_scale = 1 / (self.maximum - self.minimum + 1e-30) self.normalize_base = self.minimum else: self.normalize_scale = torch.ones((batch_size, )) self.normalize_base = torch.zeros((batch_size, )) self.B = torch.arange(batch_size) self.root_indices = torch.zeros((batch_size, ), dtype=torch.int64) self.branch_factor = branch_factor parent_shape = (batch_size, num_expansions) children_shape = (batch_size, num_expansions, branch_factor) self.visit_count = torch.zeros(parent_shape, dtype=torch.int32) # the player who will take action from the current state self.to_play = torch.zeros(parent_shape, dtype=torch.int64) self.prior = torch.zeros(children_shape) # value for player 0 self.value_sum = torch.zeros(parent_shape) # 0 for not expanded, value in range [0, num_expansions) self.children_index = torch.zeros(children_shape, dtype=torch.int64) self.model_state = common.zero_tensor_from_nested_spec( state_spec, parent_shape) # reward for player 0 self.reward = None if isinstance(model_output.reward, torch.Tensor): self.reward = torch.zeros(parent_shape) self.action = None if isinstance(model_output.actions, torch.Tensor): # candidate actions for this state self.action = torch.zeros( children_shape + action_spec.shape, dtype=action_spec.dtype) self.game_over = None if isinstance(model_output.game_over, torch.Tensor): self.game_over = torch.zeros(parent_shape, dtype=torch.bool) # value in range [0, branch_factor) self.best_child_index = torch.zeros(parent_shape, dtype=torch.int64) self.ucb_score = torch.zeros(children_shape)
def test_preprocess_experience(self): """ The following summarizes how the data is generated: .. code-block:: python # position: 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' reward = position if train_reward_function and td_steps!=-1 else position * (step_type == LAST) action = t + 1 for env 0 t for env 1 """ num_unroll_steps = 4 batch_size = 2 obs_dim = 3 observation_spec = alf.TensorSpec([obs_dim]) action_spec = alf.BoundedTensorSpec((1, ), minimum=0, maximum=1, dtype=torch.float32) reward_spec = alf.TensorSpec(()) time_step_spec = ds.time_step_spec(observation_spec, action_spec, reward_spec) repr_learner = PredictiveRepresentationLearner( observation_spec, action_spec, num_unroll_steps=num_unroll_steps, decoder_ctor=partial(SimpleDecoder, target_field='reward', decoder_net_ctor=partial( EncodingNetwork, fc_layer_params=(4, ))), encoding_net_ctor=LSTMEncodingNetwork, dynamics_net_ctor=LSTMEncodingNetwork) time_step = common.zero_tensor_from_nested_spec( time_step_spec, batch_size) state = repr_learner.get_initial_predict_state(batch_size) alg_step = repr_learner.rollout_step(time_step, state) alg_step = alg_step._replace(output=torch.tensor([[1.], [0.]])) alg_step_spec = dist_utils.extract_spec(alg_step) experience = ds.make_experience(time_step, alg_step, state) experience_spec = ds.make_experience(time_step_spec, alg_step_spec, repr_learner.train_state_spec) replay_buffer = ReplayBuffer(data_spec=experience_spec, num_environments=batch_size, max_length=16, keep_episodic_info=True) # 01234567890123 step_type0 = 'FMMMLFMMLFMMMM' step_type1 = 'FMMMMMLFMMMMLF' for i in range(len(step_type0)): step_type = [step_type0[i], step_type1[i]] step_type = [ ds.StepType.MID if c == 'M' else (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST) for c in step_type ] step_type = torch.tensor(step_type, dtype=torch.int32) reward = reward = torch.full([batch_size], float(i)) time_step = time_step._replace( discount=(step_type != ds.StepType.LAST).to(torch.float32), step_type=step_type, observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]], dtype=torch.float32), reward=reward, env_id=torch.arange(batch_size, dtype=torch.int32)) alg_step = repr_learner.rollout_step(time_step, state) alg_step = alg_step._replace(output=i + torch.tensor([[1.], [0.]])) experience = ds.make_experience(time_step, alg_step, state) replay_buffer.add_batch(experience) state = alg_step.state env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64) positions = torch.tensor(list(range(14)) + list(range(14)), dtype=torch.int64) experience = replay_buffer.get_field(None, env_ids.unsqueeze(-1).cpu(), positions.unsqueeze(-1).cpu()) experience = experience._replace(replay_buffer=replay_buffer, batch_info=BatchInfo( env_ids=env_ids, positions=positions), rollout_info_field='rollout_info') processed_experience = repr_learner.preprocess_experience(experience) pprint.pprint(processed_experience.rollout_info) # yapf: disable expected = PredictiveRepresentationLearnerInfo( action=torch.tensor( [[[ 1., 2., 3., 4., 5.]], [[ 2., 3., 4., 5., 5.]], [[ 3., 4., 5., 5., 5.]], [[ 4., 5., 5., 5., 5.]], [[ 5., 5., 5., 5., 5.]], [[ 6., 7., 8., 9., 9.]], [[ 7., 8., 9., 9., 9.]], [[ 8., 9., 9., 9., 9.]], [[ 9., 9., 9., 9., 9.]], [[10., 11., 12., 13., 14.]], [[11., 12., 13., 14., 14.]], [[12., 13., 14., 14., 14.]], [[13., 14., 14., 14., 14.]], [[14., 14., 14., 14., 14.]], [[ 0., 1., 2., 3., 4.]], [[ 1., 2., 3., 4., 5.]], [[ 2., 3., 4., 5., 6.]], [[ 3., 4., 5., 6., 6.]], [[ 4., 5., 6., 6., 6.]], [[ 5., 6., 6., 6., 6.]], [[ 6., 6., 6., 6., 6.]], [[ 7., 8., 9., 10., 11.]], [[ 8., 9., 10., 11., 12.]], [[ 9., 10., 11., 12., 12.]], [[10., 11., 12., 12., 12.]], [[11., 12., 12., 12., 12.]], [[12., 12., 12., 12., 12.]], [[13., 13., 13., 13., 13.]]]).unsqueeze(-1), mask=torch.tensor( [[[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, True]], [[ True, True, True, True, True]], [[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, True, True, True, True]], [[ True, True, True, True, True]], [[ True, True, True, True, False]], [[ True, True, True, False, False]], [[ True, True, False, False, False]], [[ True, False, False, False, False]], [[ True, False, False, False, False]]]), target=torch.tensor( [[[ 0., 1., 2., 3., 4.]], [[ 1., 2., 3., 4., 4.]], [[ 2., 3., 4., 4., 4.]], [[ 3., 4., 4., 4., 4.]], [[ 4., 4., 4., 4., 4.]], [[ 5., 6., 7., 8., 8.]], [[ 6., 7., 8., 8., 8.]], [[ 7., 8., 8., 8., 8.]], [[ 8., 8., 8., 8., 8.]], [[ 9., 10., 11., 12., 13.]], [[10., 11., 12., 13., 13.]], [[11., 12., 13., 13., 13.]], [[12., 13., 13., 13., 13.]], [[13., 13., 13., 13., 13.]], [[ 0., 1., 2., 3., 4.]], [[ 1., 2., 3., 4., 5.]], [[ 2., 3., 4., 5., 6.]], [[ 3., 4., 5., 6., 6.]], [[ 4., 5., 6., 6., 6.]], [[ 5., 6., 6., 6., 6.]], [[ 6., 6., 6., 6., 6.]], [[ 7., 8., 9., 10., 11.]], [[ 8., 9., 10., 11., 12.]], [[ 9., 10., 11., 12., 12.]], [[10., 11., 12., 12., 12.]], [[11., 12., 12., 12., 12.]], [[12., 12., 12., 12., 12.]], [[13., 13., 13., 13., 13.]]])) # yapf: enable alf.nest.map_structure(lambda x, y: self.assertEqual(x, y), processed_experience.rollout_info, expected)