Exemplo n.º 1
0
    def __iter__(self):
        if self.shuffle == False:
            index_list = get_order_list(self.total_len)
        else:
            index_list = get_random_list(self.total_len)
            
        self.batch_len = len(self)
        if "batch_len" in os.environ:
            self.batch_len = int(os.environ["batch_len"])
        
        if not hasattr(self, "workers") and self.num_workers:
            self._init_workers()
        
        if self.num_workers:
            self._stop_all_workers()
            self.index_list_numpy[:] = index_list
            gid_obj = self.gid.get_obj()
            gid_lock = self.gid.get_lock()
            with gid_lock:
                gid_obj.value = 0
                self.gidc.notify_all()
            for i in range(self.batch_len):
                # try not get lock first
                if gid_obj.value <= i:
                    with gid_lock:
                        if gid_obj.value <= i:
                            if mp_log_v:
                                print("wait")
                            self.gidc.wait()
                worker_id = self.idmap[i]
                w = self.workers[worker_id]
                if mp_log_v:
                    print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
                batch = w.buffer.recv()
                if mp_log_v:
                    print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
                batch = self.to_jittor(batch)
                yield batch
        else:
            batch_data = []
            for idx in index_list:
                batch_data.append(self[int(idx)])
                if len(batch_data) == self.batch_size:
                    batch_data = self.collate_batch(batch_data)
                    batch_data = self.to_jittor(batch_data)
                    yield batch_data
                    batch_data = []

            # depend on drop_last
            if not self.drop_last and len(batch_data) > 0:
                batch_data = self.collate_batch(batch_data)
                batch_data = self.to_jittor(batch_data)
                yield batch_data
Exemplo n.º 2
0
    def __iter__(self):
        if self.shuffle == False:
            index_list = get_order_list(self.total_len)
        else:
            index_list = get_random_list(self.total_len)
        batch_data = []
        for idx in index_list:
            batch_data.append(self[int(idx)])
            if len(batch_data) == self.batch_size:
                batch_data = self.collate_batch(batch_data)
                yield batch_data
                batch_data = []

        # depend on drop_last
        if not self.drop_last and len(batch_data) > 0:
            batch_data = self.collate_batch(batch_data)
            yield batch_data
Exemplo n.º 3
0
    def __iter__(self):
        if self.total_len is None:
            self.total_len = len(self)
        if self.shuffle == False:
            index_list = get_order_list(self.total_len)
        else:
            index_list = get_random_list(self.total_len)
        
        # scatter index_list for all mpi process
        # scatter rule:
        #   batch 1   batch 2
        # [........] [........] ...
        #  00011122   00011122
        # if last batch is smaller than world_size
        # pad to world_size
        #  last batch
        # [.] -> [012]
        if jt.in_mpi:
            world_size = mpi.world_size()
            world_rank = mpi.world_rank()
            index_list = np.int32(index_list)
            mpi.broadcast(index_list, 0)

            assert self.batch_size >= world_size, \
                f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})"
            real_batch_size = (self.batch_size-1) // world_size + 1
            if real_batch_size * world_size != self.batch_size:
                LOG.w("Batch size is not divisible by MPI world size, "
                      "The distributed version may be different from "
                      "the single-process version.")
            fix_batch = self.total_len // self.batch_size
            last_batch = self.total_len - fix_batch * self.batch_size
            fix_batch_l = index_list[0:fix_batch*self.batch_size] \
                .reshape(-1,self.batch_size)
            fix_batch_l = fix_batch_l[
                :,real_batch_size*world_rank:real_batch_size*(world_rank+1)]
            real_batch_size = fix_batch_l.shape[1]
            fix_batch_l = fix_batch_l.flatten()
            if not self.drop_last and last_batch > 0:
                last_batch_l = index_list[-last_batch:]
                real_last_batch = (last_batch-1)//world_size+1
                l = real_last_batch * world_rank
                r = l + real_last_batch
                if r > last_batch: r = last_batch
                if l >= r: l = r-1
                index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]])
            else:
                index_list = fix_batch_l

            self.real_len = len(index_list)
            self.real_batch_size = real_batch_size
            assert self.total_len // self.batch_size == \
                self.real_len // self.real_batch_size
        else:
            self.real_len = self.total_len
            self.real_batch_size = self.batch_size
            
        self.batch_len = self.__batch_len__()
        
        if not hasattr(self, "workers") and self.num_workers:
            self._init_workers()
        
        if self.num_workers:
            self._stop_all_workers()
            self.index_list_numpy[:] = index_list
            gid_obj = self.gid.get_obj()
            gid_lock = self.gid.get_lock()
            with gid_lock:
                gid_obj.value = 0
                self.gidc.notify_all()
            start = time.time()
            self.batch_time = 0
            for i in range(self.batch_len):
                # try not get lock first
                if gid_obj.value <= i:
                    with gid_lock:
                        if gid_obj.value <= i:
                            if mp_log_v:
                                print("wait")
                            self.gidc.wait()
                now = time.time()
                self.wait_time = now - start
                start = now

                self.last_id = i
                worker_id = self.idmap[i]
                w = self.workers[worker_id]
                if mp_log_v:
                    print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
                batch = w.buffer.recv()
                now = time.time()
                self.recv_time = now - start
                start = now

                if mp_log_v:
                    print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
                batch = self.to_jittor(batch)
                now = time.time()
                self.to_jittor_time = now - start
                start = now

                yield batch

                now = time.time()
                self.batch_time = now - start
                start = now
        else:
            batch_data = []
            for idx in index_list:
                batch_data.append(self[int(idx)])
                if len(batch_data) == self.real_batch_size:
                    batch_data = self.collate_batch(batch_data)
                    batch_data = self.to_jittor(batch_data)
                    yield batch_data
                    batch_data = []

            # depend on drop_last
            if not self.drop_last and len(batch_data) > 0:
                batch_data = self.collate_batch(batch_data)
                batch_data = self.to_jittor(batch_data)
                yield batch_data
Exemplo n.º 4
0
    def __iter__(self):
        if self.shuffle == False:
            index_list = get_order_list(self.total_len)
        else:
            index_list = get_random_list(self.total_len)

        # scatter index_list for all mpi process
        # scatter rule:
        # [........]
        #  000111
        #       222
        # make sure each process has the same len
        if mpi:
            index_list = np.int32(index_list)
            mpi.broadcast(index_list, 0)
            real_len = (self.total_len - 1) // mpi.world_size() + 1
            offset = mpi.world_rank() * real_len
            if offset + real_len > self.total_len:
                offset -= offset + real_len - self.total_len
            index_list = index_list[offset:offset + real_len]
            self.real_len = real_len
            assert real_len == len(index_list)
        else:
            self.real_len = self.total_len

        self.batch_len = len(self)
        if "batch_len" in os.environ:
            self.batch_len = int(os.environ["batch_len"])

        if not hasattr(self, "workers") and self.num_workers:
            self._init_workers()

        if self.num_workers:
            self._stop_all_workers()
            self.index_list_numpy[:] = index_list
            gid_obj = self.gid.get_obj()
            gid_lock = self.gid.get_lock()
            with gid_lock:
                gid_obj.value = 0
                self.gidc.notify_all()
            for i in range(self.batch_len):
                # try not get lock first
                if gid_obj.value <= i:
                    with gid_lock:
                        if gid_obj.value <= i:
                            if mp_log_v:
                                print("wait")
                            self.gidc.wait()
                worker_id = self.idmap[i]
                w = self.workers[worker_id]
                if mp_log_v:
                    print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
                batch = w.buffer.recv()
                if mp_log_v:
                    print(f"#{worker_id} {os.getpid()} recv",
                          type(batch).__name__,
                          [type(b).__name__ for b in batch])
                batch = self.to_jittor(batch)
                yield batch
        else:
            batch_data = []
            for idx in index_list:
                batch_data.append(self[int(idx)])
                if len(batch_data) == self.batch_size:
                    batch_data = self.collate_batch(batch_data)
                    batch_data = self.to_jittor(batch_data)
                    yield batch_data
                    batch_data = []

            # depend on drop_last
            if not self.drop_last and len(batch_data) > 0:
                batch_data = self.collate_batch(batch_data)
                batch_data = self.to_jittor(batch_data)
                yield batch_data
Exemplo n.º 5
0
    def _get_index_list(self):
        if self.total_len is None:
            self.total_len = len(self)
        # maybe rewrite by sampler
        total_len = self.total_len
        if self.sampler:
            index_list = list(self.sampler.__iter__())
            total_len = len(index_list)
            # check is not batch sampler
            if len(index_list):
                assert not isinstance(
                    index_list[0],
                    (list, tuple)), "Batch sampler not support yet."
        elif self.shuffle == False:
            index_list = get_order_list(self.total_len)
        else:
            index_list = get_random_list(self.total_len)

        # scatter index_list for all mpi process
        # scatter rule:
        #   batch 1   batch 2
        # [........] [........] ...
        #  00011122   00011122
        # if last batch is smaller than world_size
        # pad to world_size
        #  last batch
        # [.] -> [012]
        if jt.in_mpi:
            world_size = mpi.world_size()
            world_rank = mpi.world_rank()
            index_list = np.int32(index_list)
            # TODO: mpi broadcast in subprocess has bug, fix it
            # mpi.broadcast(index_list, 0)

            assert self.batch_size >= world_size, \
                f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})"
            real_batch_size = (self.batch_size - 1) // world_size + 1
            if real_batch_size * world_size != self.batch_size:
                LOG.w("Batch size is not divisible by MPI world size, "
                      "The distributed version may be different from "
                      "the single-process version.")
            fix_batch = total_len // self.batch_size
            last_batch = total_len - fix_batch * self.batch_size
            fix_batch_l = index_list[0:fix_batch*self.batch_size] \
                .reshape(-1,self.batch_size)
            fix_batch_l = fix_batch_l[:, real_batch_size *
                                      world_rank:real_batch_size *
                                      (world_rank + 1)]
            real_batch_size = fix_batch_l.shape[1]
            fix_batch_l = fix_batch_l.flatten()
            if not self.drop_last and last_batch > 0:
                last_batch_l = index_list[-last_batch:]
                real_last_batch = (last_batch - 1) // world_size + 1
                l = real_last_batch * world_rank
                r = l + real_last_batch
                if r > last_batch: r = last_batch
                if l >= r: l = r - 1
                index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]])
            else:
                index_list = fix_batch_l

            self.real_len = len(index_list)
            self.real_batch_size = real_batch_size
            assert total_len // self.batch_size == \
                self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}"
        else:
            self.real_len = self.total_len
            self.real_batch_size = self.batch_size
        self.batch_len = self.__batch_len__()
        return index_list