Exemple #1
0
    def add(
        self,
        batch: Batch,
        buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Add a batch of data into ReplayBufferManager.

        Each of the data's length (first dimension) must equal to the length of
        buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].

        Return (current_index, episode_reward, episode_length, episode_start_index). If
        the episode is not finished, the return value of episode_length and
        episode_reward is 0.
        """
        # preprocess batch
        new_batch = Batch()
        for key in set(self._reserved_keys).intersection(batch.keys()):
            new_batch.__dict__[key] = batch[key]
        batch = new_batch
        assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
        if self._save_only_last_obs:
            batch.obs = batch.obs[:, -1]
        if not self._save_obs_next:
            batch.pop("obs_next", None)
        elif self._save_only_last_obs:
            batch.obs_next = batch.obs_next[:, -1]
        # get index
        if buffer_ids is None:
            buffer_ids = np.arange(self.buffer_num)
        ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], []
        for batch_idx, buffer_id in enumerate(buffer_ids):
            ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index(
                batch.rew[batch_idx], batch.done[batch_idx]
            )
            ptrs.append(ptr + self._offset[buffer_id])
            ep_lens.append(ep_len)
            ep_rews.append(ep_rew)
            ep_idxs.append(ep_idx + self._offset[buffer_id])
            self.last_index[buffer_id] = ptr + self._offset[buffer_id]
            self._lengths[buffer_id] = len(self.buffers[buffer_id])
        ptrs = np.array(ptrs)
        try:
            self._meta[ptrs] = batch
        except ValueError:
            batch.rew = batch.rew.astype(float)
            batch.done = batch.done.astype(bool)
            if self._meta.is_empty():
                self._meta = _create_value(  # type: ignore
                    batch, self.maxsize, stack=False)
            else:  # dynamic key pops up in batch
                _alloc_by_keys_diff(self._meta, batch, self.maxsize, False)
            self._set_batch_for_children()
            self._meta[ptrs] = batch
        return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
Exemple #2
0
    def add(
        self,
        batch: Batch,
        buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Add a batch of data into replay buffer.

        :param Batch batch: the input data batch. Its keys must belong to the 7
            reserved keys, and "obs", "act", "rew", "done" is required.
        :param buffer_ids: to make consistent with other buffer's add function; if it
            is not None, we assume the input batch's first dimension is always 1.

        Return (current_index, episode_reward, episode_length, episode_start_index). If
        the episode is not finished, the return value of episode_length and
        episode_reward is 0.
        """
        # preprocess batch
        b = Batch()
        for key in set(self._reserved_keys).intersection(batch.keys()):
            b.__dict__[key] = batch[key]
        batch = b
        assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
        stacked_batch = buffer_ids is not None
        if stacked_batch:
            assert len(batch) == 1
        if self._save_only_last_obs:
            batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1]
        if not self._save_obs_next:
            batch.pop("obs_next", None)
        elif self._save_only_last_obs:
            batch.obs_next = (
                batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1]
            )
        # get ptr
        if stacked_batch:
            rew, done = batch.rew[0], batch.done[0]
        else:
            rew, done = batch.rew, batch.done
        ptr, ep_rew, ep_len, ep_idx = list(
            map(lambda x: np.array([x]), self._add_index(rew, done))
        )
        try:
            self._meta[ptr] = batch
        except ValueError:
            stack = not stacked_batch
            batch.rew = batch.rew.astype(float)
            batch.done = batch.done.astype(bool)
            if self._meta.is_empty():
                self._meta = _create_value(  # type: ignore
                    batch, self.maxsize, stack)
            else:  # dynamic key pops up in batch
                _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack)
            self._meta[ptr] = batch
        return ptr, ep_rew, ep_len, ep_idx
Exemple #3
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     weight = batch.pop("weight", 1.0)
     # critic 1
     current_q1 = self.critic1(batch.obs, batch.act).flatten()
     target_q = batch.returns.flatten()
     td1 = current_q1 - target_q
     critic1_loss = (td1.pow(2) * weight).mean()
     # critic1_loss = F.mse_loss(current_q1, target_q)
     self.critic1_optim.zero_grad()
     critic1_loss.backward()
     self.critic1_optim.step()
     # critic 2
     current_q2 = self.critic2(batch.obs, batch.act).flatten()
     td2 = current_q2 - target_q
     critic2_loss = (td2.pow(2) * weight).mean()
     # critic2_loss = F.mse_loss(current_q2, target_q)
     self.critic2_optim.zero_grad()
     critic2_loss.backward()
     self.critic2_optim.step()
     batch.weight = (td1 + td2) / 2.0  # prio-buffer
     if self._cnt % self._freq == 0:
         actor_loss = -self.critic1(batch.obs,
                                    self(batch, eps=0.0).act).mean()
         self.actor_optim.zero_grad()
         actor_loss.backward()
         self._last = actor_loss.item()
         self.actor_optim.step()
         self.sync_weight()
     self._cnt += 1
     return {
         "loss/actor": self._last,
         "loss/critic1": critic1_loss.item(),
         "loss/critic2": critic2_loss.item(),
     }
Exemple #4
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     all_dist = self(batch).logits
     act = to_torch(batch.act, dtype=torch.long, device=all_dist.device)
     curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2)
     target_dist = batch.returns.unsqueeze(1)
     # calculate each element's difference between curr_dist and target_dist
     u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
     huber_loss = (
         u * (self.tau_hat -
              (target_dist - curr_dist).detach().le(0.).float()).abs()
     ).sum(-1).mean(1)
     qr_loss = (huber_loss * weight).mean()
     # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
     # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
     batch.weight = u.detach().abs().sum(-1).mean(1)  # prio-buffer
     # add CQL loss
     q = self.compute_q_value(all_dist, None)
     dataset_expec = q.gather(1, act.unsqueeze(1)).mean()
     negative_sampling = q.logsumexp(1).mean()
     min_q_loss = negative_sampling - dataset_expec
     loss = qr_loss + min_q_loss * self._min_q_weight
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {
         "loss": loss.item(),
         "loss/qr": qr_loss.item(),
         "loss/cql": min_q_loss.item(),
     }
Exemple #5
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     action_batch = self(batch)
     curr_dist, taus = action_batch.logits, action_batch.taus
     act = batch.act
     curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
     target_dist = batch.returns.unsqueeze(1)
     # calculate each element's difference between curr_dist and target_dist
     dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
     huber_loss = (
         dist_diff *
         (taus.unsqueeze(2) -
          (target_dist - curr_dist).detach().le(0.).float()).abs()
     ).sum(-1).mean(1)
     loss = (huber_loss * weight).mean()
     # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
     # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
     batch.weight = dist_diff.detach().abs().sum(-1).mean(1)  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Exemple #6
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        weight = batch.pop("weight", 1.0)
        target_q = batch.returns.flatten()
        act = to_torch(batch.act[:, np.newaxis],
                       device=target_q.device,
                       dtype=torch.long)

        # critic 1
        current_q1 = self.critic1(batch.obs).gather(1, act).flatten()
        td1 = current_q1 - target_q
        critic1_loss = (td1.pow(2) * weight).mean()

        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()

        # critic 2
        current_q2 = self.critic2(batch.obs).gather(1, act).flatten()
        td2 = current_q2 - target_q
        critic2_loss = (td2.pow(2) * weight).mean()

        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        dist = self(batch).dist
        entropy = dist.entropy()
        with torch.no_grad():
            current_q1a = self.critic1(batch.obs)
            current_q2a = self.critic2(batch.obs)
            q = torch.min(current_q1a, current_q2a)
        actor_loss = -(self._alpha * entropy +
                       (dist.probs * q).sum(dim=-1)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_prob = -entropy.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        self.sync_weight()

        result = {
            "loss/actor": actor_loss.item(),
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()  # type: ignore

        return result
Exemple #7
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        weight = batch.pop("weight", 1.0)

        # critic 1
        current_q1 = self.critic1(batch.obs, batch.act).flatten()
        target_q = batch.returns.flatten()
        td1 = current_q1 - target_q
        critic1_loss = (td1.pow(2) * weight).mean()
        # critic1_loss = F.mse_loss(current_q1, target_q)
        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()

        # critic 2
        current_q2 = self.critic2(batch.obs, batch.act).flatten()
        td2 = current_q2 - target_q
        critic2_loss = (td2.pow(2) * weight).mean()
        # critic2_loss = F.mse_loss(current_q2, target_q)
        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        obs_result = self(batch)
        a = obs_result.act
        current_q1a = self.critic1(batch.obs, a).flatten()
        current_q2a = self.critic2(batch.obs, a).flatten()
        actor_loss = (self._alpha * obs_result.log_prob.flatten()
                      - torch.min(current_q1a, current_q2a)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_prob = obs_result.log_prob.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        self.sync_weight()

        result = {
            "loss/actor": actor_loss.item(),
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()  # type: ignore

        return result
Exemple #8
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self._target and self._iter % self._freq == 0:
            self.sync_weight()
        weight = batch.pop("weight", 1.0)
        out = self(batch)
        curr_dist_orig = out.logits
        taus, tau_hats = out.fractions.taus, out.fractions.tau_hats
        act = batch.act
        curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2)
        target_dist = batch.returns.unsqueeze(1)
        # calculate each element's difference between curr_dist and target_dist
        dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
        huber_loss = (
            dist_diff *
            (tau_hats.unsqueeze(2) -
             (target_dist - curr_dist).detach().le(0.).float()).abs()
        ).sum(-1).mean(1)
        quantile_loss = (huber_loss * weight).mean()
        # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
        # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
        batch.weight = dist_diff.detach().abs().sum(-1).mean(1)  # prio-buffer
        # calculate fraction loss
        with torch.no_grad():
            sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :]
            sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :]
            # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
            # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169
            values_1 = sa_quantiles - sa_quantile_hats[:, :-1]
            signs_1 = sa_quantiles > torch.cat(
                [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1)

            values_2 = sa_quantiles - sa_quantile_hats[:, 1:]
            signs_2 = sa_quantiles < torch.cat(
                [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1)

            gradient_of_taus = (torch.where(signs_1, values_1, -values_1) +
                                torch.where(signs_2, values_2, -values_2))
        fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean()
        # calculate entropy loss
        entropy_loss = out.fractions.entropies.mean()
        fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss
        self._fraction_optim.zero_grad()
        fraction_entropy_loss.backward(retain_graph=True)
        self._fraction_optim.step()
        self.optim.zero_grad()
        quantile_loss.backward()
        self.optim.step()
        self._iter += 1
        return {
            "loss": quantile_loss.item() + fraction_entropy_loss.item(),
            "loss/quantile": quantile_loss.item(),
            "loss/fraction": fraction_loss.item(),
            "loss/entropy": entropy_loss.item()
        }
Exemple #9
0
 def get_loss_batch(self, batch: Batch) -> Dict[str, float]:
     weight = batch.pop("weight", 1.0)
     with torch.no_grad():
         current_q = self.critic(batch.obs, batch.act).flatten()
         target_q = batch.returns.flatten()
         td = current_q - target_q
         critic_loss = (td.pow(2) * weight).mean()
         action = self(batch).act
         actor_loss = -self.critic(batch.obs, action).mean()
     return {
         "la": actor_loss.item(),
         "lc": critic_loss.item(),
     }
Exemple #10
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop('weight', 1.)
     q = self(batch, eps=0., is_learning=True).logits
     # q = torch.tensor(q, requires_grad=True)
     # q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.policy.reshape((-1, 1)), q)
     loss = F.mse_loss(r, q)
     loss.backward()
     self.optim.step()
     self._cnt += 1
     return {'loss': loss.item()}
Exemple #11
0
    def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
        weight = batch.pop('weight', 1.)
        # critic 1
        current_q1 = self.critic1(batch.obs, batch.act).flatten()
        target_q = batch.returns.flatten()
        td1 = current_q1 - target_q
        critic1_loss = (td1.pow(2) * weight).mean()
        # critic1_loss = F.mse_loss(current_q1, target_q)
        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()
        # critic 2
        current_q2 = self.critic2(batch.obs, batch.act).flatten()
        td2 = current_q2 - target_q
        critic2_loss = (td2.pow(2) * weight).mean()
        # critic2_loss = F.mse_loss(current_q2, target_q)
        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()
        batch.weight = (td1 + td2) / 2.  # prio-buffer
        # actor
        obs_result = self(batch, explorating=False)
        a = obs_result.act
        current_q1a = self.critic1(batch.obs, a).flatten()
        current_q2a = self.critic2(batch.obs, a).flatten()
        actor_loss = (self._alpha * obs_result.log_prob.flatten()
                      - torch.min(current_q1a, current_q2a)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._automatic_alpha_tuning:
            log_prob = (obs_result.log_prob + self._target_entropy).detach()
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.exp()

        self.sync_weight()

        result = {
            'loss/actor': actor_loss.item(),
            'loss/critic1': critic1_loss.item(),
            'loss/critic2': critic2_loss.item(),
        }
        if self._automatic_alpha_tuning:
            result['loss/alpha'] = alpha_loss.item()
        return result
Exemple #12
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop('weight', 1.)
     q = self(batch, eps=0.).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns, q).flatten()
     td = r - q
     loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     loss.backward()
     self.optim.step()
     self._cnt += 1
     return {'loss': loss.item()}
Exemple #13
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     q = self(batch).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns.flatten(), q)
     td = r - q
     loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Exemple #14
0
    def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                   indice: np.ndarray) -> Batch:
        batch = super().process_fn(batch, buffer, indice)
        with torch.no_grad():
            # the degree of on-policiess
            opd = self(batch, output_dqn=False, output_opd=True).logits
            opd = opd[np.arange(len(opd)), batch.act]
            opd_weights = torch.sigmoid(opd * self.opd_temperature)
            opd_weights = opd_weights / torch.sum(opd_weights) * len(opd)
        opd_weights = opd_weights.detach().cpu().numpy()
        ori_weight = batch.pop("weight", 1.0)
        ori_weight *= opd_weights
        batch.weight = ori_weight

        return batch
Exemple #15
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     with torch.no_grad():
         target_dist = self._target_dist(batch)
     weight = batch.pop("weight", 1.0)
     curr_dist = self(batch).logits
     act = batch.act
     curr_dist = curr_dist[np.arange(len(act)), act, :]
     cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1)
     loss = (cross_entropy * weight).mean()
     # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100
     batch.weight = cross_entropy.detach()  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Exemple #16
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     weight = batch.pop("weight", 1.0)
     self.optim.zero_grad()
     q = self(batch, eps=0.).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns, q).flatten()
     c = torch.nn.SmoothL1Loss(reduction = 'none')
     # c = lambda r, q: (r-q).pow(2)
     td = c(r, q)
     loss = (td * weight).mean()
     batch.weight = loss  # prio-buffer
     loss.backward()
     if self.grad_norm_clipping:
         torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clipping)
     self.optim.step()
     self._cnt += 1
     return {'loss': loss.item()}
Exemple #17
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     q = self(batch).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns.flatten(), q)
     td = r - q
     loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     loss.backward()
     # Gradient clips
     torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
     self.optim.step()
     self._cnt += 1
     for param_group in self.optim.param_groups:
         lr = param_group['lr']
     return {"loss": loss.item(), "lr": lr}
Exemple #18
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device)
     q = self(batch).logits
     act_mask = torch.zeros_like(q)
     act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1)
     act_q = q * act_mask
     returns = batch.returns
     returns = returns * act_mask
     td_error = returns - act_q
     loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean()
     batch.weight = td_error.sum(-1).sum(-1)  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Exemple #19
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     weight = batch.pop("weight", 1.0)
     current_q = self.critic(batch.obs, batch.act).flatten()
     target_q = batch.returns.flatten()
     td = current_q - target_q
     critic_loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     self.critic_optim.zero_grad()
     critic_loss.backward()
     self.critic_optim.step()
     action = self(batch).act
     actor_loss = -self.critic(batch.obs, action).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         "loss/actor": actor_loss.item(),
         "loss/critic": critic_loss.item(),
     }
Exemple #20
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     weight = batch.pop('weight', 1.)
     current_q = self.critic(batch.obs, batch.act).flatten()
     target_q = batch.returns.flatten()
     td = current_q - target_q
     critic_loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     self.critic_optim.zero_grad()
     critic_loss.backward()
     self.critic_optim.step()
     action = self(batch, explorating=False).act
     actor_loss = -self.critic(batch.obs, action).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         'loss/actor': actor_loss.item(),
         'loss/critic': critic_loss.item(),
     }
Exemple #21
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self._target and self._iter % self._freq == 0:
            self.sync_weight()
        self.optim.zero_grad()
        weight = batch.pop("weight", 1.0)
        q = self(batch).logits
        q = q[np.arange(len(q)), batch.act]
        returns = to_torch_as(batch.returns.flatten(), q)
        td_error = returns - q

        if self._clip_loss_grad:
            y = q.reshape(-1, 1)
            t = returns.reshape(-1, 1)
            loss = torch.nn.functional.huber_loss(y, t, reduction="mean")
        else:
            loss = (td_error.pow(2) * weight).mean()

        batch.weight = td_error  # prio-buffer
        loss.backward()
        self.optim.step()
        self._iter += 1
        return {"loss": loss.item()}
Exemple #22
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self._target and self._iter % self._freq == 0:
            self.sync_weight()
        self.optim.zero_grad()

        # compute dqn loss
        weight = batch.pop("weight", 1.0)
        q = self(batch).logits
        q = q[np.arange(len(q)), batch.act]
        r = to_torch_as(batch.returns.flatten(), q)
        weight = to_torch_as(weight, q)
        td = r - q
        dqn_loss = (td.pow(2) * weight).mean()
        batch.weight = td  # prio-buffer

        # compute opd loss
        slow_preds = self(batch, output_dqn=False, output_opd=True).logits
        fast_preds = self(batch,
                          input="fast_obs",
                          output_dqn=False,
                          output_opd=True).logits
        # act_dim + 1 classes. The last class indicate the state is off-policy
        slow_label = torch.ones_like(
            slow_preds[:, 0], dtype=torch.int64) * int(self.model.output_dim)
        fast_label = to_torch_as(torch.tensor(batch.fast_act), slow_label)
        opd_loss = F.cross_entropy(slow_preds, slow_label) + \
                    F.cross_entropy(fast_preds, fast_label)

        # compute loss and back prop
        loss = dqn_loss + self.opd_loss_coeff * opd_loss
        loss.backward()
        self.optim.step()
        self._iter += 1
        return {
            "loss": loss.item(),
            "dqn_loss": dqn_loss.item(),
            "opd_loss": opd_loss.item(),
        }
Exemple #23
0
def test_batch():
    assert list(Batch()) == []
    assert Batch().is_empty()
    assert not Batch(b={'c': {}}).is_empty()
    assert Batch(b={'c': {}}).is_empty(recurse=True)
    assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
    assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
    assert not Batch(d=1).is_empty()
    assert not Batch(a=np.float64(1.0)).is_empty()
    assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
    assert not Batch(a=[1, 2, 3]).is_empty()
    b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
    assert b.c.dtype == object
    b = Batch(d=[None], e=[starmap], f=Batch)
    assert b.d.dtype == b.e.dtype == object and b.f == Batch
    b = Batch()
    b.update()
    assert b.is_empty()
    b.update(c=[3, 5])
    assert np.allclose(b.c, [3, 5])
    # mimic the behavior of dict.update, where kwargs can overwrite keys
    b.update({'a': 2}, a=3)
    assert 'a' in b and b.a == 3
    assert b.pop('a') == 3
    assert 'a' not in b
    with pytest.raises(AssertionError):
        Batch({1: 2})
    assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
    with pytest.raises(TypeError):
        Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
    batch = Batch(a=[torch.ones(3), torch.ones(3)])
    assert torch.allclose(batch.a, torch.ones(2, 3))
    batch.cat_(batch)
    assert torch.allclose(batch.a, torch.ones(4, 3))
    Batch(a=[])
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == batch["obs"]
    batch.obs = [1]
    assert batch.obs == [1]
    batch.cat_(batch)
    assert np.allclose(batch.obs, [1, 1])
    assert batch.np.shape == (6, 4)
    assert np.allclose(batch[0].obs, batch[1].obs)
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, shuffle=False)):
        if i != 5:
            assert b.obs == batch[i].obs
        else:
            with pytest.raises(AttributeError):
                batch[i].obs
            with pytest.raises(AttributeError):
                b.obs
    print(batch)
    batch = Batch(a=np.arange(10))
    with pytest.raises(AssertionError):
        list(batch.split(0))
    data = [
        (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]),
        (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]),
        (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]),
        (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
    ]
    for size, merge_last, result in data:
        bs = list(batch.split(size, shuffle=False, merge_last=merge_last))
        assert [bs[i].a.tolist() for i in range(len(bs))] == result
    batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
    batch_item = Batch({'a': [batch_dict]})[0]
    assert isinstance(batch_item.a.b, np.ndarray)
    assert batch_item.a.b == batch_dict['b']
    assert isinstance(batch_item.a.c, float)
    assert batch_item.a.c == batch_dict['c']
    assert isinstance(batch_item.a.d, torch.Tensor)
    assert batch_item.a.d == batch_dict['d']
    batch2 = Batch(a=[{
        'b': np.float64(1.0),
        'c': np.zeros(1),
        'd': Batch(e=np.array(3.0))}])
    assert len(batch2) == 1
    assert Batch().shape == []
    assert Batch(a=1).shape == []
    assert Batch(a=set((1, 2, 1))).shape == []
    assert batch2.shape[0] == 1
    assert 'a' in batch2 and all([i in batch2.a for i in 'bcd'])
    with pytest.raises(IndexError):
        batch2[-2]
    with pytest.raises(IndexError):
        batch2[1]
    assert batch2[0].shape == []
    with pytest.raises(IndexError):
        batch2[0][0]
    with pytest.raises(TypeError):
        len(batch2[0])
    assert isinstance(batch2[0].a.c, np.ndarray)
    assert isinstance(batch2[0].a.b, np.float64)
    assert isinstance(batch2[0].a.d.e, np.float64)
    batch2_from_list = Batch(list(batch2))
    batch2_from_comp = Batch([e for e in batch2])
    assert batch2_from_list.a.b == batch2.a.b
    assert batch2_from_list.a.c == batch2.a.c
    assert batch2_from_list.a.d.e == batch2.a.d.e
    assert batch2_from_comp.a.b == batch2.a.b
    assert batch2_from_comp.a.c == batch2.a.c
    assert batch2_from_comp.a.d.e == batch2.a.d.e
    for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
        assert batch_slice.a.b == batch2.a.b
        assert batch_slice.a.c == batch2.a.c
        assert batch_slice.a.d.e == batch2.a.d.e
    batch2.a.d.f = {}
    batch2_sum = (batch2 + 1.0) * 2
    assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
    assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
    assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
    assert batch2_sum.a.d.f.is_empty()
    with pytest.raises(TypeError):
        batch2 += [1]
    batch3 = Batch(a={
        'c': np.zeros(1),
        'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
    batch3.a.d[0] = {'e': 4.0}
    assert batch3.a.d.e[0] == 4.0
    batch3.a.d[0] = Batch(f=5.0)
    assert batch3.a.d.f[0] == 5.0
    with pytest.raises(ValueError):
        batch3.a.d[0] = Batch(f=5.0, g=0.0)
    with pytest.raises(ValueError):
        batch3[0] = Batch(a={"c": 2, "e": 1})
    # auto convert
    batch4 = Batch(a=np.array(['a', 'b']))
    assert batch4.a.dtype == object  # auto convert to object
    batch4.update(a=np.array(['c', 'd']))
    assert list(batch4.a) == ['c', 'd']
    assert batch4.a.dtype == object  # auto convert to object
    batch5 = Batch(a=np.array([{'index': 0}]))
    assert isinstance(batch5.a, Batch)
    assert np.allclose(batch5.a.index, [0])
    batch5.b = np.array([{'index': 1}])
    assert isinstance(batch5.b, Batch)
    assert np.allclose(batch5.b.index, [1])

    # None is a valid object and can be stored in Batch
    a = Batch.stack([Batch(a=None), Batch(b=None)])
    assert a.a[0] is None and a.a[1] is None
    assert a.b[0] is None and a.b[1] is None

    # nx.Graph corner case
    assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object
    g1 = nx.Graph()
    g1.add_nodes_from(list(range(10)))
    g2 = nx.Graph()
    g2.add_nodes_from(list(range(20)))
    assert Batch(a=np.array([g1, g2])).a.dtype == object