def __init__(self, source, module_name="TORCH_EXTENSION_NAME", path=None, temps="temps", warm_start=False): """ Add one temps dir and copy all the file to this disk to escape pollution. Parameters ---------- source:str all the cpp source text but not file name. module_name: name of module path:str path to build. temps: str Add one temps dir. warm_start:bool start from exist file. """ self.init_path = os.getcwd() # check file if module_name == "TORCH_EXTENSION_NAME": print("please re set your module name") self.source = source if path is not None: def_pwd(path) MODULE_DIR = Path().absolute() # temps: if temps: if warm_start: if not os.path.isdir(temps): raise FileNotFoundError( "Try to warm start but without no exist {} found". format(temps)) else: if not os.path.isdir(temps): os.mkdir(temps) else: warnings.warn( "There is exist {temps}. Duplicate files will be overwritten." "please use remove() to delete temporary file after each test." .format(temps=temps)) self.temps = temps self.path = MODULE_DIR / temps else: self.temps = None self.path = MODULE_DIR # check module_name self.module_name = module_name self.build = True self.functions = [] os.chdir(self.init_path)
def quick_import(module_name, path=None, build=False, suffix="so", with_html=False, re_build_func=re_build_func, re_build_func_kwargs=None): """ Import .so file as module. Parameters ---------- re_build_func_kwargs:dict kwargs for build func re_build_func:callable build func module_name: str module_name path: path build: build or not. default is false and try to find the exist module_name.so or module_name.pyd or module_name.dll. suffix:str file type [so,pyd] and so on with_html:False just for cython to check file. Returns ------- module """ def_pwd(path) print("Move to {}".format(os.getcwd())) if build: if re_build_func_kwargs is None: re_build_func(module_name, with_html) else: re_build_func(**re_build_func_kwargs) ext = [i for i in os.listdir() if module_name in i and suffix in i] if len(ext) > 0: module = import_module(module_name, os.getcwd()) msg = "The {module_name} module methods:{methods}" names = dir(module) names = [i for i in names if "__" not in i] print(msg.format(module_name=module_name, methods=names)) return module else: raise FileNotFoundError( ": There is no './{}.***.{}' in '{}',\n".format( module_name, suffix, path), "There are just {},\n".format(os.listdir()), "Please try to build=Ture again.")
def test_data2(self): def_pwd("./raw", change=False) sg1 = BaseStructureGraphGEO() data_list = sg1.transform_and_to_data(self.data0_checked) loader = DataLoader(data_list, batch_size=3) for i in loader: print(i)
def test_data(self): def_pwd("./raw", change=False) sg1 = StructureGraphGEO( nn_strategy="find_points_in_spheres", bond_generator=None, atom_converter=None, bond_converter=None, state_converter=None, cutoff=5.0, ) sg1.transform_and_save(self.data0_checked, save_mode="i")
def test_CrystalGraph42(self): def_pwd("./raw", change=False) sg1 = StructureGraphGEO( nn_strategy="find_xyz_in_spheres", bond_generator=None, atom_converter=None, bond_converter=None, state_converter=None, cutoff=5.0, ) sg1.transform_and_save(self.data0_checked, save_mode="i") imdg = DatasetGEO(".", load_mode="i", re_process_init=False) l = imdg[2] l = imdg[2]
def to_path(self, new_path, flatten=False, add_dir="3-layer", pop=0, n_jobs=1): """ Parameters ---------- new_path:str new path flatten:bool,dict flatten the filtered file. if flatten is dict, the key is the specific dir name,and value is True. Examples: flatten = {"asp":True} add_dir:list, int add the top dir_name to file to escape same name file. only valid for flatten=True pop: int (negative) pop the last n layer. default =0 used for copy by dir rather than files. just used for flatten=False n_jobs:int n_jobs Returns ------- file in path. """ self.file_list_merge = self.merge(pop=pop) new_path = def_pwd(new_path) self.file_list_merge_new = self.merge(path=new_path, flatten=flatten, add_dir=add_dir, refresh_file_list=False, pop=pop) if len(set(self.file_list_merge_new)) < len(set(self.file_list_merge)): raise UserWarning( "There are same name files after flatten folders. " "you can change add_dir to add difference prefix to files", ) if n_jobs != 1: parallelize(n_jobs, self.copy_user, zip( self.file_list_merge, self.file_list_merge_new, ), mode="j", respective=False) else: for ij in tqdm( list(zip(self.file_list_merge, self.file_list_merge_new))): self.copy_user(ij)
def __init__(self, path=None, filename="filename", prefix: str = None): """ Parameters ---------- path:str /data_cluster, or F:data_cluster/data1 filename:str filename prefix:str prefix for all filenname """ if not prefix: prefix = "" self._prefix = prefix def_pwd(path) self._path = path self._filename = "" self.default_filename = filename self._file_list = []
def __init__(self, path=None, suffix=None): """ Parameters ---------- path:str total dir of all file suffix:str suffix of file Examples: .txt """ path = def_pwd(path) self.path = path parents = re.split(r'[\\/]', str(path)) self.parents = parents self.file_list = check_file(path, path, suffix=suffix) self.init_file = tuple(self.file_list) self.file_list_merge = [] self.file_list_merge_new = [] self.file_dir = []
def __init__(self, file, path=None, temps="temps", warm_start=False, only_file=True): """ Add one temps dir and copy all the file to this disk to escape pollution. Parameters ---------- file:str file name without path. path:str path of file. temps: str Add one temps dir and copy all the file to this disk to escape pollution. warm_start:bool start from exist file. only_file:bool just copy the source file to temps. """ self.init_path = os.getcwd() # check file if path: assert r"/" not in file, "Path must in one of `path` parameter or the `filename`." assert r"\\" not in file, "Path must in one of `path` parameter or the `filename`." else: if "/" in file or "\\" in file: path = Path(file).parent file = os.path.split(file)[-1] self.file = file self.check_suffix(self.file) if path is not None: def_pwd(path) if os.path.isfile(file): pass else: raise IOError( "No file named {} in {}, please re-site your path".format( file, os.getcwd())) MODULE_DIR = Path().absolute() # temps: if temps: if warm_start: if not os.path.isdir(temps): raise FileNotFoundError( "Try to warm start but without no exist {} found". format(temps)) else: files = os.listdir() if temps in files: files.remove(temps) if only_file: files = [ self.file, ] if not os.path.isdir(temps): os.mkdir(temps) else: warnings.warn( "There is exist {temps}. Duplicate files will be overwritten." "please use remove() to delete temporary file after test." .format(temps=temps)) for i in files: if os.path.isdir(i): if os.path.isdir(MODULE_DIR / temps / i): shutil.rmtree(MODULE_DIR / temps / i) shutil.copytree(i, MODULE_DIR / temps / i) else: shutil.copy(i, temps) self.temps = temps self.path = MODULE_DIR / temps else: self.temps = None self.path = MODULE_DIR # check module_name module_name = get_name_without_suffix(os.path.split(self.file)[-1]) self.module_name = module_name self.build = True os.chdir(self.init_path)