示例#1
0
    def train_on_batch(self, batches):
        gt.blank_stamp()
        for i in range(self.n_iters):
            self._n_train_steps_total += 1
            for batch in batches:
                batch = utils.to_tensor_batch(batch, self.policy.device)
                losses, eval_stats = self.compute_loss(
                    batch,
                    skip_statistics=not self._need_to_update_eval_statistics)

                self.policy_optimizer.zero_grad()
                losses.policy_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(),
                                               max_norm=0.5)
                self.policy_optimizer.step()

                self.critic_optimizer.zero_grad()
                losses.critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(),
                                               max_norm=0.5)
                self.critic_optimizer.step()

                if self._need_to_update_eval_statistics:
                    self.eval_statistics = eval_stats
                    # Compute statistics using only one batch per epoch
                    self._need_to_update_eval_statistics = False
                gt.stamp('ppo training', unique=False)
示例#2
0
    def train_from_torch(self, batch):
        gt.blank_stamp()
        losses, stats, errors = self.compute_loss(
            batch,
            skip_statistics=not self._need_to_update_eval_statistics,
        )
        """
        Update networks
        """
        if self.use_automatic_entropy_tuning:
            self.alpha_optimizer.zero_grad()
            losses.alpha_loss.backward()
            self.alpha_optimizer.step()

        self.policy_optimizer.zero_grad()
        losses.policy_loss.backward()
        self.policy_optimizer.step()

        self.qf1_optimizer.zero_grad()
        losses.qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        losses.qf2_loss.backward()
        self.qf2_optimizer.step()

        self._n_train_steps_total += 1

        self.try_update_target_networks()
        if self._need_to_update_eval_statistics:
            self.eval_statistics = stats
            # Compute statistics using only one batch per epoch
            self._need_to_update_eval_statistics = False
        gt.stamp('sac training', unique=False)
        return errors
示例#3
0
import gtimer as gt
import time

time.sleep(0.1)
gt.start()
time.sleep(0.1)
gt.stamp('first')
gt.pause()
time.sleep(0.1)
gt.resume()
gt.stamp('second')
time.sleep(0.1)
gt.blank_stamp('third')
time.sleep(0.1)
gt.stop('fourth')
time.sleep(0.1)
print gt.report()