Exemple #1
0
    def prepare_functions(self):
        for line in open(self._reg_dec_file_path, 'r'):
            m = re.match(r'\s*([^\s].*); //\s+(.*)', line)
            if not m:
                continue
            cpp_func_sig = m.group(1).replace('at::', '').replace('c10::', '')
            aten_func_sig_literal = m.group(2)

            aten_func_sig = aten_func_sig_literal
            if "schema" in aten_func_sig_literal and "compound" in aten_func_sig_literal:
                res = json.loads(aten_func_sig_literal)
                aten_func_sig = res["schema"]

            if not self.is_tensor_api(cpp_func_sig):
                continue

            try:
                cpp_sig = CPPSig(cpp_func_sig)
                if self.is_bypass_func(cpp_sig):
                    continue

                aten_sig = AtenSig(aten_func_sig)

                self.cross_correct_sig(cpp_sig, aten_sig)

                self._sigs.append(
                    (cpp_sig, aten_sig, cpp_func_sig, aten_func_sig))
            except Exception as e:
                self._err_info.append((cpp_func_sig, str(e)))
                print('Error parsing "{}": {}'.format(cpp_func_sig, e),
                      file=sys.stderr)

        with open(self._func_file_path, 'r') as ff:
            self._func_data = ff.read()

        print('Extracted {} functions ({} errors) from {}'.format(
            len(self._sigs), len(self._err_info), self._reg_dec_file_path),
              file=sys.stderr)
        assert len(self._err_info) == 0
    def prepare_functions(self):
        # Parse SparseCPUType.h
        _sparse_sig_strs = []
        for line in open(self._sparse_dec_file_path, 'r'):
            m = re.match(r'\s*([^\s].*\));', line)
            if not m:
                continue
            cpp_func_sig_str = m.group(1)
            _sparse_sig_strs.append(cpp_func_sig_str)
        #     print(cpp_func_sig_str)
        # print("********************")

        # Parse SparseAttrType.h
        with open(self._sparse_attr_file_path, 'r') as ff:
            self._sparse_attr_data = ff.read()

        # Parse Functions.h
        with open(self._func_file_path, 'r') as ff:
            self._func_data = ff.read()

        # Parse Registration declartion.h
        for line in open(self._reg_dec_file_path, 'r'):
            m = re.match(r'\s*([^\s].*); //\s+(.*)', line)
            if not m:
                continue
            cpp_func_sig = m.group(1).replace('at::', '').replace('c10::', '')
            aten_func_sig_literal = m.group(2)

            aten_func_sig = aten_func_sig_literal
            if "schema" in aten_func_sig_literal and "compound" in aten_func_sig_literal:
                res = json.loads(aten_func_sig_literal)
                aten_func_sig = res["schema"]

            if not self.is_tensor_api(cpp_func_sig):
                continue

            try:
                cpp_sig = CPPSig(cpp_func_sig)
                if self.is_bypass_func(cpp_sig):
                    continue

                for sparse_cpp_sig_str in _sparse_sig_strs:
                    if sparse_cpp_sig_str.find("clone") >= 0 and cpp_func_sig.find("clone") >= 0:
                        print("{} {}".format(sparse_cpp_sig_str, cpp_func_sig))
                    if sparse_cpp_sig_str.replace(' ', '') == cpp_func_sig.replace(' ', ''):
                        sparse_sig = CPPSig(sparse_cpp_sig_str)
                        sparse_sig.is_tensor_member_func = self.is_tensor_member_function(sparse_sig.def_name)
                        aten_sig = AtenSig(aten_func_sig)
                        self.cross_correct_sig(sparse_sig, aten_sig)
                        self._sigs.append((sparse_sig, aten_sig, sparse_cpp_sig_str, aten_func_sig))
                    else:
                        continue
            except Exception as e:
                self._err_info.append((cpp_func_sig, str(e)))
                print('Error parsing "{}": {}'.format(cpp_func_sig, e), file=sys.stderr)

        print('Extracted {} functions ({} errors) from {}'.format(
              len(self._sigs),
              len(self._err_info),
              self._reg_dec_file_path),
            file=sys.stderr)
        assert len(self._err_info) == 0