def train_step(self, exp: TimeStep, state): # [B, num_unroll_steps + 1] info = exp.rollout_info targets = common.as_list(info.target) batch_size = exp.step_type.shape[0] latent, state = self._encoding_net(exp.observation, state) sim_latent = self._multi_step_latent_rollout(latent, self._num_unroll_steps, info.action, state) loss = 0 for i, decoder in enumerate(self._decoders): # [num_unroll_steps + 1)*B, ...] train_info = decoder.train_step(sim_latent).info train_info_spec = dist_utils.extract_spec(train_info) train_info = dist_utils.distributions_to_params(train_info) train_info = alf.nest.map_structure( lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x. shape[1:]), train_info) # [num_unroll_steps + 1, B, ...] train_info = dist_utils.params_to_distributions( train_info, train_info_spec) target = alf.nest.map_structure(lambda x: x.transpose(0, 1), targets[i]) loss_info = decoder.calc_loss(target, train_info, info.mask.t()) loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info) loss += loss_info.loss loss_info = LossInfo(loss=loss, extra=loss) return AlgStep(output=latent, state=state, info=loss_info)
def test_conversions(self): dists = { 't': torch.tensor([[1., 2., 4.], [3., 3., 1.]]), 'd': dist_utils.DiagMultivariateNormal( torch.tensor([[1., 2.], [2., 2.]]), torch.tensor([[2., 3.], [1., 1.]])) } params = dist_utils.distributions_to_params(dists) dists_spec = dist_utils.extract_spec(dists, from_dim=1) self.assertEqual(dists_spec['t'], alf.TensorSpec(shape=(3, ), dtype=torch.float32)) self.assertEqual(type(dists_spec['d']), dist_utils.DistributionSpec) self.assertEqual(len(params), 2) self.assertEqual(dists['t'], params['t']) self.assertEqual(dists['d'].base_dist.mean, params['d']['loc']) self.assertEqual(dists['d'].base_dist.stddev, params['d']['scale']) dists1 = dist_utils.params_to_distributions(params, dists_spec) self.assertEqual(len(dists1), 2) self.assertEqual(dists1['t'], dists['t']) self.assertEqual(type(dists1['d']), type(dists['d'])) params_spec = dist_utils.to_distribution_param_spec(dists_spec) alf.nest.assert_same_structure(params_spec, params) params1_spec = dist_utils.extract_spec(params) self.assertEqual(params_spec, params1_spec)
def train_step(self, exp: Experience, state): def _hook(grad, name): alf.summary.scalar("MCTS_state_grad_norm/" + name, grad.norm()) model_output = self._model.initial_inference(exp.observation) if alf.summary.should_record_summaries(): model_output.state.register_hook(partial(_hook, name="s0")) model_output_spec = dist_utils.extract_spec(model_output) model_outputs = [dist_utils.distributions_to_params(model_output)] info = exp.rollout_info for i in range(self._num_unroll_steps): model_output = self._model.recurrent_inference( model_output.state, info.action[:, i, ...]) if alf.summary.should_record_summaries(): model_output.state.register_hook( partial(_hook, name="s" + str(i + 1))) model_output = model_output._replace(state=scale_gradient( model_output.state, self._recurrent_gradient_scaling_factor)) model_outputs.append( dist_utils.distributions_to_params(model_output)) model_outputs = alf.nest.utils.stack_nests(model_outputs, dim=1) model_outputs = dist_utils.params_to_distributions( model_outputs, model_output_spec) return AlgStep(info=self._model.calc_loss(model_outputs, info.target))
def _update_total_dists(new_action, exp, total_dists): old_action = dist_utils.params_to_distributions( exp.action_param, self._action_distribution_spec) valid_masks = (exp.step_type != StepType.LAST).to(torch.float32) return nest_map( lambda d1, d2, total_dist: (_dist(d1, d2) * valid_masks).sum() + total_dist, old_action, new_action, total_dists)
def train_step(self, exp: TimeStep, state): # [B, num_unroll_steps + 1] info = exp.rollout_info batch_size = exp.step_type.shape[0] latent, state = self._encoding_net(exp.observation, state) sim_latents = [latent] if self._num_unroll_steps > 0: if self._latent_to_dstate_fc is not None: dstate = self._latent_to_dstate_fc(latent) dstate = dstate.split(self._dynamics_state_dims, dim=1) dstate = alf.nest.pack_sequence_as( self._dynamics_net.state_spec, dstate) else: dstate = state for i in range(self._num_unroll_steps): sim_latent, dstate = self._dynamics_net(info.action[:, i, ...], dstate) sim_latents.append(sim_latent) sim_latent = torch.cat(sim_latents, dim=0) # [num_unroll_steps + 1)*B, ...] train_info = self._decoder.train_step(sim_latent).info train_info_spec = dist_utils.extract_spec(train_info) train_info = dist_utils.distributions_to_params(train_info) train_info = alf.nest.map_structure( lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x. shape[1:]), train_info) # [num_unroll_steps + 1, B, ...] train_info = dist_utils.params_to_distributions( train_info, train_info_spec) target = alf.nest.map_structure(lambda x: x.transpose(0, 1), info.target) loss_info = self._decoder.calc_loss(target, train_info, info.mask.t()) loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info) return AlgStep(output=latent, state=state, info=loss_info)
def unroll(self, unroll_length): r"""Unroll ``unroll_length`` steps using the current policy. Because the ``self._env`` is a batched environment. The total number of environment steps is ``self._env.batch_size * unroll_length``. Args: unroll_length (int): number of steps to unroll. Returns: Experience: The stacked experience with shape :math:`[T, B, \ldots]` for each of its members. """ if self._current_time_step is None: self._current_time_step = common.get_initial_time_step(self._env) if self._current_policy_state is None: self._current_policy_state = self.get_initial_rollout_state( self._env.batch_size) if self._current_transform_state is None: self._current_transform_state = self.get_initial_transform_state( self._env.batch_size) time_step = self._current_time_step policy_state = self._current_policy_state trans_state = self._current_transform_state experience_list = [] initial_state = self.get_initial_rollout_state(self._env.batch_size) env_step_time = 0. store_exp_time = 0. for _ in range(unroll_length): policy_state = common.reset_state_if_necessary( policy_state, initial_state, time_step.is_first()) transformed_time_step, trans_state = self.transform_timestep( time_step, trans_state) # save the untransformed time step in case that sub-algorithms need # to store it in replay buffers transformed_time_step = transformed_time_step._replace( untransformed=time_step) policy_step = self.rollout_step(transformed_time_step, policy_state) # release the reference to ``time_step`` transformed_time_step = transformed_time_step._replace( untransformed=()) action = common.detach(policy_step.output) t0 = time.time() next_time_step = self._env.step(action) env_step_time += time.time() - t0 self.observe_for_metrics(time_step.cpu()) if self._exp_replayer_type == "one_time": exp = make_experience(transformed_time_step, policy_step, policy_state) else: exp = make_experience(time_step.cpu(), policy_step, policy_state) t0 = time.time() self.observe_for_replay(exp) store_exp_time += time.time() - t0 exp_for_training = Experience( action=action, reward=transformed_time_step.reward, discount=transformed_time_step.discount, step_type=transformed_time_step.step_type, state=policy_state, prev_action=transformed_time_step.prev_action, observation=transformed_time_step.observation, rollout_info=dist_utils.distributions_to_params( policy_step.info), env_id=transformed_time_step.env_id) experience_list.append(exp_for_training) time_step = next_time_step policy_state = policy_step.state alf.summary.scalar("time/unroll_env_step", env_step_time) alf.summary.scalar("time/unroll_store_exp", store_exp_time) experience = alf.nest.utils.stack_nests(experience_list) experience = experience._replace( rollout_info=dist_utils.params_to_distributions( experience.rollout_info, self._rollout_info_spec)) self._current_time_step = time_step # Need to detach so that the graph from this unroll is disconnected from # the next unroll. Otherwise backward() will report error for on-policy # training after the next unroll. self._current_policy_state = common.detach(policy_state) self._current_transform_state = common.detach(trans_state) return experience