示例#1
0
def setup_imports():
    from ocpmodels.common.registry import registry

    # First, check if imports are already setup
    has_already_setup = registry.get("imports_setup", no_warning=True)
    if has_already_setup:
        return
    # Automatically load all of the modules, so that
    # they register with registry
    root_folder = registry.get("ocpmodels_root", no_warning=True)

    if root_folder is None:
        root_folder = os.path.dirname(os.path.abspath(__file__))
        root_folder = os.path.join(root_folder, "..")

    trainer_folder = os.path.join(root_folder, "trainers")
    trainer_pattern = os.path.join(trainer_folder, "**", "*.py")
    datasets_folder = os.path.join(root_folder, "datasets")
    datasets_pattern = os.path.join(datasets_folder, "*.py")
    model_folder = os.path.join(root_folder, "models")
    model_pattern = os.path.join(model_folder, "*.py")

    importlib.import_module("ocpmodels.common.meter")

    files = (
        glob.glob(datasets_pattern, recursive=True)
        + glob.glob(model_pattern, recursive=True)
        + glob.glob(trainer_pattern, recursive=True)
    )

    for f in files:
        for key in ["/trainers", "/datasets", "/models"]:
            if f.find(key) != -1:
                splits = f.split(os.sep)
                file_name = splits[-1]
                module_name = file_name[: file_name.find(".py")]
                importlib.import_module(
                    "ocpmodels.%s.%s" % (key[1:], module_name)
                )

    registry.register("imports_setup", True)
示例#2
0
文件: utils.py 项目: wood-b/ocp
def setup_imports():
    from ocpmodels.common.registry import registry

    # First, check if imports are already setup
    has_already_setup = registry.get("imports_setup", no_warning=True)
    if has_already_setup:
        return
    # Automatically load all of the modules, so that
    # they register with registry
    root_folder = registry.get("ocpmodels_root", no_warning=True)

    if root_folder is None:
        root_folder = os.path.dirname(os.path.abspath(__file__))
        root_folder = os.path.join(root_folder, "..")

    trainer_folder = os.path.join(root_folder, "trainers")
    trainer_pattern = os.path.join(trainer_folder, "**", "*.py")
    datasets_folder = os.path.join(root_folder, "datasets")
    datasets_pattern = os.path.join(datasets_folder, "*.py")
    model_folder = os.path.join(root_folder, "models")
    model_pattern = os.path.join(model_folder, "*.py")
    task_folder = os.path.join(root_folder, "tasks")
    task_pattern = os.path.join(task_folder, "*.py")

    importlib.import_module("ocpmodels.common.logger")

    files = (glob.glob(datasets_pattern, recursive=True) +
             glob.glob(model_pattern, recursive=True) +
             glob.glob(trainer_pattern, recursive=True) +
             glob.glob(task_pattern, recursive=True))

    for f in files:
        for key in ["/trainers", "/datasets", "/models", "/tasks"]:
            if f.find(key) != -1:
                splits = f.split(os.sep)
                file_name = splits[-1]
                module_name = file_name[:file_name.find(".py")]
                importlib.import_module("ocpmodels.%s.%s" %
                                        (key[1:], module_name))

    experimental_folder = os.path.join(root_folder, "../experimental/")
    if os.path.exists(experimental_folder):
        experimental_files = glob.glob(
            experimental_folder + "**/*py",
            recursive=True,
        )
        # Ignore certain directories within experimental
        ignore_file = os.path.join(experimental_folder, ".ignore")
        if os.path.exists(ignore_file):
            ignored = []
            with open(ignore_file) as f:
                for line in f.read().splitlines():
                    ignored += glob.glob(experimental_folder + line +
                                         "/**/*py",
                                         recursive=True)
            for f in ignored:
                experimental_files.remove(f)
        for f in experimental_files:
            splits = f.split(os.sep)
            file_name = ".".join(splits[-splits[::-1].index(".."):])
            module_name = file_name[:file_name.find(".py")]
            importlib.import_module(module_name)

    registry.register("imports_setup", True)