예제 #1
0
def register_builtin_tasks():
    register_tasks(
        (
            ContextualIntentSlotTask_Deprecated,
            DisjointMultitask,
            DocClassificationTask_Deprecated,
            DocumentClassificationTask,
            DocumentRegressionTask,
            EnsembleTask,
            EnsembleTask_Deprecated,
            JointTextTask_Deprecated,
            LMTask,
            LMTask_Deprecated,
            NewDisjointMultitask,
            PairClassificationTask_Deprecated,
            PairwiseClassificationTask,
            QueryDocumentPairwiseRankingTask,
            QueryDocumentPairwiseRankingTask_Deprecated,
            SemanticParsingTask_Deprecated,
            SeqNNTask,
            SeqNNTask_Deprecated,
            WordTaggingTask,
            WordTaggingTask_Deprecated,
        )
    )
예제 #2
0
def add_include(path):
    """
    Import tasks (and associated components) from the folder name.
    """
    eprint("Including:", path)
    modules = glob.glob(os.path.join(path, "*.py"))
    all = [
        os.path.basename(f)[:-3].replace("/", ".")
        for f in modules
        if os.path.isfile(f) and not f.endswith("__init__.py")
    ]
    for mod_name in all:
        mod_path = path + "." + mod_name
        eprint("... importing module:", mod_path)
        my_module = importlib.import_module(mod_path)

        for m in inspect.getmembers(my_module, inspect.isclass):
            if m[1].__module__ != mod_path:
                pass
            elif Task_Deprecated in m[1].__bases__ or NewTask in m[1].__bases__:
                eprint("... task:", m[1].__name__)
                register_tasks(m[1])
            else:
                eprint("... importing:", m[1])
                importlib.import_module(mod_path, m[1])
예제 #3
0
def register_builtin_tasks():
    register_tasks((
        BertPairRegressionTask,
        DisjointMultitask,
        DocClassificationTask_Deprecated,
        DocumentClassificationTask,
        DocumentRegressionTask,
        EnsembleTask,
        IntentSlotTask,
        LMTask,
        LMTask_Deprecated,
        MaskedLMTask,
        NewBertClassificationTask,
        NewBertPairClassificationTask,
        NewDisjointMultitask,
        PairwiseClassificationTask,
        QueryDocumentPairwiseRankingTask,
        QueryDocumentPairwiseRankingTask_Deprecated,
        SemanticParsingTask,
        SemanticParsingTask_Deprecated,
        SeqNNTask,
        SeqNNTask_Deprecated,
        SquadQATask,
        WordTaggingTask,
        WordTaggingTask_Deprecated,
    ))
예제 #4
0
def gen_config_impl(task_name, options):
    # import the classes required by parameters
    requested_classes = [locate(opt) for opt in options] + [locate(task_name)]
    register_tasks(requested_classes)

    task_class_set = find_config_class(task_name)
    if not task_class_set:
        raise Exception(f"Unknown task class: {task_name} "
                        "(try fully qualified class name?)")
    elif len(task_class_set) > 1:
        raise Exception(f"Multiple tasks named {task_name}: {task_class_set}")

    task_class = next(iter(task_class_set))
    task_config = getattr(task_class, "example_config", task_class.Config)
    root = PyTextConfig(task=task_config(), version=LATEST_VERSION)
    eprint("INFO - Applying task option:", task_class.__name__)

    # Use components listed in options instead of defaults
    for opt in options:
        if "=" in opt:
            param_path, value = opt.split("=", 1)
            found = find_param(root, "." + param_path)
            if len(found) == 1:
                eprint("INFO - Applying parameter option to", found[0], ":",
                       opt)
                replace_param(root, found[0].split("."), value)
            elif not found:
                raise Exception(f"Unknown parameter option: {opt}")
            else:
                raise Exception(
                    f"Multiple possibilities for {opt}: {', '.join(found)}")
        else:
            replace_class_set = find_config_class(opt)
            if not replace_class_set:
                raise Exception(f"Not a component class: {opt}")
            elif len(replace_class_set) > 1:
                raise Exception(
                    f"Multiple component named {opt}: {replace_class_set}")
            replace_class = next(iter(replace_class_set))
            found = replace_components(root, opt,
                                       get_subclasses(replace_class))
            if found:
                eprint(
                    "INFO - Applying class option:",
                    "->".join(reversed(found)),
                    "=",
                    opt,
                )
                obj = root
                for k in reversed(found[1:]):
                    obj = getattr(obj, k)
                if hasattr(replace_class, "Config"):
                    setattr(obj, found[0], replace_class.Config())
                else:
                    setattr(obj, found[0], replace_class())
            else:
                raise Exception(f"Unknown class option: {opt}")
    return root
예제 #5
0
def register_builtin_tasks():
    register_tasks((
        DocClassificationTask,
        WordTaggingTask,
        JointTextTask,
        LMTask,
        EnsembleTask,
        PairClassification,
        SeqNNTask,
        ContextualIntentSlotTask,
        SemanticParsingTask,
        DisjointMultitask,
    ))
예제 #6
0
def register_builtin_tasks():
    register_tasks(
        (
            ContextualIntentSlotTask_Deprecated,
            DisjointMultitask,
            DocClassificationTask_Deprecated,
            EnsembleTask_Deprecated,
            JointTextTask_Deprecated,
            LMTask_Deprecated,
            NewDisjointMultitask,
            PairClassificationTask_Deprecated,
            SemanticParsingTask_Deprecated,
            SeqNNTask_Deprecated,
            WordTaggingTask_Deprecated,
        )
    )
예제 #7
0
파일: main.py 프로젝트: parety/pytext
def gen_config_impl(task_name, options):
    # import the classes required by parameters
    requested_classes = [locate(opt) for opt in options] + [locate(task_name)]
    register_tasks(requested_classes)

    task_class_set = find_config_class(task_name)
    if not task_class_set:
        raise Exception(f"Unknown task class: {task_name} "
                        "(try fully qualified class name?)")
    elif len(task_class_set) > 1:
        raise Exception(f"Multiple tasks named {task_name}: {task_class_set}")

    task_class = next(iter(task_class_set))
    task_config = getattr(task_class, "example_config", task_class.Config)
    root = PyTextConfig(task=task_config())

    # Use components listed in options instead of defaults
    for opt in options:
        replace_class_set = find_config_class(opt)
        if not replace_class_set:
            raise Exception(f"Not a component class: {opt}")
        elif len(replace_class_set) > 1:
            raise Exception(
                f"Multiple component named {opt}: {replace_class_set}")
        replace_class = next(iter(replace_class_set))
        found = replace_components(root, opt, set(replace_class.__bases__))
        if found:
            eprint("INFO - Applying option:", "->".join(reversed(found)), "=",
                   opt)
            obj = root
            for k in reversed(found[1:]):
                obj = getattr(obj, k)
            if hasattr(replace_class, "Config"):
                setattr(obj, found[0], replace_class.Config())
            else:
                setattr(obj, found[0], replace_class())
        else:
            raise Exception(f"Unknown option: {opt}")
    return config_to_json(PyTextConfig, root)