コード例 #1
0
    def update(self, rollouts: Sequence[StepSequence]):
        r"""
        Train the particles $mu$.

        :param rollouts: rewards collected from the rollout
        """
        policy_grads = []
        parameters = []

        for i in range(self.num_particles):
            # Get the rollouts associated to the i-th particle
            concat_ros = StepSequence.concat(rollouts[i])
            concat_ros.torch()

            act_stats = compute_action_statistics(concat_ros,
                                                  self.expl_strats[i])
            act_stats_fixed = compute_action_statistics(
                concat_ros, self.fixed_expl_strats[i])

            klds = to.distributions.kl_divergence(act_stats.act_distr,
                                                  act_stats_fixed.act_distr)
            entropy = act_stats.act_distr.entropy()
            log_prob = act_stats.log_probs

            concat_ros.rewards = concat_ros.rewards - (
                0.1 * klds.mean(1)).view(-1) - 0.1 * entropy.mean(1).view(-1)

            # Update the advantage estimator's parameters and return advantage estimates
            adv = self.particles[i].critic.update(rollouts[i],
                                                  use_empirical_returns=True)

            # Estimate policy gradients
            self.optimizers[i].zero_grad()
            policy_grad = -to.mean(log_prob * adv.detach())
            policy_grad.backward()  # step comes later than usual

            # Collect flattened parameter and gradient vectors
            policy_grads.append(self.expl_strats[i].param_grad)
            parameters.append(self.expl_strats[i].param_values)

        parameters = to.stack(parameters)
        policy_grads = to.stack(policy_grads)
        Kxx, dx_Kxx = self.kernel(parameters)
        grad_theta = (to.mm(Kxx, policy_grads / self.temperature) +
                      dx_Kxx) / self.num_particles

        for i in range(self.num_particles):
            self.expl_strats[i].param_grad = grad_theta[i]
            self.optimizers[i].step()
        self.updatecount += 1
コード例 #2
0
def test_action_statistics(env: SimEnv, policy: Policy):
    sigma = 1.0  # with lower values like 0.1 we can observe violations of the tolerances

    # Create an action-based exploration strategy
    explstrat = NormalActNoiseExplStrat(policy, std_init=sigma)

    # Sample a deterministic rollout
    ro_policy = rollout(env,
                        policy,
                        eval=True,
                        max_steps=1000,
                        stop_on_done=False,
                        seed=0)
    ro_policy.torch(to.get_default_dtype())

    # Run the exploration strategy on the previously sampled rollout
    if policy.is_recurrent:
        if isinstance(policy, TwoHeadedPolicy):
            act_expl, _, _ = explstrat(ro_policy.observations)
        else:
            act_expl, _ = explstrat(ro_policy.observations)
        # Get the hidden states from the deterministic rollout
        hidden_states = ro_policy.hidden_states
    else:
        if isinstance(policy, TwoHeadedPolicy):
            act_expl, _ = explstrat(ro_policy.observations)
        else:
            act_expl = explstrat(ro_policy.observations)
        hidden_states = [
            0.0
        ] * ro_policy.length  # just something that does not violate the format

    ro_expl = StepSequence(
        actions=act_expl[:-1],  # truncate act due to last obs
        observations=ro_policy.observations,
        rewards=ro_policy.rewards,  # don't care but necessary
        hidden_states=hidden_states,
    )
    ro_expl.torch()

    # Compute action statistics and the ground truth
    actstats = compute_action_statistics(ro_expl, explstrat)
    gt_logprobs = Normal(loc=ro_policy.actions,
                         scale=sigma).log_prob(ro_expl.actions)
    gt_entropy = Normal(loc=ro_policy.actions, scale=sigma).entropy()

    to.testing.assert_allclose(actstats.log_probs,
                               gt_logprobs,
                               rtol=1e-4,
                               atol=1e-5)
    to.testing.assert_allclose(actstats.entropy,
                               gt_entropy,
                               rtol=1e-4,
                               atol=1e-5)
コード例 #3
0
    def update(self, rollouts: Sequence[StepSequence]):
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        with to.no_grad():
            # Compute the action probabilities using the old (before update) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_old = act_stats.log_probs
            act_distr_old = act_stats.act_distr

            # Compute value predictions using the old old (before update) value function
            v_pred_old = self._critic.values(concat_ros)

        # Attach advantages and old log probs to rollout
        concat_ros.add_data('log_probs_old', log_probs_old)
        concat_ros.add_data('v_pred_old', v_pred_old)

        # For logging the gradient norms
        policy_grad_norm = []
        value_fcn_grad_norm = []

        # Compute the value targets (empirical discounted returns) for all samples before fitting the V-fcn parameters
        adv = self._critic.gae(concat_ros)  # done with to.no_grad()
        v_targ = discounted_values(rollouts, self._critic.gamma).view(
            -1, 1)  # empirical discounted returns
        concat_ros.add_data('adv', adv)
        concat_ros.add_data('v_targ', v_targ)

        # Iterations over the whole data set
        for e in range(self.num_epoch):

            for batch in tqdm(concat_ros.split_shuffled_batches(
                    self.batch_size,
                    complete_rollouts=self._policy.is_recurrent
                    or isinstance(self._critic.value_fcn, RecurrentPolicy)),
                              total=num_iter_from_rollouts(
                                  None, concat_ros, self.batch_size),
                              desc=f'Epoch {e}',
                              unit='batches',
                              file=sys.stdout,
                              leave=False):
                # Reset the gradients
                self.optim.zero_grad()

                # Compute log of the action probabilities for the mini-batch
                log_probs = compute_action_statistics(
                    batch, self._expl_strat).log_probs.to(self.policy.device)

                # Compute value predictions for the mini-batch
                v_pred = self._critic.values(batch)

                # Compute combined loss and backpropagate
                loss = self.loss_fcn(log_probs, batch.log_probs_old, batch.adv,
                                     v_pred, batch.v_pred_old, batch.v_targ)
                loss.backward()

                # Clip the gradients if desired
                policy_grad_norm.append(
                    self.clip_grad(self._expl_strat.policy,
                                   self.max_grad_norm))
                value_fcn_grad_norm.append(
                    self.clip_grad(self._critic.value_fcn, self.max_grad_norm))

                # Call optimizer
                self.optim.step()

                if to.isnan(self._expl_strat.noise.std).any():
                    raise RuntimeError(
                        f'At least one exploration parameter became NaN! The exploration parameters are'
                        f'\n{self._expl_strat.std.detach().numpy()}')

            # Update the learning rate if a scheduler has been specified
            if self._lr_scheduler is not None:
                self._lr_scheduler.step()

        # Additional logging
        if self.log_loss:
            with to.no_grad():
                # Compute value predictions using the new (after the updates) value function approximator
                v_pred = self._critic.values(concat_ros).to(self.policy.device)
                v_loss_old = self._critic.loss_fcn(
                    v_pred_old.to(self.policy.device),
                    v_targ.to(self.policy.device)).to(self.policy.device)
                v_loss_new = self._critic.loss_fcn(v_pred, v_targ).to(
                    self.policy.device)
                value_fcn_loss_impr = v_loss_old - v_loss_new  # positive values are desired

                # Compute the action probabilities using the new (after the updates) policy
                act_stats = compute_action_statistics(concat_ros,
                                                      self._expl_strat)
                log_probs_new = act_stats.log_probs
                act_distr_new = act_stats.act_distr
                loss_after = self.loss_fcn(log_probs_new, log_probs_old, adv,
                                           v_pred, v_pred_old, v_targ)
                kl_avg = to.mean(kl_divergence(
                    act_distr_old,
                    act_distr_new))  # mean seeking a.k.a. inclusive KL

                # Compute explained variance (after the updates)
                explvar = explained_var(v_pred, v_targ)
                self.logger.add_value('explained var',
                                      explvar.detach().numpy())
                self.logger.add_value('V-fcn loss improvement',
                                      value_fcn_loss_impr.detach().numpy())
                self.logger.add_value('loss after',
                                      loss_after.detach().numpy())
                self.logger.add_value('KL(old_new)', kl_avg.item())

        # Logging
        self.logger.add_value(
            'avg expl strat std',
            to.mean(self._expl_strat.noise.std.data).detach().numpy())
        self.logger.add_value('expl strat entropy',
                              self._expl_strat.noise.get_entropy().item())
        self.logger.add_value('avg policy grad norm',
                              np.mean(policy_grad_norm))
        self.logger.add_value('avg V-fcn grad norm',
                              np.mean(value_fcn_grad_norm))
        if self._lr_scheduler is not None:
            self.logger.add_value('learning rate', self._lr_scheduler.get_lr())
コード例 #4
0
    def update(self, rollouts: Sequence[StepSequence]):
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        # Update the advantage estimator's parameters and return advantage estimates
        adv = self._critic.update(rollouts, use_empirical_returns=False)

        with to.no_grad():
            # Compute the action probabilities using the old (before update) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_old = act_stats.log_probs
            act_distr_old = act_stats.act_distr

        # Attach advantages and old log probs to rollout
        concat_ros.add_data('adv', adv)
        concat_ros.add_data('log_probs_old', log_probs_old)

        # For logging the gradient norms
        policy_grad_norm = []

        # Iterations over the whole data set
        for e in range(self.num_epoch):

            for batch in tqdm(concat_ros.split_shuffled_batches(
                    self.batch_size,
                    complete_rollouts=self._policy.is_recurrent),
                              total=num_iter_from_rollouts(
                                  None, concat_ros, self.batch_size),
                              desc=f'Epoch {e}',
                              unit='batches',
                              file=sys.stdout,
                              leave=False):
                # Reset the gradients
                self.optim.zero_grad()

                # Compute log of the action probabilities for the mini-batch
                log_probs = compute_action_statistics(
                    batch, self._expl_strat).log_probs

                # Compute policy loss and backpropagate
                loss = self.loss_fcn(log_probs, batch.log_probs_old, batch.adv)
                loss.backward()

                # Clip the gradients if desired
                policy_grad_norm.append(
                    self.clip_grad(self._expl_strat.policy,
                                   self.max_grad_norm))

                # Call optimizer
                self.optim.step()

                if to.isnan(self._expl_strat.noise.std).any():
                    raise RuntimeError(
                        f'At least one exploration parameter became NaN! The exploration parameters are'
                        f'\n{self._expl_strat.std.detach().numpy()}')

            # Update the learning rate if a scheduler has been specified
            if self._lr_scheduler is not None:
                self._lr_scheduler.step()

        # Additional logging
        if self.log_loss:
            with to.no_grad():
                act_stats = compute_action_statistics(concat_ros,
                                                      self._expl_strat)
                log_probs_new = act_stats.log_probs
                act_distr_new = act_stats.act_distr
                loss_after = self.loss_fcn(log_probs_new, log_probs_old, adv)
                kl_avg = to.mean(kl_divergence(
                    act_distr_old,
                    act_distr_new))  # mean seeking a.k.a. inclusive KL
                self.logger.add_value('loss after',
                                      loss_after.detach().numpy())
                self.logger.add_value('KL(old_new)', kl_avg.item())

        # Logging
        self.logger.add_value(
            'avg expl strat std',
            to.mean(self._expl_strat.noise.std.data).detach().numpy())
        self.logger.add_value('expl strat entropy',
                              self._expl_strat.noise.get_entropy().item())
        self.logger.add_value('avg policy grad norm',
                              np.mean(policy_grad_norm))
        if self._lr_scheduler is not None:
            self.logger.add_value('learning rate', self._lr_scheduler.get_lr())
コード例 #5
0
ファイル: a2c.py プロジェクト: arlene-kuehn/SimuRLacra
    def update(self, rollouts: Sequence[StepSequence]):
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        # Compute the value targets (empirical discounted returns) for all samples before fitting the V-fcn parameters
        adv = self._critic.gae(concat_ros)  # done with to.no_grad()
        v_targ = discounted_values(rollouts, self._critic.gamma).view(-1, 1).to(self.policy.device)  # empirical discounted returns

        with to.no_grad():
            # Compute value predictions and the GAE using the old (before the updates) value function approximator
            v_pred = self._critic.values(concat_ros)

            # Compute the action probabilities using the old (before update) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_old = act_stats.log_probs
            act_distr_old = act_stats.act_distr
            loss_before = self.loss_fcn(log_probs_old, adv, v_pred, v_targ)
            self.logger.add_value('loss before', loss_before, 4)

        concat_ros.add_data('adv', adv)
        concat_ros.add_data('v_targ', v_targ)

        # For logging the gradients' norms
        policy_grad_norm = []

        for batch in tqdm(concat_ros.split_shuffled_batches(
            self.batch_size,
            complete_rollouts=self._policy.is_recurrent or isinstance(self._critic.vfcn, RecurrentPolicy)),
            total=num_iter_from_rollouts(None, concat_ros, self.batch_size),
            desc='Updating', unit='batches', file=sys.stdout, leave=False):
            # Reset the gradients
            self.optim.zero_grad()

            # Compute log of the action probabilities for the mini-batch
            log_probs = compute_action_statistics(batch, self._expl_strat).log_probs

            # Compute value predictions for the mini-batch
            v_pred = self._critic.values(batch)

            # Compute combined loss and backpropagate
            loss = self.loss_fcn(log_probs, batch.adv, v_pred, batch.v_targ)
            loss.backward()

            # Clip the gradients if desired
            policy_grad_norm.append(self.clip_grad(self.expl_strat.policy, self.max_grad_norm))

            # Call optimizer
            self.optim.step()

        # Update the learning rate if a scheduler has been specified
        if self._lr_scheduler is not None:
            self._lr_scheduler.step()

        if to.isnan(self.expl_strat.noise.std).any():
            raise RuntimeError(f'At least one exploration parameter became NaN! The exploration parameters are'
                               f'\n{self.expl_strat.std.item()}')

        # Logging
        with to.no_grad():
            # Compute value predictions and the GAE using the new (after the updates) value function approximator
            v_pred = self._critic.values(concat_ros).to(self.policy.device)
            adv = self._critic.gae(concat_ros)  # done with to.no_grad()

            # Compute the action probabilities using the new (after the updates) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_new = act_stats.log_probs
            act_distr_new = act_stats.act_distr
            loss_after = self.loss_fcn(log_probs_new, adv, v_pred, v_targ)
            kl_avg = to.mean(
                kl_divergence(act_distr_old, act_distr_new))  # mean seeking a.k.a. inclusive KL
            explvar = explained_var(v_pred, v_targ)  # values close to 1 are desired
            self.logger.add_value('loss after', loss_after, 4)
            self.logger.add_value('KL(old_new)', kl_avg, 4)
            self.logger.add_value('explained var', explvar, 4)

        ent = self.expl_strat.noise.get_entropy()
        self.logger.add_value('avg expl strat std', to.mean(self.expl_strat.noise.std), 4)
        self.logger.add_value('expl strat entropy', to.mean(ent), 4)
        self.logger.add_value('avg grad norm policy', np.mean(policy_grad_norm), 4)
        if self._lr_scheduler is not None:
            self.logger.add_value('avg lr', np.mean(self._lr_scheduler.get_last_lr()), 6)