예제 #1
0
    def evaluate(self):
        """Evaluate."""
        eval_env = VecEpsilonGreedy(VecFrameStack(self.env, self.frame_stack),
                                    self.eval_eps)
        self.qf.eval()
        misc.set_env_to_eval_mode(eval_env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(eval_env, self.qf, self.eval_num_episodes, outfile,
                            self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'video',
                               self.ckptr.format.format(self.t) + '.mp4')
        rl_record(eval_env, self.qf, self.record_num_episodes, outfile,
                  self.device)

        self.qf.train()
        misc.set_env_to_train_mode(self.env)
        self.data_manager.manual_reset()
    def __call__(self, ob, state_in=None):
        """Produce decision from model."""
        if self.t < self.policy_training_start:
            outs = self.pi(ob, state_in, deterministic=True)
        else:
            outs = self.pi(ob, state_in)

        def _res_norm(ac):
            return ac.abs().sum(dim=1).mean()
        residual_norm = nest.map_structure(_res_norm, outs.action)
        if isinstance(residual_norm, torch.Tensor):
            logger.add_scalar('actor/l1_residual_norm', residual_norm, self.t,
                              time.time())
            self.t += outs.action.shape[0]
        else:
            self.t += nest.flatten(outs.action)[0].shape[0]
            for k, v in residual_norm.items():
                logger.add_scalar(f'actor/{k}_residual_norm', v, self.t,
                                  time.time())
        data = {'action': outs.action,
                'value': self.vf(ob).value,
                'logp': outs.dist.log_prob(outs.action),
                'dist': outs.dist.to_tensors()}
        if outs.state_out:
            data['state'] = outs.state_out
        return data
예제 #3
0
    def evaluate(self):
        """Evaluate model."""
        self.pi.eval()
        misc.set_env_to_eval_mode(self.env)

        # Eval policy
        os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True)
        outfile = os.path.join(self.logdir, 'eval',
                               self.ckptr.format.format(self.t) + '.json')
        stats = rl_evaluate(self.env, self.pi, self.eval_num_episodes, outfile,
                            self.device)
        logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'],
                          self.t, time.time())
        logger.add_scalar('eval/mean_episode_length', stats['mean_length'],
                          self.t, time.time())

        # Record policy
        # os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True)
        # outfile = os.path.join(self.logdir, 'video',
        #                        self.ckptr.format.format(self.t) + '.mp4')
        # rl_record(self.env, self.pi, self.record_num_episodes, outfile,
        #           self.device)

        self.pi.train()
        misc.set_env_to_train_mode(self.env)
예제 #4
0
파일: td3.py 프로젝트: amackeith/dl
    def loss(self, batch):
        """Loss function."""
        # compute QFunction loss.
        with torch.no_grad():
            target_action = self.target_pi(batch['next_obs']).action
            noise = (torch.randn_like(target_action) *
                     self.policy_noise).clamp(-self.policy_noise_clip,
                                              self.policy_noise_clip)
            target_action = (target_action + noise).clamp(-1., 1.)
            target_q1 = self.target_qf1(batch['next_obs'], target_action).value
            target_q2 = self.target_qf2(batch['next_obs'], target_action).value
            target_q = torch.min(target_q1, target_q2)
            qtarg = self.reward_scale * batch['reward'].float() + (
                (1.0 - batch['done']) * self.gamma * target_q)

        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value
        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf_loss = self.qf_criterion(q1, qtarg) + self.qf_criterion(q2, qtarg)

        # compute policy loss
        if self.t % self.policy_update_period == 0:
            action = self.pi(batch['obs'], deterministic=True).action
            q = self.qf1(batch['obs'], action).value
            pi_loss = -q.mean()
        else:
            pi_loss = torch.zeros_like(qf_loss)

        # log losses
        if self.t % self.log_period < self.update_period:
            logger.add_scalar('loss/qf', qf_loss, self.t, time.time())
            if self.t % self.policy_update_period == 0:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())
        return pi_loss, qf_loss
예제 #5
0
파일: trainer.py 프로젝트: cbschaff/rsa
    def step(self):
        # Get batch.
        if self._diter is None:
            self._diter = self.dtrain.__iter__()
        try:
            batch = self._diter.__next__()
        except StopIteration:
            self.epochs += 1
            self._diter = None
            return self.epochs
        batch = nest.map_structure(lambda x: x.to(self.device), batch)

        # compute loss
        ob, ac = batch
        self.model.train()
        loss = -self.model(ob).log_prob(ac).mean()

        logger.add_scalar('train/loss',
                          loss.detach().cpu().numpy(), self.t, time.time())

        # update model
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        # increment step
        self.t += min(
            len(self.data) - (self.t % len(self.data)), self.batch_size)
        return self.epochs
예제 #6
0
def log_stats():
    from dl import logger
    cpu_util, mem_util, gpus = get_stats()
    timestamp = time.time()
    logger.add_scalar('hardware/cpu_util', cpu_util, walltime=timestamp)
    logger.add_scalar('hardware/mem_util', mem_util, walltime=timestamp)
    logger.add_scalar('hardware/cpu_util', cpu_util, walltime=timestamp)
    for gpu in gpus:
        logger.add_scalar(f'hardware/gpu{gpu.id}/util', gpu.util,
                          walltime=timestamp)
        logger.add_scalar(f'hardware/gpu{gpu.id}/mem_util', gpu.memutil,
                          walltime=timestamp)
예제 #7
0
파일: trainer.py 프로젝트: takuma-ynd/dl
    def evaluate(self):
        """Evaluate model."""
        self.model.eval()

        accuracy = []
        with torch.no_grad():
            for batch in self.dtest:
                x, y = nest.map_structure(lambda x: x.to(self.device), batch)
                y_hat = self.model(x).argmax(-1)
                accuracy.append((y_hat == y).float().mean().cpu().numpy())

            logger.add_scalar(f'test_accuracy', np.mean(accuracy), self.epochs,
                              time.time())
예제 #8
0
    def step(self, action):
        """Step."""
        obs, rews, dones, infos = self.venv.step(action)
        if not self._eval:
            self.t += np.sum(np.logical_not(self._dones))
        for i, d in enumerate(self._dones):  # handle synced resets
            if not d:
                self.lens[i] += 1
                self.rews[i] += rews[i]
            else:
                assert dones[i]
        for i, done in enumerate(dones):
            if done and not self._dones[i]:
                if not self._eval:
                    logger.add_scalar('env/episode_length', self.lens[i],
                                      self.t, time.time())
                    logger.add_scalar('env/episode_reward', self.rews[i],
                                      self.t, time.time())
                self.lens[i] = 0
                self.rews[i] = 0.
        # log unwrapped episode stats if they exist
        if 'episode_info' in infos[0]:
            for i, info in enumerate(infos):
                epinfo = info['episode_info']
                if epinfo['done'] and not self._eval and not self._dones[i]:
                    logger.add_scalar('env/unwrapped_episode_length',
                                      epinfo['length'], self.t, time.time())
                    logger.add_scalar('env/unwrapped_episode_reward',
                                      epinfo['reward'], self.t, time.time())
        self._dones = np.logical_or(dones, self._dones)

        return obs, rews, dones, infos
예제 #9
0
    def loss(self, batch):
        """Loss."""
        q = self.qf(batch['obs'], batch['action']).value

        with torch.no_grad():
            target = self._compute_target(batch['reward'], batch['next_obs'],
                                          batch['done'])

        assert target.shape == q.shape
        err = self.criterion(target, q)
        self.buffer.update_priorities(batch['idxes'],
                                      err.detach().cpu().numpy() + 1e-6)
        assert err.shape == batch['weights'].shape
        err = batch['weights'] * err
        loss = err.mean()

        if self.t % self.log_period < self.update_period:
            logger.add_scalar('alg/maxq', torch.max(q).detach().cpu().numpy(),
                              self.t, time.time())
            logger.add_scalar('alg/loss', loss.detach().cpu().numpy(), self.t,
                              time.time())
            logger.add_scalar('alg/epsilon',
                              self.eps_schedule.value(self._actor.t),
                              self.t, time.time())
            logger.add_scalar('alg/beta', self.beta_schedule.value(self.t),
                              self.t, time.time())
        return loss
예제 #10
0
파일: trainer.py 프로젝트: cbschaff/rsa
    def evaluate(self):
        """Evaluate model."""
        self.model.eval()

        nll = 0.
        with torch.no_grad():
            for batch in self.dtrain:
                ob, ac = nest.map_structure(lambda x: x.to(self.device), batch)
                nll += -self.model(ob).log_prob(ac).sum()
            avg_nll = nll / len(self.data)

            logger.add_scalar('train/NLL', nll, self.epochs, time.time())
            logger.add_scalar('train/AVG_NLL', avg_nll, self.epochs,
                              time.time())
예제 #11
0
 def __call__(self, obs):
     """Act."""
     self.t += nest.flatten(obs)[0].shape[0]
     if self.should_take_zero_action():
         if self.zero_action is None:
             with torch.no_grad():
                 self.zero_action = nest.map_structure(
                     torch.zeros_like,
                     self.pi(obs).action)
         return {'action': self.zero_action}
     else:
         ac = self.pi(obs).action
         with torch.no_grad():
             ac_norm = ac.abs().mean().cpu().numpy()
             logger.add_scalar('alg/residual_norm', ac_norm, self.t,
                               time.time())
         return {'action': self.pi(obs).action}
예제 #12
0
파일: alpha_zero.py 프로젝트: takuma-ynd/dl
    def step(self):
        """Step alpha zero."""
        self.pi.train()
        self.t += self.data_manager.play_game(self.n_sims)

        # fill replay buffer if needed
        while not self.buffer.full():
            self.t += self.data_manager.play_game(self.n_sims)

        for _ in range(self.batches_per_game):
            batch = self.data_manager.sample(self.batch_size)
            self.opt.zero_grad()
            loss = self.loss(batch)
            loss['total'].backward()
            self.opt.step()
            for k, v in loss.items():
                logger.add_scalar(f'loss/{k}',
                                  v.detach().cpu().numpy(), self.t,
                                  time.time())
        return self.t
예제 #13
0
 def __call__(self, ob, state_in=None, mask=None):
     """Produce decision from model."""
     if self.t < self.policy_training_start:
         outs = self.pi(ob, state_in, mask, deterministic=True)
         if not torch.allclose(outs.action, torch.zeros_like(outs.action)):
             raise ValueError("Pi should be initialized to output zero "
                              "actions so that an acurate value function "
                              "can be learned for the base policy.")
     else:
         outs = self.pi(ob, state_in, mask)
     residual_norm = torch.mean(torch.sum(torch.abs(outs.action), dim=1))
     logger.add_scalar('actor/l1_residual_norm', residual_norm, self.t,
                       time.time())
     self.t += outs.action.shape[0]
     data = {
         'action': outs.action,
         'value': outs.value,
         'logp': outs.dist.log_prob(outs.action)
     }
     if outs.state_out:
         data['state'] = outs.state_out
     return data
예제 #14
0
파일: ddpg.py 프로젝트: amackeith/dl
    def loss(self, batch):
        """Loss function."""
        # compute QFunction loss.
        with torch.no_grad():
            target_action = self.target_pi(batch['next_obs']).action
            target_q = self.target_qf(batch['next_obs'], target_action).value
            qtarg = self.reward_scale * batch['reward'].float() + (
                (1.0 - batch['done']) * self.gamma * target_q)

        q = self.qf(batch['obs'], batch['action']).value
        assert qtarg.shape == q.shape
        qf_loss = self.qf_criterion(q, qtarg)

        # compute policy loss
        action = self.pi(batch['obs'], deterministic=True).action
        q = self.qf(batch['obs'], action).value
        pi_loss = -q.mean()

        # log losses
        if self.t % self.log_period < self.update_period:
            logger.add_scalar('loss/qf', qf_loss, self.t, time.time())
            logger.add_scalar('loss/pi', pi_loss, self.t, time.time())
        return pi_loss, qf_loss
예제 #15
0
    def loss(self, batch):
        """Compute loss."""
        q = self.qf(batch['obs'], batch['action']).value

        with torch.no_grad():
            target = self._compute_target(batch['reward'], batch['next_obs'],
                                          batch['done'])

        assert target.shape == q.shape
        loss = self.criterion(target, q).mean()
        if self.t % self.log_period < self.update_period:
            logger.add_scalar('alg/maxq', torch.max(q).detach().cpu().numpy(),
                              self.t, time.time())
            logger.add_scalar('alg/loss', loss.detach().cpu().numpy(), self.t,
                              time.time())
            logger.add_scalar('alg/epsilon',
                              self.eps_schedule.value(self._actor.t),
                              self.t, time.time())
        return loss
예제 #16
0
    def step(self):
        """Compute rollout, loss, and update model."""
        self.pi.train()
        # adjust learning rate
        lr_frac = self.lr_decay_rate**(self.t // self.lr_decay_freq)
        for g in self.opt.param_groups:
            g['lr'] = self.pi_lr * lr_frac
        for g in self.opt_l.param_groups:
            g['lr'] = self.lambda_lr * lr_frac

        self.data_manager.rollout()
        self.t += self.data_manager.rollout_length * self.nenv
        losses = {}
        for _ in range(self.epochs_per_rollout):
            for batch in self.data_manager.sampler():
                loss = self.loss(batch)
                if losses == {}:
                    losses = {k: [] for k in loss}
                for k, v in loss.items():
                    losses[k].append(v.detach().cpu().numpy())
                if self.t >= max(self.policy_training_start,
                                 self.lambda_training_start):
                    self.opt_l.zero_grad()
                    loss['lambda'].backward(retain_graph=True)
                    self.opt_l.step()
                self.opt.zero_grad()
                loss['total'].backward()
                if self.max_grad_norm:
                    nn.utils.clip_grad_norm_(self.pi.parameters(),
                                             self.max_grad_norm)
                self.opt.step()
        for k, v in losses.items():
            logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time())
        logger.add_scalar('alg/lr_pi', self.opt.param_groups[0]['lr'], self.t,
                          time.time())
        logger.add_scalar('alg/lr_lambda', self.opt_l.param_groups[0]['lr'],
                          self.t, time.time())
        return self.t
예제 #17
0
    def step(self):
        """Compute rollout, loss, and update model."""
        self.pi.train()
        self.t += self.data_manager.rollout()
        losses = {}
        for _ in range(self.epochs_per_rollout):
            for batch in self.data_manager.sampler():
                self.opt.zero_grad()
                loss = self.loss(batch)
                if losses == {}:
                    losses = {k: [] for k in loss}
                for k, v in loss.items():
                    losses[k].append(v.detach().cpu().numpy())
                loss['total'].backward()
                if self.max_grad_norm:
                    norm = nn.utils.clip_grad_norm_(self.pi.parameters(),
                                                    self.max_grad_norm)
                    logger.add_scalar('alg/grad_norm', norm, self.t,
                                      time.time())
                    logger.add_scalar('alg/grad_norm_clipped',
                                      min(norm, self.max_grad_norm),
                                      self.t, time.time())
                self.opt.step()
        for k, v in losses.items():
            logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time())

        data = self.data_manager.storage.get_rollout()
        value_error = data['vpred'].data - data['q_mc'].data
        logger.add_scalar('alg/value_error_mean',
                          value_error.mean().cpu().numpy(), self.t, time.time())
        logger.add_scalar('alg/value_error_std',
                          value_error.std().cpu().numpy(), self.t, time.time())

        logger.add_scalar('alg/kl', self.compute_kl(), self.t, time.time())
        return self.t
예제 #18
0
    def loss(self, batch):
        """Compute loss."""
        if self.data_manager.recurrent:
            outs = self.pi(batch['obs'], batch['state'], batch['mask'])
        else:
            outs = self.pi(batch['obs'])
        loss = {}

        # compute policy loss
        if self.t < self.policy_training_start:
            pi_loss = torch.Tensor([0.0]).to(self.device)
        else:
            logp = outs.dist.log_prob(batch['action'])
            assert logp.shape == batch['logp'].shape
            ratio = torch.exp(logp - batch['logp'])
            assert ratio.shape == batch['atarg'].shape
            ploss1 = ratio * batch['atarg']
            ploss2 = torch.clamp(ratio, 1.0 - self.clip_param,
                                 1.0 + self.clip_param) * batch['atarg']
            pi_loss = -torch.min(ploss1, ploss2).mean()
        loss['pi'] = pi_loss

        # compute value loss
        vloss1 = 0.5 * self.mse(outs.value, batch['vtarg'])
        vpred_clipped = batch['vpred'] + (outs.value - batch['vpred']).clamp(
            -self.clip_param, self.clip_param)
        vloss2 = 0.5 * self.mse(vpred_clipped, batch['vtarg'])
        vf_loss = torch.max(vloss1, vloss2).mean()
        loss['value'] = vf_loss

        # compute entropy loss
        if self.t < self.policy_training_start:
            ent_loss = torch.Tensor([0.0]).to(self.device)
        else:
            ent_loss = outs.dist.entropy().mean()
        loss['entropy'] = ent_loss

        # compute residual regularizer
        if self.t < self.policy_training_start:
            reg_loss = torch.Tensor([0.0]).to(self.device)
        else:
            if self.l2_reg:
                reg_loss = outs.dist.rsample().pow(2).sum(dim=-1).mean()
            else:  # huber loss
                ac_norm = torch.norm(outs.dist.rsample(), dim=-1)
                reg_loss = self.huber(ac_norm, torch.zeros_like(ac_norm))
        loss['reg'] = reg_loss

        ###############################
        # Constrained loss added here.
        ###############################

        # soft plus on lambda to constrain it to be positive.
        lambda_ = F.softplus(self.log_lambda_)
        logger.add_scalar('alg/lambda', lambda_, self.t, time.time())
        logger.add_scalar('alg/lambda_', self.log_lambda_, self.t, time.time())
        if self.t < max(self.policy_training_start,
                        self.lambda_training_start):
            loss['lambda'] = torch.Tensor([0.0]).to(self.device)
        else:
            neps = (1.0 - batch['mask']).sum()
            loss['lambda'] = (
                lambda_ *
                (batch['reward'].sum() - self.reward_threshold * neps) /
                batch['reward'].size()[0])
        if self.t >= self.policy_training_start:
            loss['pi'] = (reg_loss + lambda_ * loss['pi']) / (1. + lambda_)
        loss['total'] = (loss['pi'] + self.vf_coef * vf_loss -
                         self.ent_coef * ent_loss)
        return loss
예제 #19
0
    def loss(self, batch):
        """Loss function."""
        pi_out = self.pi(batch['obs'], reparameterization_trick=True)
        logp = pi_out.dist.log_prob(pi_out.action)
        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value

        # alpha loss
        if self.automatic_entropy_tuning:
            ent_error = logp + self.target_entropy
            alpha_loss = -(self.log_alpha * ent_error.detach()).mean()
            self.opt_alpha.zero_grad()
            alpha_loss.backward()
            self.opt_alpha.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = self.alpha
            alpha_loss = 0

        # qf loss
        with torch.no_grad():
            next_pi_out = self.pi(batch['next_obs'])
            next_ac_logp = next_pi_out.dist.log_prob(next_pi_out.action)
            q1_next = self.target_qf1(batch['next_obs'],
                                      next_pi_out.action).value
            q2_next = self.target_qf2(batch['next_obs'],
                                      next_pi_out.action).value
            qnext = torch.min(q1_next, q2_next) - alpha * next_ac_logp
            qtarg = batch['reward'] + (1.0 -
                                       batch['done']) * self.gamma * qnext

        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf1_loss = self.mse_loss(q1, qtarg)
        qf2_loss = self.mse_loss(q2, qtarg)

        # pi loss
        pi_loss = None
        if self.t % self.policy_update_period == 0:
            q1_pi = self.qf1(batch['obs'], pi_out.action).value
            q2_pi = self.qf2(batch['obs'], pi_out.action).value
            min_q_pi = torch.min(q1_pi, q2_pi)
            assert min_q_pi.shape == logp.shape
            pi_loss = (alpha * logp - min_q_pi).mean()

            # log pi loss about as frequently as other losses
            if self.t % self.log_period < self.policy_update_period:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())

        if self.t % self.log_period < self.update_period:
            if self.automatic_entropy_tuning:
                logger.add_scalar('alg/log_alpha',
                                  self.log_alpha.detach().cpu().numpy(),
                                  self.t, time.time())
                scalars = {
                    "target": self.target_entropy,
                    "entropy": -torch.mean(logp.detach()).cpu().numpy().item()
                }
                logger.add_scalars('alg/entropy', scalars, self.t, time.time())
            else:
                logger.add_scalar(
                    'alg/entropy',
                    -torch.mean(logp.detach()).cpu().numpy().item(), self.t,
                    time.time())
            logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time())
            logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time())
            logger.add_scalar('alg/qf1',
                              q1.mean().detach().cpu().numpy(), self.t,
                              time.time())
            logger.add_scalar('alg/qf2',
                              q2.mean().detach().cpu().numpy(), self.t,
                              time.time())
        return pi_loss, qf1_loss, qf2_loss
예제 #20
0
    def step(self):
        """Compute rollout, loss, and update model."""
        self.pi.train()
        self.t += self.data_manager.rollout()
        losses = {'pi': [], 'vf': [], 'ent': [], 'kl': [], 'total': [],
                  'kl_pen': []}

        #######################
        # Update pi
        #######################

        if self.t >= self.policy_training_start:
            kl_too_big = False
            for _ in range(self.epochs_pi):
                if kl_too_big:
                    break
                for batch in self.data_manager.sampler():
                    self.opt_pi.zero_grad()
                    loss = self.loss_pi(batch)
                    # break if new policy is too different from old policy
                    if loss['kl'] > 4 * self.kl_target:
                        kl_too_big = True
                        break
                    loss['total'].backward()

                    for k, v in loss.items():
                        losses[k].append(v.detach().cpu().numpy())

                    if self.max_grad_norm:
                        norm = nn.utils.clip_grad_norm_(self.pi.parameters(),
                                                        self.max_grad_norm)
                        logger.add_scalar('alg/grad_norm', norm, self.t,
                                          time.time())
                        logger.add_scalar('alg/grad_norm_clipped',
                                          min(norm, self.max_grad_norm),
                                          self.t, time.time())
                    self.opt_pi.step()

        #######################
        # Update value function
        #######################
        for _ in range(self.epochs_vf):
            for batch in self.data_manager.sampler():
                self.opt_vf.zero_grad()
                loss = self.loss_vf(batch)
                losses['vf'].append(loss.detach().cpu().numpy())
                loss.backward()
                if self.max_grad_norm:
                    norm = nn.utils.clip_grad_norm_(self.vf.parameters(),
                                                    self.max_grad_norm)
                    logger.add_scalar('alg/vf_grad_norm', norm, self.t,
                                      time.time())
                    logger.add_scalar('alg/vf_grad_norm_clipped',
                                      min(norm, self.max_grad_norm),
                                      self.t, time.time())
                self.opt_vf.step()

        for k, v in losses.items():
            if len(v) > 0:
                logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time())

        # update weight on kl to match kl_target.
        if self.t >= self.policy_training_start:
            kl = self.compute_kl()
            if kl > 10.0 * self.kl_target and self.kl_weight < self.initial_kl_weight:
                self.kl_weight = self.initial_kl_weight
            elif kl > 1.3 * self.kl_target:
                self.kl_weight *= self.alpha
            elif kl < 0.7 * self.kl_target:
                self.kl_weight /= self.alpha

            logger.add_scalar('alg/kl', kl, self.t, time.time())
            logger.add_scalar('alg/kl_weight', self.kl_weight, self.t, time.time())

        data = self.data_manager.storage.get_rollout()
        value_error = data['vpred'].data - data['q_mc'].data
        logger.add_scalar('alg/value_error_mean',
                          value_error.mean().cpu().numpy(), self.t, time.time())
        logger.add_scalar('alg/value_error_std',
                          value_error.std().cpu().numpy(), self.t, time.time())
        return self.t
예제 #21
0
    def loss(self, batch):
        """Loss function."""
        dist = self.pi(batch['obs']).dist
        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value

        # alpha loss
        if self.automatic_entropy_tuning:
            ent_error = dist.entropy() - self.target_entropy
            alpha_loss = self.log_alpha * ent_error.detach().mean()
            self.opt_alpha.zero_grad()
            alpha_loss.backward()
            self.opt_alpha.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = self.alpha
            alpha_loss = 0

        # qf loss
        with torch.no_grad():
            next_dist = self.pi(batch['next_obs']).dist
            q1_next = self.target_qf1(batch['next_obs']).qvals
            q2_next = self.target_qf2(batch['next_obs']).qvals
            qmin = torch.min(q1_next, q2_next)
            # explicitly compute the expectation over next actions
            qnext = torch.sum(qmin * next_dist.probs,
                              dim=1) + alpha * next_dist.entropy()
            qtarg = batch['reward'] + (1.0 -
                                       batch['done']) * self.gamma * qnext

        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf1_loss = self.mse_loss(q1, qtarg)
        qf2_loss = self.mse_loss(q2, qtarg)

        # pi loss
        pi_loss = None
        if self.t % self.policy_update_period == 0:
            with torch.no_grad():
                q1_pi = self.qf1(batch['obs']).qvals
                q2_pi = self.qf2(batch['obs']).qvals
                min_q_pi = torch.min(q1_pi, q2_pi)
            assert min_q_pi.shape == dist.logits.shape
            target_dist = CatDist(logits=min_q_pi)
            pi_dist = CatDist(logits=alpha * dist.logits)
            pi_loss = pi_dist.kl(target_dist).mean()

            # log pi loss about as frequently as other losses
            if self.t % self.log_period < self.policy_update_period:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())

        if self.t % self.log_period < self.update_period:
            if self.automatic_entropy_tuning:
                logger.add_scalar('alg/log_alpha',
                                  self.log_alpha.detach().cpu().numpy(),
                                  self.t, time.time())
                scalars = {
                    "target": self.target_entropy,
                    "entropy":
                    dist.entropy().mean().detach().cpu().numpy().item()
                }
                logger.add_scalars('alg/entropy', scalars, self.t, time.time())
            else:
                logger.add_scalar(
                    'alg/entropy',
                    dist.entropy().mean().detach().cpu().numpy().item(),
                    self.t, time.time())
            logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time())
            logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time())
            logger.add_scalar('alg/qf1',
                              q1.mean().detach().cpu().numpy(), self.t,
                              time.time())
            logger.add_scalar('alg/qf2',
                              q2.mean().detach().cpu().numpy(), self.t,
                              time.time())
        return pi_loss, qf1_loss, qf2_loss
예제 #22
0
    def loss(self, batch):
        """Loss function."""
        pi_out = self.pi(batch['obs'], reparameterization_trick=True)
        logp = pi_out.dist.log_prob(pi_out.action)
        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value

        # alpha loss
        should_update_policy = (self.t >= self.policy_training_start
                                and self.t % self.policy_update_period == 0)
        if self.automatic_entropy_tuning:
            if should_update_policy:
                ent_error = logp + self.target_entropy
                alpha_loss = -(self.log_alpha * ent_error.detach()).mean()
                self.opt_alpha.zero_grad()
                alpha_loss.backward()
                self.opt_alpha.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = self.alpha
            alpha_loss = 0

        # qf loss
        with torch.no_grad():
            next_pi_out = self.pi(batch['next_obs'])
            next_ac = next_pi_out.action
            # Account for the fact that we are learning about the base policy
            # before we start updating the residual policy
            if self.t < self.policy_training_start:
                next_ac = nest.map_structure(torch.zeros_like, next_ac)
            next_ac_logp = next_pi_out.dist.log_prob(next_ac)

            q1_next = self.target_qf1(batch['next_obs'], next_ac).value
            q2_next = self.target_qf2(batch['next_obs'], next_ac).value
            qnext = torch.min(q1_next, q2_next) - alpha * next_ac_logp
            qtarg = batch['reward'] + (1.0 -
                                       batch['done']) * self.gamma * qnext

        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf1_loss = self.mse_loss(q1,
                                 qtarg) + self.q_reg_weight * (q1**2).mean()
        qf2_loss = self.mse_loss(q2,
                                 qtarg) + self.q_reg_weight * (q2**2).mean()

        # pi loss
        pi_loss = None
        if should_update_policy:
            q1_pi = self.qf1(batch['obs'], pi_out.action).value
            q2_pi = self.qf2(batch['obs'], pi_out.action).value
            min_q_pi = torch.min(q1_pi, q2_pi)
            assert min_q_pi.shape == logp.shape
            pi_loss = (alpha * logp - min_q_pi).mean()
            action_reg = self.action_reg_weight * (pi_out.action**2).mean()
            pi_loss = pi_loss + action_reg

            # log pi loss about as frequently as other losses
            if self.t % self.log_period < self.policy_update_period:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())

        if self.t % self.log_period < self.update_period:
            if self.automatic_entropy_tuning:
                logger.add_scalar('alg/log_alpha',
                                  self.log_alpha.detach().cpu().numpy(),
                                  self.t, time.time())
                scalars = {
                    "target": self.target_entropy,
                    "entropy": -torch.mean(logp.detach()).cpu().numpy().item()
                }
                logger.add_scalars('alg/entropy', scalars, self.t, time.time())
            else:
                logger.add_scalar(
                    'alg/entropy',
                    -torch.mean(logp.detach()).cpu().numpy().item(), self.t,
                    time.time())
            logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time())
            logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time())
            logger.add_scalar('alg/qf1',
                              q1.mean().detach().cpu().numpy(), self.t,
                              time.time())
            logger.add_scalar('alg/qf2',
                              q2.mean().detach().cpu().numpy(), self.t,
                              time.time())
        return pi_loss, qf1_loss, qf2_loss
예제 #23
0
    def step(self):
        """Compute rollout, loss, and update model."""
        self.pi.train()
        self.t += self.data_manager.rollout()
        losses = {'pi': [], 'vf': [], 'ent': [], 'kl': [], 'total': [],
                  'kl_pen': [], 'rnd': []}
        if self.norm_advantages:
            atarg = self.data_manager.storage.data['atarg']
            atarg = atarg[:, 0] + self.rnd_coef * atarg[:, 1]
            self.data_manager.storage.data['atarg'][:, 0] -= atarg.mean()
            self.data_manager.storage.data['atarg'] /= atarg.std() + 1e-5

        #######################
        # Update pi
        #######################
        kl_too_big = False
        for _ in range(self.epochs_pi):
            if kl_too_big:
                break
            for batch in self.data_manager.sampler():
                self.opt_pi.zero_grad()
                loss = self.loss_pi(batch)
                # break if new policy is too different from old policy
                if loss['kl'] > 4 * self.kl_target:
                    kl_too_big = True
                    break
                loss['total'].backward()

                for k, v in loss.items():
                    losses[k].append(v.detach().cpu().numpy())

                if self.max_grad_norm:
                    norm = nn.utils.clip_grad_norm_(self.pi.parameters(),
                                                    self.max_grad_norm)
                    logger.add_scalar('alg/grad_norm', norm, self.t,
                                      time.time())
                    logger.add_scalar('alg/grad_norm_clipped',
                                      min(norm, self.max_grad_norm),
                                      self.t, time.time())
                self.opt_pi.step()

        #######################
        # Update value function
        #######################
        for _ in range(self.epochs_vf):
            for batch in self.data_manager.sampler():
                self.opt_vf.zero_grad()
                loss = self.loss_vf(batch)
                losses['vf'].append(loss.detach().cpu().numpy())
                loss.backward()
                if self.max_grad_norm:
                    norm = nn.utils.clip_grad_norm_(self.vf.parameters(),
                                                    self.max_grad_norm)
                    logger.add_scalar('alg/vf_grad_norm', norm, self.t,
                                      time.time())
                    logger.add_scalar('alg/vf_grad_norm_clipped',
                                      min(norm, self.max_grad_norm),
                                      self.t, time.time())
                self.opt_vf.step()

        #######################
        # Update RND
        #######################
        for batch in self.data_manager.sampler():
            self.rnd_update_count += 1
            if self.rnd_update_count % self.rnd_subsample_rate == 0:
                loss = self.rnd.update(batch['obs'])
                losses['rnd'].append(loss.detach().cpu().numpy())

        for k, v in losses.items():
            logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time())

        # update weight on kl to match kl_target.
        kl = self.compute_kl()
        if kl > 10.0 * self.kl_target and self.kl_weight < self.initial_kl_weight:
            self.kl_weight = self.initial_kl_weight
        elif kl > 1.3 * self.kl_target:
            self.kl_weight *= self.alpha
        elif kl < 0.7 * self.kl_target:
            self.kl_weight /= self.alpha

        logger.add_scalar('alg/kl', kl, self.t, time.time())
        logger.add_scalar('alg/kl_weight', self.kl_weight, self.t, time.time())
        avg_return = self.data_manager.storage.data['return'][:, 0].mean(dim=0)
        avg_return = avg_return[0] + self.rnd_coef * avg_return[1]
        logger.add_scalar('alg/return', avg_return, self.t, time.time())

        # log value errors
        errors = []
        for batch in self.data_manager.sampler():
            errors.append(batch['vpred'] - batch['q_mc'])
        errors = torch.cat(errors)
        logger.add_scalar('alg/value_error_mean',
                          errors.mean().cpu().numpy(), self.t, time.time())
        logger.add_scalar('alg/value_error_std',
                          errors.std().cpu().numpy(), self.t, time.time())

        return self.t
예제 #24
0
    def step(self):
        """Compute rollout, loss, and update model."""
        self.pi.train()
        self.t += self.data_manager.rollout()
        losses = {
            'pi': [],
            'vf': [],
            'ent': [],
            'kl': [],
            'total': [],
            'kl_pen': [],
            'rnd': [],
            'ide': []
        }
        # if self.norm_advantages:
        #     atarg = self.data_manager.storage.data['atarg']
        #     atarg = atarg[:, 0] + self.ngu_coef * atarg[:, 1]
        #     self.data_manager.storage.data['atarg'][:, 0] -= atarg.mean()
        #     self.data_manager.storage.data['atarg'] /= atarg.std() + 1e-5

        if self.t >= self.policy_training_starts:
            #######################
            # Update pi
            #######################
            kl_too_big = False
            for _ in range(self.epochs_pi):
                if kl_too_big:
                    break
                for batch in self.data_manager.sampler():
                    self.opt_pi.zero_grad()
                    loss = self.loss_pi(batch)
                    # break if new policy is too different from old policy
                    if loss['kl'] > 4 * self.kl_target:
                        kl_too_big = True
                        break
                    loss['total'].backward()

                    for k, v in loss.items():
                        losses[k].append(v.detach().cpu().numpy())

                    if self.max_grad_norm:
                        norm = nn.utils.clip_grad_norm_(
                            self.pi.parameters(), self.max_grad_norm)
                        logger.add_scalar('alg/grad_norm', norm, self.t,
                                          time.time())
                        logger.add_scalar('alg/grad_norm_clipped',
                                          min(norm, self.max_grad_norm),
                                          self.t, time.time())
                    self.opt_pi.step()

            #######################
            # Update value function
            #######################
            for _ in range(self.epochs_vf):
                for batch in self.data_manager.sampler():
                    self.opt_vf.zero_grad()
                    loss = self.loss_vf(batch)
                    losses['vf'].append(loss.detach().cpu().numpy())
                    loss.backward()
                    if self.max_grad_norm:
                        norm = nn.utils.clip_grad_norm_(
                            self.vf.parameters(), self.max_grad_norm)
                        logger.add_scalar('alg/vf_grad_norm', norm, self.t,
                                          time.time())
                        logger.add_scalar('alg/vf_grad_norm_clipped',
                                          min(norm, self.max_grad_norm),
                                          self.t, time.time())
                    self.opt_vf.step()

        rollout = self.data_manager.storage.data
        lens = self.data_manager.storage.sequence_lengths.int()
        #######################
        # Store rollout_data in replay_buffer
        #######################
        for i in range(self.nenv):
            if self.buffer.num_in_buffer < self.buffer_size or i % self.ngu_subsample_freq == 0:
                for step in range(lens[i]):

                    def _f(x):
                        return x[step][i].cpu().numpy()

                    data = nest.map_structure(_f, rollout)
                    idx = self.buffer.store_observation(data['obs'])
                    self.buffer.store_effect(
                        idx, {
                            'done': data['done'],
                            'action': data['action'],
                            'reward': data['reward']
                        })

        #######################
        # Update NGU
        #######################
        for _ in range(self.ngu_updates):
            batch = self.buffer.sample(self.ngu_batch_size)

            def _to_torch(data):
                if isinstance(data, np.ndarray):
                    return torch.from_numpy(data).to(self.device)
                else:
                    return data

            batch = nest.map_structure(_to_torch, batch)
            loss = self.ngu.update_rnd(batch['obs']['ob'])
            losses['rnd'].append(loss.detach().cpu().numpy())
            not_done = torch.logical_not(batch['done'])
            loss = self.ngu.update_ide(batch['obs']['ob'][not_done],
                                       batch['next_obs']['ob'][not_done],
                                       batch['action'][not_done].long())
            losses['ide'].append(loss.detach().cpu().numpy())

        for k, v in losses.items():
            if len(v) == 0:
                continue
            logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time())

        # update weight on kl to match kl_target.
        if self.t >= self.policy_training_starts:
            kl = self.compute_kl()
            if kl > 10.0 * self.kl_target and self.kl_weight < self.initial_kl_weight:
                self.kl_weight = self.initial_kl_weight
            elif kl > 1.3 * self.kl_target:
                self.kl_weight *= self.alpha
            elif kl < 0.7 * self.kl_target:
                self.kl_weight /= self.alpha
        else:
            kl = 0.0

        logger.add_scalar('alg/kl', kl, self.t, time.time())
        logger.add_scalar('alg/kl_weight', self.kl_weight, self.t, time.time())
        avg_return = self.data_manager.storage.data['return'][:, 0].mean(dim=0)
        avg_return = (avg_return[0] +
                      self.ngu_coef * avg_return[1]) / (1 + self.ngu_coef)
        logger.add_scalar('alg/return', avg_return, self.t, time.time())

        # log value errors
        errors = []
        for batch in self.data_manager.sampler():
            errors.append(batch['vpred'] - batch['q_mc'])
        errors = torch.cat(errors)
        logger.add_scalar('alg/value_error_mean',
                          errors.mean().cpu().numpy(), self.t, time.time())
        logger.add_scalar('alg/value_error_std',
                          errors.std().cpu().numpy(), self.t, time.time())

        return self.t
예제 #25
0
파일: sac.py 프로젝트: takuma-ynd/dl
    def loss(self, batch):
        """Loss function."""
        pi_out = self.pi(batch['obs'], reparameterization_trick=self.rsample)
        if self.discrete:
            new_ac = pi_out.action
            ent = pi_out.dist.entropy()
        else:
            assert isinstance(pi_out.dist, TanhNormal), (
                "It is strongly encouraged that you use a TanhNormal "
                "action distribution for continuous action spaces.")
            if self.rsample:
                new_ac, new_pth_ac = pi_out.dist.rsample(
                                                    return_pretanh_value=True)
            else:
                new_ac, new_pth_ac = pi_out.dist.sample(
                                                    return_pretanh_value=True)
            logp = pi_out.dist.log_prob(new_ac, new_pth_ac)
        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value
        v = self.vf(batch['obs']).value

        # alpha loss
        if self.automatic_entropy_tuning:
            if self.discrete:
                ent_error = -ent + self.target_entropy
            else:
                ent_error = logp + self.target_entropy
            alpha_loss = -(self.log_alpha * ent_error.detach()).mean()
            self.opt_alpha.zero_grad()
            alpha_loss.backward()
            self.opt_alpha.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = 1
            alpha_loss = 0

        # qf loss
        vtarg = self.target_vf(batch['next_obs']).value
        qtarg = self.reward_scale * batch['reward'].float() + (
                    (1.0 - batch['done']) * self.gamma * vtarg)
        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf1_loss = self.qf_criterion(q1, qtarg.detach())
        qf2_loss = self.qf_criterion(q2, qtarg.detach())

        # vf loss
        q1_outs = self.qf1(batch['obs'], new_ac)
        q1_new = q1_outs.value
        q2_new = self.qf2(batch['obs'], new_ac).value
        q = torch.min(q1_new, q2_new)
        if self.discrete:
            vtarg = q + alpha * ent
        else:
            vtarg = q - alpha * logp
        assert v.shape == vtarg.shape
        vf_loss = self.vf_criterion(v, vtarg.detach())

        # pi loss
        pi_loss = None
        if self.t % self.policy_update_period == 0:
            if self.discrete:
                target_dist = CatDist(logits=q1_outs.qvals.detach())
                pi_dist = CatDist(logits=alpha * pi_out.dist.logits)
                pi_loss = pi_dist.kl(target_dist).mean()
            else:
                if self.rsample:
                    assert q.shape == logp.shape
                    pi_loss = (alpha*logp - q1_new).mean()
                else:
                    pi_targ = q1_new - v
                    assert pi_targ.shape == logp.shape
                    pi_loss = (logp * (alpha * logp - pi_targ).detach()).mean()

                pi_loss += self.policy_mean_reg_weight * (
                                            pi_out.dist.normal.mean**2).mean()

            # log pi loss about as frequently as other losses
            if self.t % self.log_period < self.policy_update_period:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())

        if self.t % self.log_period < self.update_period:
            if self.automatic_entropy_tuning:
                logger.add_scalar('ent/log_alpha',
                                  self.log_alpha.detach().cpu().numpy(), self.t,
                                  time.time())
                if self.discrete:
                    scalars = {"target": self.target_entropy,
                               "entropy": ent.mean().detach().cpu().numpy().item()}
                else:
                    scalars = {"target": self.target_entropy,
                               "entropy": -torch.mean(
                                            logp.detach()).cpu().numpy().item()}
                logger.add_scalars('ent/entropy', scalars, self.t, time.time())
            else:
                if self.discrete:
                    logger.add_scalar(
                            'ent/entropy',
                            ent.mean().detach().cpu().numpy().item(),
                            self.t, time.time())
                else:
                    logger.add_scalar(
                            'ent/entropy',
                            -torch.mean(logp.detach()).cpu().numpy().item(),
                            self.t, time.time())
            logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time())
            logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time())
            logger.add_scalar('loss/vf', vf_loss, self.t, time.time())
        return pi_loss, qf1_loss, qf2_loss, vf_loss