def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. Each of the data's length (first dimension) must equal to the length of buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess batch new_batch = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): new_batch.__dict__[key] = batch[key] batch = new_batch assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = batch.obs_next[:, -1] # get index if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( batch.rew[batch_idx], batch.done[batch_idx] ) ptrs.append(ptr + self._offset[buffer_id]) ep_lens.append(ep_len) ep_rews.append(ep_rew) ep_idxs.append(ep_idx + self._offset[buffer_id]) self.last_index[buffer_id] = ptr + self._offset[buffer_id] self._lengths[buffer_id] = len(self.buffers[buffer_id]) ptrs = np.array(ptrs) try: self._meta[ptrs] = batch except ValueError: batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) if self._meta.is_empty(): self._meta = _create_value( # type: ignore batch, self.maxsize, stack=False) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() self._meta[ptrs] = batch return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. :param Batch batch: the input data batch. Its keys must belong to the 7 reserved keys, and "obs", "act", "rew", "done" is required. :param buffer_ids: to make consistent with other buffer's add function; if it is not None, we assume the input batch's first dimension is always 1. Return (current_index, episode_reward, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ # preprocess batch b = Batch() for key in set(self._reserved_keys).intersection(batch.keys()): b.__dict__[key] = batch[key] batch = b assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) stacked_batch = buffer_ids is not None if stacked_batch: assert len(batch) == 1 if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: batch.obs_next = ( batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] ) # get ptr if stacked_batch: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done ptr, ep_rew, ep_len, ep_idx = list( map(lambda x: np.array([x]), self._add_index(rew, done)) ) try: self._meta[ptr] = batch except ValueError: stack = not stacked_batch batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) if self._meta.is_empty(): self._meta = _create_value( # type: ignore batch, self.maxsize, stack) else: # dynamic key pops up in batch _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) self._meta[ptr] = batch return ptr, ep_rew, ep_len, ep_idx
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer if self._cnt % self._freq == 0: actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() self.actor_optim.step() self.sync_weight() self._cnt += 1 return { "loss/actor": self._last, "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) all_dist = self(batch).logits act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( u * (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) qr_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer # add CQL loss q = self.compute_q_value(all_dist, None) dataset_expec = q.gather(1, act.unsqueeze(1)).mean() negative_sampling = q.logsumexp(1).mean() min_q_loss = negative_sampling - dataset_expec loss = qr_loss + min_q_loss * self._min_q_weight loss.backward() self.optim.step() self._iter += 1 return { "loss": loss.item(), "loss/qr": qr_loss.item(), "loss/cql": min_q_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) action_batch = self(batch) curr_dist, taus = action_batch.logits, action_batch.taus act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( dist_diff * (taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) # critic 1 current_q1 = self.critic1(batch.obs).gather(1, act).flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs).gather(1, act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor dist = self(batch).dist entropy = dist.entropy() with torch.no_grad(): current_q1a = self.critic1(batch.obs) current_q2a = self.critic2(batch.obs) q = torch.min(current_q1a, current_q2a) actor_loss = -(self._alpha * entropy + (dist.probs * q).sum(dim=-1)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._is_auto_alpha: log_prob = -entropy.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() self.sync_weight() result = { "loss/actor": actor_loss.item(), "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() result["alpha"] = self._alpha.item() # type: ignore return result
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor obs_result = self(batch) a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() actor_loss = (self._alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._is_auto_alpha: log_prob = obs_result.log_prob.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() self.sync_weight() result = { "loss/actor": actor_loss.item(), "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() result["alpha"] = self._alpha.item() # type: ignore return result
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) out = self(batch) curr_dist_orig = out.logits taus, tau_hats = out.fractions.taus, out.fractions.tau_hats act = batch.act curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( dist_diff * (tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) quantile_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer # calculate fraction loss with torch.no_grad(): sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :] # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 values_1 = sa_quantiles - sa_quantile_hats[:, :-1] signs_1 = sa_quantiles > torch.cat( [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1) values_2 = sa_quantiles - sa_quantile_hats[:, 1:] signs_2 = sa_quantiles < torch.cat( [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1) gradient_of_taus = (torch.where(signs_1, values_1, -values_1) + torch.where(signs_2, values_2, -values_2)) fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() # calculate entropy loss entropy_loss = out.fractions.entropies.mean() fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss self._fraction_optim.zero_grad() fraction_entropy_loss.backward(retain_graph=True) self._fraction_optim.step() self.optim.zero_grad() quantile_loss.backward() self.optim.step() self._iter += 1 return { "loss": quantile_loss.item() + fraction_entropy_loss.item(), "loss/quantile": quantile_loss.item(), "loss/fraction": fraction_loss.item(), "loss/entropy": entropy_loss.item() }
def get_loss_batch(self, batch: Batch) -> Dict[str, float]: weight = batch.pop("weight", 1.0) with torch.no_grad(): current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() action = self(batch).act actor_loss = -self.critic(batch.obs, action).mean() return { "la": actor_loss.item(), "lc": critic_loss.item(), }
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop('weight', 1.) q = self(batch, eps=0., is_learning=True).logits # q = torch.tensor(q, requires_grad=True) # q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.policy.reshape((-1, 1)), q) loss = F.mse_loss(r, q) loss.backward() self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: weight = batch.pop('weight', 1.) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2. # prio-buffer # actor obs_result = self(batch, explorating=False) a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() actor_loss = (self._alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._automatic_alpha_tuning: log_prob = (obs_result.log_prob + self._target_entropy).detach() alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.exp() self.sync_weight() result = { 'loss/actor': actor_loss.item(), 'loss/critic1': critic1_loss.item(), 'loss/critic2': critic2_loss.item(), } if self._automatic_alpha_tuning: result['loss/alpha'] = alpha_loss.item() return result
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop('weight', 1.) q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: batch = super().process_fn(batch, buffer, indice) with torch.no_grad(): # the degree of on-policiess opd = self(batch, output_dqn=False, output_opd=True).logits opd = opd[np.arange(len(opd)), batch.act] opd_weights = torch.sigmoid(opd * self.opd_temperature) opd_weights = opd_weights / torch.sum(opd_weights) * len(opd) opd_weights = opd_weights.detach().cpu().numpy() ori_weight = batch.pop("weight", 1.0) ori_weight *= opd_weights batch.weight = ori_weight return batch
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() with torch.no_grad(): target_dist = self._target_dist(batch) weight = batch.pop("weight", 1.0) curr_dist = self(batch).logits act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :] cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) loss = (cross_entropy * weight).mean() # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 batch.weight = cross_entropy.detach() # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) self.optim.zero_grad() q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() c = torch.nn.SmoothL1Loss(reduction = 'none') # c = lambda r, q: (r-q).pow(2) td = c(r, q) loss = (td * weight).mean() batch.weight = loss # prio-buffer loss.backward() if self.grad_norm_clipping: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clipping) self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() # Gradient clips torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optim.step() self._cnt += 1 for param_group in self.optim.param_groups: lr = param_group['lr'] return {"loss": loss.item(), "lr": lr}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) q = self(batch).logits act_mask = torch.zeros_like(q) act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1) act_q = q * act_mask returns = batch.returns returns = returns * act_mask td_error = returns - act_q loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean() batch.weight = td_error.sum(-1).sum(-1) # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() action = self(batch).act actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { "loss/actor": actor_loss.item(), "loss/critic": critic_loss.item(), }
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: weight = batch.pop('weight', 1.) current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() action = self(batch, explorating=False).act actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { 'loss/actor': actor_loss.item(), 'loss/critic': critic_loss.item(), }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] returns = to_torch_as(batch.returns.flatten(), q) td_error = returns - q if self._clip_loss_grad: y = q.reshape(-1, 1) t = returns.reshape(-1, 1) loss = torch.nn.functional.huber_loss(y, t, reduction="mean") else: loss = (td_error.pow(2) * weight).mean() batch.weight = td_error # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() # compute dqn loss weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) weight = to_torch_as(weight, q) td = r - q dqn_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer # compute opd loss slow_preds = self(batch, output_dqn=False, output_opd=True).logits fast_preds = self(batch, input="fast_obs", output_dqn=False, output_opd=True).logits # act_dim + 1 classes. The last class indicate the state is off-policy slow_label = torch.ones_like( slow_preds[:, 0], dtype=torch.int64) * int(self.model.output_dim) fast_label = to_torch_as(torch.tensor(batch.fast_act), slow_label) opd_loss = F.cross_entropy(slow_preds, slow_label) + \ F.cross_entropy(fast_preds, fast_label) # compute loss and back prop loss = dqn_loss + self.opd_loss_coeff * opd_loss loss.backward() self.optim.step() self._iter += 1 return { "loss": loss.item(), "dqn_loss": dqn_loss.item(), "opd_loss": opd_loss.item(), }
def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() assert not Batch(b={'c': {}}).is_empty() assert Batch(b={'c': {}}).is_empty(recurse=True) assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) assert not Batch(d=1).is_empty() assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) assert b.c.dtype == object b = Batch(d=[None], e=[starmap], f=Batch) assert b.d.dtype == b.e.dtype == object and b.f == Batch b = Batch() b.update() assert b.is_empty() b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) assert 'a' in b and b.a == 3 assert b.pop('a') == 3 assert 'a' not in b with pytest.raises(AssertionError): Batch({1: 2}) assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch.cat_(batch) assert torch.allclose(batch.a, torch.ones(4, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch) batch = Batch(a=np.arange(10)) with pytest.raises(AssertionError): list(batch.split(0)) data = [ (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]), (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]), (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]), (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), ] for size, merge_last, result in data: bs = list(batch.split(size, shuffle=False, merge_last=merge_last)) assert [bs[i].a.tolist() for i in range(len(bs))] == result batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict['b'] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] batch2 = Batch(a=[{ 'b': np.float64(1.0), 'c': np.zeros(1), 'd': Batch(e=np.array(3.0))}]) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] assert Batch(a=set((1, 2, 1))).shape == [] assert batch2.shape[0] == 1 assert 'a' in batch2 and all([i in batch2.a for i in 'bcd']) with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] assert batch2[0].shape == [] with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch([e for e in batch2]) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2.a.d.f = {} batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 assert batch2_sum.a.d.f.is_empty() with pytest.raises(TypeError): batch2 += [1] batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 with pytest.raises(ValueError): batch3.a.d[0] = Batch(f=5.0, g=0.0) with pytest.raises(ValueError): batch3[0] = Batch(a={"c": 2, "e": 1}) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == object # auto convert to object batch4.update(a=np.array(['c', 'd'])) assert list(batch4.a) == ['c', 'd'] assert batch4.a.dtype == object # auto convert to object batch5 = Batch(a=np.array([{'index': 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) batch5.b = np.array([{'index': 1}]) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) # None is a valid object and can be stored in Batch a = Batch.stack([Batch(a=None), Batch(b=None)]) assert a.a[0] is None and a.a[1] is None assert a.b[0] is None and a.b[1] is None # nx.Graph corner case assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object g1 = nx.Graph() g1.add_nodes_from(list(range(10))) g2 = nx.Graph() g2.add_nodes_from(list(range(20))) assert Batch(a=np.array([g1, g2])).a.dtype == object