Beispiel #1
0
 def check(self, name, pre_data, data):
     global has_error
     if pre_data is None and isinstance(data, np.ndarray):
         if (data == 0).all():
             LOG.i(f"name {name} is None")
         else:
             LOG.e(f"name {name} is non-zero")
         return
     if type(pre_data) != type(data):
         LOG.e(
             f"type not match, {pre_data.__class__.__name__}!={data.__class__.__name__}, name: {name}"
         )
         has_error += 1
         return
     if isinstance(pre_data, (list, tuple)):
         if len(pre_data) != len(data):
             has_error += 1
             LOG.e(
                 f"Name <{name}> len not match, {len(pre_data)} != {len(data)}"
             )
         n = max(len(pre_data), len(data))
         for i in range(n):
             a = pre_data[i] if i < len(pre_data) else "None"
             b = data[i] if i < len(data) else "None"
             self.check(name + f".{i}", a, b)
     elif isinstance(pre_data, np.ndarray):
         if len(pre_data.shape) == 0:
             pre_data = np.array([pre_data])
         if len(data.shape) == 0:
             data = np.array([data])
         if pre_data.shape != data.shape:
             has_error += 1
             LOG.e(
                 f"Ndarray shape <{name}> not match {pre_data.shape} != {data.shape}"
             )
             return
         self.check_array(name, pre_data, data)
     elif isinstance(pre_data, dict):
         if len(pre_data) != len(data):
             has_error += 1
             LOG.w(
                 f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}"
             )
         for k in pre_data:
             pv = pre_data[k]
             if k not in data:
                 has_error += 1
                 msg = f"Key <{k}> not in data, Name <{name}>"
                 if isinstance(pv, np.ndarray):
                     LOG.e(msg)
                 else:
                     LOG.w(msg)
                 continue
             self.check(name + f".{k}", pre_data[k], data[k])
     else:
         if pre_data != data:
             has_error += 1
             LOG.e(
                 f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}"
             )
Beispiel #2
0
def compile_extern():
    # compile llvm passes
    if cc_type != "clang":
        return
    global kernel_opt_flags
    cache_path_llvm = os.path.join(cache_path, "llvm")
    jittor_path_llvm = os.path.join(jittor_path, "extern", "llvm")
    clang_dir = os.path.dirname(get_full_path_of_executable(cc_path))
    assert clang_dir.endswith(
        "bin") and "llvm" in clang_dir, f"Wrong clang_dir: {clang_dir}"
    llvm_include = os.path.abspath(os.path.join(clang_dir, "..", "include"))
    assert os.path.isdir(llvm_include), "LLVM include path not found"
    make_cache_dir(cache_path_llvm)
    files = os.listdir(jittor_path_llvm)
    # test_pass.cc is used for test link problem of llvm pass plugin
    test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
    with open(test_pass_path, 'w') as f:
        f.write("int main() {return 0;}")

    # -fno-rtti fix link error

    # -Wl,-znodelete fix segfault
    # https://github.com/sampsyo/llvm-pass-skeleton/issues/7#issuecomment-401834287

    # -D_GLIBCXX_USE_CXX11_ABI=0 fix undefined symbol: createPrinterPass
    # https://stackoverflow.com/questions/37366291/undefined-symbol-for-self-built-llvm-opt

    # try different flags
    try_flags = [
        " -Wl,-znodelete -D_GLIBCXX_USE_CXX11_ABI=0 ",
        " -Wl,-znodelete ",
    ]
    found_flags_id = -1
    for fname in files:
        for i, flag in enumerate(try_flags):
            if found_flags_id != -1 and found_flags_id != i:
                continue
            so_name = os.path.join(cache_path_llvm,
                                   os.path.splitext(fname)[0] + f".{i}.so")
            compile(cc_path,
                    f"{cc_flags} {opt_flags} {flag} -I'{llvm_include}'",
                    [os.path.join(jittor_path_llvm, fname)], so_name)
            # if not found available flags, we test it.
            if found_flags_id == -1:
                try:
                    s = run_cmd(
                        f"{cc_path} {cc_flags} -Xclang -load -Xclang '{so_name}' {test_pass_path}",
                        cache_path_llvm,
                        print_error=False)
                except Exception as e:
                    LOG.v(f"Try flag {flag} failed: {e}")
                    continue
                found_flags_id = i
            kernel_opt_flags += f" -Xclang -load -Xclang '{so_name}' "
            break
        else:
            LOG.w("Clang is used, but LLVM pass plugin is unable to link.")
            break
    LOG.vv(f"Compile extern llvm passes: {str(files)}")
Beispiel #3
0
    def check_array(self, name, a, b):
        rtol = self.rtol
        atol = self.atol
        global has_error
        err = np.abs(a-b)
        tol = atol + rtol * np.abs(b)
        is_error = np.logical_or( err > tol, (a>=-1e-5)!=(b>=-1e-5))
        index = np.where(is_error)
        assert len(index)>0
        if len(index[0]) == 0:
            return

        has_error += 1
        LOG.w(f"Ndarray <{name}> not match, shape:{a.shape}")
        i = tuple( i[0] for i in index )
        err_rate = is_error.mean()
        LOG.w(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}% amean({a.mean()}) bmean({b.mean()}) astd({a.std()}) bstd({b.std()}) ")
        if err_rate > 0.01:
            LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
Beispiel #4
0
    def __iter__(self):
        if self.total_len is None:
            self.total_len = len(self)
        if self.shuffle == False:
            index_list = get_order_list(self.total_len)
        else:
            index_list = get_random_list(self.total_len)
        
        # scatter index_list for all mpi process
        # scatter rule:
        #   batch 1   batch 2
        # [........] [........] ...
        #  00011122   00011122
        # if last batch is smaller than world_size
        # pad to world_size
        #  last batch
        # [.] -> [012]
        if jt.in_mpi:
            world_size = mpi.world_size()
            world_rank = mpi.world_rank()
            index_list = np.int32(index_list)
            mpi.broadcast(index_list, 0)

            assert self.batch_size >= world_size, \
                f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})"
            real_batch_size = (self.batch_size-1) // world_size + 1
            if real_batch_size * world_size != self.batch_size:
                LOG.w("Batch size is not divisible by MPI world size, "
                      "The distributed version may be different from "
                      "the single-process version.")
            fix_batch = self.total_len // self.batch_size
            last_batch = self.total_len - fix_batch * self.batch_size
            fix_batch_l = index_list[0:fix_batch*self.batch_size] \
                .reshape(-1,self.batch_size)
            fix_batch_l = fix_batch_l[
                :,real_batch_size*world_rank:real_batch_size*(world_rank+1)]
            real_batch_size = fix_batch_l.shape[1]
            fix_batch_l = fix_batch_l.flatten()
            if not self.drop_last and last_batch > 0:
                last_batch_l = index_list[-last_batch:]
                real_last_batch = (last_batch-1)//world_size+1
                l = real_last_batch * world_rank
                r = l + real_last_batch
                if r > last_batch: r = last_batch
                if l >= r: l = r-1
                index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]])
            else:
                index_list = fix_batch_l

            self.real_len = len(index_list)
            self.real_batch_size = real_batch_size
            assert self.total_len // self.batch_size == \
                self.real_len // self.real_batch_size
        else:
            self.real_len = self.total_len
            self.real_batch_size = self.batch_size
            
        self.batch_len = self.__batch_len__()
        
        if not hasattr(self, "workers") and self.num_workers:
            self._init_workers()
        
        if self.num_workers:
            self._stop_all_workers()
            self.index_list_numpy[:] = index_list
            gid_obj = self.gid.get_obj()
            gid_lock = self.gid.get_lock()
            with gid_lock:
                gid_obj.value = 0
                self.gidc.notify_all()
            start = time.time()
            self.batch_time = 0
            for i in range(self.batch_len):
                # try not get lock first
                if gid_obj.value <= i:
                    with gid_lock:
                        if gid_obj.value <= i:
                            if mp_log_v:
                                print("wait")
                            self.gidc.wait()
                now = time.time()
                self.wait_time = now - start
                start = now

                self.last_id = i
                worker_id = self.idmap[i]
                w = self.workers[worker_id]
                if mp_log_v:
                    print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
                batch = w.buffer.recv()
                now = time.time()
                self.recv_time = now - start
                start = now

                if mp_log_v:
                    print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
                batch = self.to_jittor(batch)
                now = time.time()
                self.to_jittor_time = now - start
                start = now

                yield batch

                now = time.time()
                self.batch_time = now - start
                start = now
        else:
            batch_data = []
            for idx in index_list:
                batch_data.append(self[int(idx)])
                if len(batch_data) == self.real_batch_size:
                    batch_data = self.collate_batch(batch_data)
                    batch_data = self.to_jittor(batch_data)
                    yield batch_data
                    batch_data = []

            # depend on drop_last
            if not self.drop_last and len(batch_data) > 0:
                batch_data = self.collate_batch(batch_data)
                batch_data = self.to_jittor(batch_data)
                yield batch_data
Beispiel #5
0
if has_cuda:
    nvcc_flags = convert_nvcc_flags(cc_flags)
    nvcc_version = list(jit_utils.get_int_version(nvcc_path))
    max_arch = 1000
    if nvcc_version < [
            11,
    ]:
        max_arch = 75
    elif nvcc_version < [11, 1]:
        max_arch = 80
    if len(flags.cuda_archs):
        min_arch = 30
        archs = []
        for arch in flags.cuda_archs:
            if arch < min_arch:
                LOG.w(f"CUDA arch({arch})<{min_arch} is not supported")
                continue
            if arch > max_arch:
                LOG.w(
                    f"CUDA arch({arch})>{max_arch} will be backward-compatible"
                )
                arch = max_arch
            archs.append(arch)
        flags.cuda_archs = archs
        nvcc_flags += f" -arch=compute_{min(archs)} "
        nvcc_flags += ''.join(map(lambda x: f' -code=sm_{x} ', archs))

flags.cc_path = cc_path
flags.cc_type = cc_type
flags.cc_flags = cc_flags + kernel_opt_flags
flags.nvcc_path = nvcc_path
Beispiel #6
0
    def _get_index_list(self):
        if self.total_len is None:
            self.total_len = len(self)
        # maybe rewrite by sampler
        total_len = self.total_len
        if self.sampler:
            index_list = list(self.sampler.__iter__())
            total_len = len(index_list)
            # check is not batch sampler
            if len(index_list):
                assert not isinstance(
                    index_list[0],
                    (list, tuple)), "Batch sampler not support yet."
        elif self.shuffle == False:
            index_list = get_order_list(self.total_len)
        else:
            # using _shuffle_rng to generate multiprocess
            # consist shuffle list
            # index_list = get_random_list(self.total_len)
            index_list = self._shuffle_rng.permutation(range(self.total_len))

        # scatter index_list for all mpi process
        # scatter rule:
        #   batch 1   batch 2
        # [........] [........] ...
        #  00011122   00011122
        # if last batch is smaller than world_size
        # pad to world_size
        #  last batch
        # [.] -> [012]
        if jt.in_mpi:
            world_size = mpi.world_size()
            world_rank = mpi.world_rank()
            index_list = np.int32(index_list)
            # TODO: mpi broadcast in subprocess has bug, fix it
            # mpi.broadcast(index_list, 0)

            assert self.batch_size >= world_size, \
                f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})"
            real_batch_size = (self.batch_size - 1) // world_size + 1
            if real_batch_size * world_size != self.batch_size:
                LOG.w("Batch size is not divisible by MPI world size, "
                      "The distributed version may be different from "
                      "the single-process version.")
            fix_batch = total_len // self.batch_size
            last_batch = total_len - fix_batch * self.batch_size
            fix_batch_l = index_list[0:fix_batch*self.batch_size] \
                .reshape(-1,self.batch_size)
            fix_batch_l = fix_batch_l[:, real_batch_size *
                                      world_rank:real_batch_size *
                                      (world_rank + 1)]
            real_batch_size = fix_batch_l.shape[1]
            fix_batch_l = fix_batch_l.flatten()
            if not self.drop_last and last_batch > 0:
                last_batch_l = index_list[-last_batch:]
                real_last_batch = (last_batch - 1) // world_size + 1
                l = real_last_batch * world_rank
                r = l + real_last_batch
                if r > last_batch: r = last_batch
                if l >= r: l = r - 1
                index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]])
            else:
                index_list = fix_batch_l

            self.real_len = len(index_list)
            self.real_batch_size = real_batch_size
            assert total_len // self.batch_size == \
                self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}"
        else:
            self.real_len = self.total_len
            self.real_batch_size = self.batch_size
        self.batch_len = self.__batch_len__()
        return index_list
Beispiel #7
0
 def get_param_name(self, p):
     if id(p) not in self.param_name_map:
         LOG.w("Param name not found", p.shape, id(p))
         return "noname" + str(list(p.shape))
     return self.param_name_map[id(p)]