def __iter__(self): if self.shuffle == False: index_list = get_order_list(self.total_len) else: index_list = get_random_list(self.total_len) self.batch_len = len(self) if "batch_len" in os.environ: self.batch_len = int(os.environ["batch_len"]) if not hasattr(self, "workers") and self.num_workers: self._init_workers() if self.num_workers: self._stop_all_workers() self.index_list_numpy[:] = index_list gid_obj = self.gid.get_obj() gid_lock = self.gid.get_lock() with gid_lock: gid_obj.value = 0 self.gidc.notify_all() for i in range(self.batch_len): # try not get lock first if gid_obj.value <= i: with gid_lock: if gid_obj.value <= i: if mp_log_v: print("wait") self.gidc.wait() worker_id = self.idmap[i] w = self.workers[worker_id] if mp_log_v: print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) batch = w.buffer.recv() if mp_log_v: print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ]) batch = self.to_jittor(batch) yield batch else: batch_data = [] for idx in index_list: batch_data.append(self[int(idx)]) if len(batch_data) == self.batch_size: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) yield batch_data batch_data = [] # depend on drop_last if not self.drop_last and len(batch_data) > 0: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) yield batch_data
def __iter__(self): if self.shuffle == False: index_list = get_order_list(self.total_len) else: index_list = get_random_list(self.total_len) batch_data = [] for idx in index_list: batch_data.append(self[int(idx)]) if len(batch_data) == self.batch_size: batch_data = self.collate_batch(batch_data) yield batch_data batch_data = [] # depend on drop_last if not self.drop_last and len(batch_data) > 0: batch_data = self.collate_batch(batch_data) yield batch_data
def __iter__(self): if self.total_len is None: self.total_len = len(self) if self.shuffle == False: index_list = get_order_list(self.total_len) else: index_list = get_random_list(self.total_len) # scatter index_list for all mpi process # scatter rule: # batch 1 batch 2 # [........] [........] ... # 00011122 00011122 # if last batch is smaller than world_size # pad to world_size # last batch # [.] -> [012] if jt.in_mpi: world_size = mpi.world_size() world_rank = mpi.world_rank() index_list = np.int32(index_list) mpi.broadcast(index_list, 0) assert self.batch_size >= world_size, \ f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" real_batch_size = (self.batch_size-1) // world_size + 1 if real_batch_size * world_size != self.batch_size: LOG.w("Batch size is not divisible by MPI world size, " "The distributed version may be different from " "the single-process version.") fix_batch = self.total_len // self.batch_size last_batch = self.total_len - fix_batch * self.batch_size fix_batch_l = index_list[0:fix_batch*self.batch_size] \ .reshape(-1,self.batch_size) fix_batch_l = fix_batch_l[ :,real_batch_size*world_rank:real_batch_size*(world_rank+1)] real_batch_size = fix_batch_l.shape[1] fix_batch_l = fix_batch_l.flatten() if not self.drop_last and last_batch > 0: last_batch_l = index_list[-last_batch:] real_last_batch = (last_batch-1)//world_size+1 l = real_last_batch * world_rank r = l + real_last_batch if r > last_batch: r = last_batch if l >= r: l = r-1 index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) else: index_list = fix_batch_l self.real_len = len(index_list) self.real_batch_size = real_batch_size assert self.total_len // self.batch_size == \ self.real_len // self.real_batch_size else: self.real_len = self.total_len self.real_batch_size = self.batch_size self.batch_len = self.__batch_len__() if not hasattr(self, "workers") and self.num_workers: self._init_workers() if self.num_workers: self._stop_all_workers() self.index_list_numpy[:] = index_list gid_obj = self.gid.get_obj() gid_lock = self.gid.get_lock() with gid_lock: gid_obj.value = 0 self.gidc.notify_all() start = time.time() self.batch_time = 0 for i in range(self.batch_len): # try not get lock first if gid_obj.value <= i: with gid_lock: if gid_obj.value <= i: if mp_log_v: print("wait") self.gidc.wait() now = time.time() self.wait_time = now - start start = now self.last_id = i worker_id = self.idmap[i] w = self.workers[worker_id] if mp_log_v: print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) batch = w.buffer.recv() now = time.time() self.recv_time = now - start start = now if mp_log_v: print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ]) batch = self.to_jittor(batch) now = time.time() self.to_jittor_time = now - start start = now yield batch now = time.time() self.batch_time = now - start start = now else: batch_data = [] for idx in index_list: batch_data.append(self[int(idx)]) if len(batch_data) == self.real_batch_size: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) yield batch_data batch_data = [] # depend on drop_last if not self.drop_last and len(batch_data) > 0: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) yield batch_data
def __iter__(self): if self.shuffle == False: index_list = get_order_list(self.total_len) else: index_list = get_random_list(self.total_len) # scatter index_list for all mpi process # scatter rule: # [........] # 000111 # 222 # make sure each process has the same len if mpi: index_list = np.int32(index_list) mpi.broadcast(index_list, 0) real_len = (self.total_len - 1) // mpi.world_size() + 1 offset = mpi.world_rank() * real_len if offset + real_len > self.total_len: offset -= offset + real_len - self.total_len index_list = index_list[offset:offset + real_len] self.real_len = real_len assert real_len == len(index_list) else: self.real_len = self.total_len self.batch_len = len(self) if "batch_len" in os.environ: self.batch_len = int(os.environ["batch_len"]) if not hasattr(self, "workers") and self.num_workers: self._init_workers() if self.num_workers: self._stop_all_workers() self.index_list_numpy[:] = index_list gid_obj = self.gid.get_obj() gid_lock = self.gid.get_lock() with gid_lock: gid_obj.value = 0 self.gidc.notify_all() for i in range(self.batch_len): # try not get lock first if gid_obj.value <= i: with gid_lock: if gid_obj.value <= i: if mp_log_v: print("wait") self.gidc.wait() worker_id = self.idmap[i] w = self.workers[worker_id] if mp_log_v: print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) batch = w.buffer.recv() if mp_log_v: print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [type(b).__name__ for b in batch]) batch = self.to_jittor(batch) yield batch else: batch_data = [] for idx in index_list: batch_data.append(self[int(idx)]) if len(batch_data) == self.batch_size: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) yield batch_data batch_data = [] # depend on drop_last if not self.drop_last and len(batch_data) > 0: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) yield batch_data
def _get_index_list(self): if self.total_len is None: self.total_len = len(self) # maybe rewrite by sampler total_len = self.total_len if self.sampler: index_list = list(self.sampler.__iter__()) total_len = len(index_list) # check is not batch sampler if len(index_list): assert not isinstance( index_list[0], (list, tuple)), "Batch sampler not support yet." elif self.shuffle == False: index_list = get_order_list(self.total_len) else: index_list = get_random_list(self.total_len) # scatter index_list for all mpi process # scatter rule: # batch 1 batch 2 # [........] [........] ... # 00011122 00011122 # if last batch is smaller than world_size # pad to world_size # last batch # [.] -> [012] if jt.in_mpi: world_size = mpi.world_size() world_rank = mpi.world_rank() index_list = np.int32(index_list) # TODO: mpi broadcast in subprocess has bug, fix it # mpi.broadcast(index_list, 0) assert self.batch_size >= world_size, \ f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" real_batch_size = (self.batch_size - 1) // world_size + 1 if real_batch_size * world_size != self.batch_size: LOG.w("Batch size is not divisible by MPI world size, " "The distributed version may be different from " "the single-process version.") fix_batch = total_len // self.batch_size last_batch = total_len - fix_batch * self.batch_size fix_batch_l = index_list[0:fix_batch*self.batch_size] \ .reshape(-1,self.batch_size) fix_batch_l = fix_batch_l[:, real_batch_size * world_rank:real_batch_size * (world_rank + 1)] real_batch_size = fix_batch_l.shape[1] fix_batch_l = fix_batch_l.flatten() if not self.drop_last and last_batch > 0: last_batch_l = index_list[-last_batch:] real_last_batch = (last_batch - 1) // world_size + 1 l = real_last_batch * world_rank r = l + real_last_batch if r > last_batch: r = last_batch if l >= r: l = r - 1 index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) else: index_list = fix_batch_l self.real_len = len(index_list) self.real_batch_size = real_batch_size assert total_len // self.batch_size == \ self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" else: self.real_len = self.total_len self.real_batch_size = self.batch_size self.batch_len = self.__batch_len__() return index_list