Esempio n. 1
0
    def train_step(self, exp: Experience, state):
        def _hook(grad, name):
            alf.summary.scalar("MCTS_state_grad_norm/" + name, grad.norm())

        model_output = self._model.initial_inference(exp.observation)
        if alf.summary.should_record_summaries():
            model_output.state.register_hook(partial(_hook, name="s0"))
        model_output_spec = dist_utils.extract_spec(model_output)
        model_outputs = [dist_utils.distributions_to_params(model_output)]
        info = exp.rollout_info

        for i in range(self._num_unroll_steps):
            model_output = self._model.recurrent_inference(
                model_output.state, info.action[:, i, ...])
            if alf.summary.should_record_summaries():
                model_output.state.register_hook(
                    partial(_hook, name="s" + str(i + 1)))
            model_output = model_output._replace(state=scale_gradient(
                model_output.state, self._recurrent_gradient_scaling_factor))
            model_outputs.append(
                dist_utils.distributions_to_params(model_output))

        model_outputs = alf.nest.utils.stack_nests(model_outputs, dim=1)
        model_outputs = dist_utils.params_to_distributions(
            model_outputs, model_output_spec)
        return AlgStep(info=self._model.calc_loss(model_outputs, info.target))
    def train_step(self, exp: TimeStep, state):
        # [B, num_unroll_steps + 1]
        info = exp.rollout_info
        targets = common.as_list(info.target)
        batch_size = exp.step_type.shape[0]
        latent, state = self._encoding_net(exp.observation, state)

        sim_latent = self._multi_step_latent_rollout(latent,
                                                     self._num_unroll_steps,
                                                     info.action, state)

        loss = 0
        for i, decoder in enumerate(self._decoders):
            # [num_unroll_steps + 1)*B, ...]
            train_info = decoder.train_step(sim_latent).info
            train_info_spec = dist_utils.extract_spec(train_info)
            train_info = dist_utils.distributions_to_params(train_info)
            train_info = alf.nest.map_structure(
                lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x.
                                    shape[1:]), train_info)
            # [num_unroll_steps + 1, B, ...]
            train_info = dist_utils.params_to_distributions(
                train_info, train_info_spec)
            target = alf.nest.map_structure(lambda x: x.transpose(0, 1),
                                            targets[i])
            loss_info = decoder.calc_loss(target, train_info, info.mask.t())
            loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0),
                                               loss_info)
            loss += loss_info.loss

        loss_info = LossInfo(loss=loss, extra=loss)

        return AlgStep(output=latent, state=state, info=loss_info)
Esempio n. 3
0
    def test_conversions(self):
        dists = {
            't':
            torch.tensor([[1., 2., 4.], [3., 3., 1.]]),
            'd':
            dist_utils.DiagMultivariateNormal(
                torch.tensor([[1., 2.], [2., 2.]]),
                torch.tensor([[2., 3.], [1., 1.]]))
        }
        params = dist_utils.distributions_to_params(dists)
        dists_spec = dist_utils.extract_spec(dists, from_dim=1)
        self.assertEqual(dists_spec['t'],
                         alf.TensorSpec(shape=(3, ), dtype=torch.float32))
        self.assertEqual(type(dists_spec['d']), dist_utils.DistributionSpec)
        self.assertEqual(len(params), 2)
        self.assertEqual(dists['t'], params['t'])
        self.assertEqual(dists['d'].base_dist.mean, params['d']['loc'])
        self.assertEqual(dists['d'].base_dist.stddev, params['d']['scale'])

        dists1 = dist_utils.params_to_distributions(params, dists_spec)
        self.assertEqual(len(dists1), 2)
        self.assertEqual(dists1['t'], dists['t'])
        self.assertEqual(type(dists1['d']), type(dists['d']))

        params_spec = dist_utils.to_distribution_param_spec(dists_spec)
        alf.nest.assert_same_structure(params_spec, params)
        params1_spec = dist_utils.extract_spec(params)
        self.assertEqual(params_spec, params1_spec)
Esempio n. 4
0
    def train_step(self, exp: TimeStep, state):
        # [B, num_unroll_steps + 1]
        info = exp.rollout_info
        batch_size = exp.step_type.shape[0]
        latent, state = self._encoding_net(exp.observation, state)

        sim_latents = [latent]

        if self._num_unroll_steps > 0:

            if self._latent_to_dstate_fc is not None:
                dstate = self._latent_to_dstate_fc(latent)
                dstate = dstate.split(self._dynamics_state_dims, dim=1)
                dstate = alf.nest.pack_sequence_as(
                    self._dynamics_net.state_spec, dstate)
            else:
                dstate = state

        for i in range(self._num_unroll_steps):
            sim_latent, dstate = self._dynamics_net(info.action[:, i, ...],
                                                    dstate)
            sim_latents.append(sim_latent)

        sim_latent = torch.cat(sim_latents, dim=0)

        # [num_unroll_steps + 1)*B, ...]
        train_info = self._decoder.train_step(sim_latent).info
        train_info_spec = dist_utils.extract_spec(train_info)
        train_info = dist_utils.distributions_to_params(train_info)
        train_info = alf.nest.map_structure(
            lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x.
                                shape[1:]), train_info)
        # [num_unroll_steps + 1, B, ...]
        train_info = dist_utils.params_to_distributions(
            train_info, train_info_spec)
        target = alf.nest.map_structure(lambda x: x.transpose(0, 1),
                                        info.target)
        loss_info = self._decoder.calc_loss(target, train_info, info.mask.t())
        loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info)

        return AlgStep(output=latent, state=state, info=loss_info)
Esempio n. 5
0
    def after_update(self, experience, train_info: TracInfo):
        """Adjust actor parameter according to KL-divergence."""
        action_param = dist_utils.distributions_to_params(
            train_info.action_distribution)
        exp_array = TracExperience(observation=train_info.observation,
                                   step_type=experience.step_type,
                                   action_param=action_param,
                                   prev_action=train_info.prev_action,
                                   state=train_info.state)
        dists, steps = self._trusted_updater.adjust_step(
            lambda: self._calc_change(exp_array), self._action_dist_clips)

        if alf.summary.should_record_summaries():
            with alf.summary.scope(self._name):
                for i, d in enumerate(alf.nest.flatten(dists)):
                    alf.summary.scalar("unadjusted_action_dist/%s" % i, d)
                alf.summary.scalar("adjust_steps", steps)

        ac_info = train_info.ac._replace(
            action_distribution=train_info.action_distribution)
        self._ac_algorithm.after_update(experience, ac_info)
Esempio n. 6
0
    def unroll(self, unroll_length):
        r"""Unroll ``unroll_length`` steps using the current policy.

        Because the ``self._env`` is a batched environment. The total number of
        environment steps is ``self._env.batch_size * unroll_length``.

        Args:
            unroll_length (int): number of steps to unroll.
        Returns:
            Experience: The stacked experience with shape :math:`[T, B, \ldots]`
            for each of its members.
        """
        if self._current_time_step is None:
            self._current_time_step = common.get_initial_time_step(self._env)
        if self._current_policy_state is None:
            self._current_policy_state = self.get_initial_rollout_state(
                self._env.batch_size)
        if self._current_transform_state is None:
            self._current_transform_state = self.get_initial_transform_state(
                self._env.batch_size)

        time_step = self._current_time_step
        policy_state = self._current_policy_state
        trans_state = self._current_transform_state

        experience_list = []
        initial_state = self.get_initial_rollout_state(self._env.batch_size)

        env_step_time = 0.
        store_exp_time = 0.
        for _ in range(unroll_length):
            policy_state = common.reset_state_if_necessary(
                policy_state, initial_state, time_step.is_first())
            transformed_time_step, trans_state = self.transform_timestep(
                time_step, trans_state)
            # save the untransformed time step in case that sub-algorithms need
            # to store it in replay buffers
            transformed_time_step = transformed_time_step._replace(
                untransformed=time_step)
            policy_step = self.rollout_step(transformed_time_step,
                                            policy_state)
            # release the reference to ``time_step``
            transformed_time_step = transformed_time_step._replace(
                untransformed=())

            action = common.detach(policy_step.output)

            t0 = time.time()
            next_time_step = self._env.step(action)
            env_step_time += time.time() - t0

            self.observe_for_metrics(time_step.cpu())

            if self._exp_replayer_type == "one_time":
                exp = make_experience(transformed_time_step, policy_step,
                                      policy_state)
            else:
                exp = make_experience(time_step.cpu(), policy_step,
                                      policy_state)

            t0 = time.time()
            self.observe_for_replay(exp)
            store_exp_time += time.time() - t0

            exp_for_training = Experience(
                action=action,
                reward=transformed_time_step.reward,
                discount=transformed_time_step.discount,
                step_type=transformed_time_step.step_type,
                state=policy_state,
                prev_action=transformed_time_step.prev_action,
                observation=transformed_time_step.observation,
                rollout_info=dist_utils.distributions_to_params(
                    policy_step.info),
                env_id=transformed_time_step.env_id)

            experience_list.append(exp_for_training)
            time_step = next_time_step
            policy_state = policy_step.state

        alf.summary.scalar("time/unroll_env_step", env_step_time)
        alf.summary.scalar("time/unroll_store_exp", store_exp_time)
        experience = alf.nest.utils.stack_nests(experience_list)
        experience = experience._replace(
            rollout_info=dist_utils.params_to_distributions(
                experience.rollout_info, self._rollout_info_spec))

        self._current_time_step = time_step
        # Need to detach so that the graph from this unroll is disconnected from
        # the next unroll. Otherwise backward() will report error for on-policy
        # training after the next unroll.
        self._current_policy_state = common.detach(policy_state)
        self._current_transform_state = common.detach(trans_state)

        return experience