コード例 #1
0
ファイル: gae.py プロジェクト: arlene-kuehn/SimuRLacra
    def update(self, rollouts: Sequence[StepSequence], use_empirical_returns: bool = False):
        """
        Adapt the parameters of the advantage function estimator, minimizing the MSE loss for the given samples.

        :param rollouts: batch of rollouts
        :param use_empirical_returns: use the return from the rollout (True) or the ones from the V-fcn (False)
        :return adv: tensor of advantages after V-function updates
        """
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        if use_empirical_returns:
            # Compute the value targets (empirical discounted returns) for all samples
            v_targ = discounted_values(rollouts, self.gamma).view(-1, 1)
        else:
            # Use the value function to compute the value targets (also called bootstrapping)
            v_targ = self.tdlamda_returns(concat_ros=concat_ros)
        concat_ros.add_data('v_targ', v_targ)

        # Logging
        with to.no_grad():
            v_pred_old = self.values(concat_ros)
            loss_old = self.loss_fcn(v_pred_old, v_targ)
        vfcn_grad_norm = []

        # Iterate over all gathered samples num_epoch times
        for e in range(self.num_epoch):

            for batch in tqdm(concat_ros.split_shuffled_batches(
                self.batch_size, complete_rollouts=isinstance(self.vfcn, RecurrentPolicy)),
                total=num_iter_from_rollouts(None, concat_ros, self.batch_size),
                desc=f'Epoch {e}', unit='batches', file=sys.stdout, leave=False):
                # Reset the gradients
                self.optim.zero_grad()

                # Make predictions for this mini-batch using values function
                v_pred = self.values(batch)

                # Compute estimator loss for this mini-batch and backpropagate
                vfcn_loss = self.loss_fcn(v_pred, batch.v_targ)
                vfcn_loss.backward()

                # Clip the gradients if desired
                vfcn_grad_norm.append(Algorithm.clip_grad(self.vfcn, self.max_grad_norm))

                # Call optimizer
                self.optim.step()

            # Update the learning rate if a scheduler has been specified
            if self._lr_scheduler is not None:
                self._lr_scheduler.step()

        # Estimate the advantage after fitting the parameters of the V-fcn
        adv = self.gae(concat_ros)  # is done with to.no_grad()

        with to.no_grad():
            v_pred_new = self.values(concat_ros)
            loss_new = self.loss_fcn(v_pred_new, v_targ)
            vfcn_loss_impr = loss_old - loss_new  # positive values are desired
            explvar = explained_var(v_pred_new, v_targ)  # values close to 1 are desired

        # Log metrics computed from the old value function (before the update)
        self.logger.add_value('explained var critic', explvar, 4)
        self.logger.add_value('loss improv critic', vfcn_loss_impr, 4)
        self.logger.add_value('avg grad norm critic', np.mean(vfcn_grad_norm), 4)
        if self._lr_scheduler is not None:
            self.logger.add_value('lr critic', self._lr_scheduler.get_last_lr(), 6)

        return adv
コード例 #2
0
    def update(self, rollouts: Sequence[StepSequence]):
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        with to.no_grad():
            # Compute the action probabilities using the old (before update) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_old = act_stats.log_probs
            act_distr_old = act_stats.act_distr

            # Compute value predictions using the old old (before update) value function
            v_pred_old = self._critic.values(concat_ros)

        # Attach advantages and old log probs to rollout
        concat_ros.add_data('log_probs_old', log_probs_old)
        concat_ros.add_data('v_pred_old', v_pred_old)

        # For logging the gradient norms
        policy_grad_norm = []
        value_fcn_grad_norm = []

        # Compute the value targets (empirical discounted returns) for all samples before fitting the V-fcn parameters
        adv = self._critic.gae(concat_ros)  # done with to.no_grad()
        v_targ = discounted_values(rollouts, self._critic.gamma).view(
            -1, 1)  # empirical discounted returns
        concat_ros.add_data('adv', adv)
        concat_ros.add_data('v_targ', v_targ)

        # Iterations over the whole data set
        for e in range(self.num_epoch):

            for batch in tqdm(concat_ros.split_shuffled_batches(
                    self.batch_size,
                    complete_rollouts=self._policy.is_recurrent
                    or isinstance(self._critic.value_fcn, RecurrentPolicy)),
                              total=num_iter_from_rollouts(
                                  None, concat_ros, self.batch_size),
                              desc=f'Epoch {e}',
                              unit='batches',
                              file=sys.stdout,
                              leave=False):
                # Reset the gradients
                self.optim.zero_grad()

                # Compute log of the action probabilities for the mini-batch
                log_probs = compute_action_statistics(
                    batch, self._expl_strat).log_probs.to(self.policy.device)

                # Compute value predictions for the mini-batch
                v_pred = self._critic.values(batch)

                # Compute combined loss and backpropagate
                loss = self.loss_fcn(log_probs, batch.log_probs_old, batch.adv,
                                     v_pred, batch.v_pred_old, batch.v_targ)
                loss.backward()

                # Clip the gradients if desired
                policy_grad_norm.append(
                    self.clip_grad(self._expl_strat.policy,
                                   self.max_grad_norm))
                value_fcn_grad_norm.append(
                    self.clip_grad(self._critic.value_fcn, self.max_grad_norm))

                # Call optimizer
                self.optim.step()

                if to.isnan(self._expl_strat.noise.std).any():
                    raise RuntimeError(
                        f'At least one exploration parameter became NaN! The exploration parameters are'
                        f'\n{self._expl_strat.std.detach().numpy()}')

            # Update the learning rate if a scheduler has been specified
            if self._lr_scheduler is not None:
                self._lr_scheduler.step()

        # Additional logging
        if self.log_loss:
            with to.no_grad():
                # Compute value predictions using the new (after the updates) value function approximator
                v_pred = self._critic.values(concat_ros).to(self.policy.device)
                v_loss_old = self._critic.loss_fcn(
                    v_pred_old.to(self.policy.device),
                    v_targ.to(self.policy.device)).to(self.policy.device)
                v_loss_new = self._critic.loss_fcn(v_pred, v_targ).to(
                    self.policy.device)
                value_fcn_loss_impr = v_loss_old - v_loss_new  # positive values are desired

                # Compute the action probabilities using the new (after the updates) policy
                act_stats = compute_action_statistics(concat_ros,
                                                      self._expl_strat)
                log_probs_new = act_stats.log_probs
                act_distr_new = act_stats.act_distr
                loss_after = self.loss_fcn(log_probs_new, log_probs_old, adv,
                                           v_pred, v_pred_old, v_targ)
                kl_avg = to.mean(kl_divergence(
                    act_distr_old,
                    act_distr_new))  # mean seeking a.k.a. inclusive KL

                # Compute explained variance (after the updates)
                explvar = explained_var(v_pred, v_targ)
                self.logger.add_value('explained var',
                                      explvar.detach().numpy())
                self.logger.add_value('V-fcn loss improvement',
                                      value_fcn_loss_impr.detach().numpy())
                self.logger.add_value('loss after',
                                      loss_after.detach().numpy())
                self.logger.add_value('KL(old_new)', kl_avg.item())

        # Logging
        self.logger.add_value(
            'avg expl strat std',
            to.mean(self._expl_strat.noise.std.data).detach().numpy())
        self.logger.add_value('expl strat entropy',
                              self._expl_strat.noise.get_entropy().item())
        self.logger.add_value('avg policy grad norm',
                              np.mean(policy_grad_norm))
        self.logger.add_value('avg V-fcn grad norm',
                              np.mean(value_fcn_grad_norm))
        if self._lr_scheduler is not None:
            self.logger.add_value('learning rate', self._lr_scheduler.get_lr())
コード例 #3
0
ファイル: a2c.py プロジェクト: arlene-kuehn/SimuRLacra
    def update(self, rollouts: Sequence[StepSequence]):
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        # Compute the value targets (empirical discounted returns) for all samples before fitting the V-fcn parameters
        adv = self._critic.gae(concat_ros)  # done with to.no_grad()
        v_targ = discounted_values(rollouts, self._critic.gamma).view(-1, 1).to(self.policy.device)  # empirical discounted returns

        with to.no_grad():
            # Compute value predictions and the GAE using the old (before the updates) value function approximator
            v_pred = self._critic.values(concat_ros)

            # Compute the action probabilities using the old (before update) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_old = act_stats.log_probs
            act_distr_old = act_stats.act_distr
            loss_before = self.loss_fcn(log_probs_old, adv, v_pred, v_targ)
            self.logger.add_value('loss before', loss_before, 4)

        concat_ros.add_data('adv', adv)
        concat_ros.add_data('v_targ', v_targ)

        # For logging the gradients' norms
        policy_grad_norm = []

        for batch in tqdm(concat_ros.split_shuffled_batches(
            self.batch_size,
            complete_rollouts=self._policy.is_recurrent or isinstance(self._critic.vfcn, RecurrentPolicy)),
            total=num_iter_from_rollouts(None, concat_ros, self.batch_size),
            desc='Updating', unit='batches', file=sys.stdout, leave=False):
            # Reset the gradients
            self.optim.zero_grad()

            # Compute log of the action probabilities for the mini-batch
            log_probs = compute_action_statistics(batch, self._expl_strat).log_probs

            # Compute value predictions for the mini-batch
            v_pred = self._critic.values(batch)

            # Compute combined loss and backpropagate
            loss = self.loss_fcn(log_probs, batch.adv, v_pred, batch.v_targ)
            loss.backward()

            # Clip the gradients if desired
            policy_grad_norm.append(self.clip_grad(self.expl_strat.policy, self.max_grad_norm))

            # Call optimizer
            self.optim.step()

        # Update the learning rate if a scheduler has been specified
        if self._lr_scheduler is not None:
            self._lr_scheduler.step()

        if to.isnan(self.expl_strat.noise.std).any():
            raise RuntimeError(f'At least one exploration parameter became NaN! The exploration parameters are'
                               f'\n{self.expl_strat.std.item()}')

        # Logging
        with to.no_grad():
            # Compute value predictions and the GAE using the new (after the updates) value function approximator
            v_pred = self._critic.values(concat_ros).to(self.policy.device)
            adv = self._critic.gae(concat_ros)  # done with to.no_grad()

            # Compute the action probabilities using the new (after the updates) policy
            act_stats = compute_action_statistics(concat_ros, self._expl_strat)
            log_probs_new = act_stats.log_probs
            act_distr_new = act_stats.act_distr
            loss_after = self.loss_fcn(log_probs_new, adv, v_pred, v_targ)
            kl_avg = to.mean(
                kl_divergence(act_distr_old, act_distr_new))  # mean seeking a.k.a. inclusive KL
            explvar = explained_var(v_pred, v_targ)  # values close to 1 are desired
            self.logger.add_value('loss after', loss_after, 4)
            self.logger.add_value('KL(old_new)', kl_avg, 4)
            self.logger.add_value('explained var', explvar, 4)

        ent = self.expl_strat.noise.get_entropy()
        self.logger.add_value('avg expl strat std', to.mean(self.expl_strat.noise.std), 4)
        self.logger.add_value('expl strat entropy', to.mean(ent), 4)
        self.logger.add_value('avg grad norm policy', np.mean(policy_grad_norm), 4)
        if self._lr_scheduler is not None:
            self.logger.add_value('avg lr', np.mean(self._lr_scheduler.get_last_lr()), 6)