def import_user_module(args): module_path = getattr(args, "user_dir", None) if module_path is not None: module_path = os.path.abspath(args.user_dir) if not os.path.exists(module_path) and not os.path.isfile( os.path.dirname(module_path)): fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) if os.path.exists(fairseq_rel_path): module_path = fairseq_rel_path else: fairseq_rel_path = os.path.join(os.path.dirname(__file__), "..", args.user_dir) if os.path.exists(fairseq_rel_path): module_path = fairseq_rel_path else: raise FileNotFoundError(module_path) # ensure that user modules are only imported once import_user_module.memo = getattr(import_user_module, "memo", set()) if module_path not in import_user_module.memo: import_user_module.memo.add(module_path) module_parent, module_name = os.path.split(module_path) if module_name not in sys.modules: sys.path.insert(0, module_parent) importlib.import_module(module_name) tasks_path = os.path.join(module_path, "tasks") if os.path.exists(tasks_path): from fairseq.tasks import import_tasks import_tasks(tasks_path, f"{module_name}.tasks") models_path = os.path.join(module_path, "models") if os.path.exists(models_path): from fairseq.models import import_models import_models(models_path, f"{module_name}.models") elif module_path in sys.modules[module_name].__path__: logger.info( f"--user-dir={module_path} has already been imported.") else: raise ImportError( "Failed to import --user-dir={} because the corresponding module name " "({}) is not globally unique. Please rename the directory to " "something unique and try again.".format( module_path, module_name))
# Copyright (c) Yiming Wang # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os from fairseq.tasks import import_tasks # automatically import any Python files in the tasks/ directory tasks_dir = os.path.dirname(__file__) import_tasks(tasks_dir, "espresso.tasks")