示例#1
0
    def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args,
                **extra_kwargs):
        if isinstance(cfg, DictConfig):
            choice = cfg._name

            if choice and choice in DATACLASS_REGISTRY:
                dc = DATACLASS_REGISTRY[choice]
                cfg = merge_with_parent(dc(), cfg)
        elif isinstance(cfg, str):
            choice = cfg
            if choice in DATACLASS_REGISTRY:
                cfg = DATACLASS_REGISTRY[choice]()
        else:
            choice = getattr(cfg, registry_name, None)
            if choice in DATACLASS_REGISTRY:
                cfg = populate_dataclass(DATACLASS_REGISTRY[choice](), cfg)

        if choice is None:
            if required:
                raise ValueError("{} is required!".format(registry_name))
            return None

        cls = REGISTRY[choice]
        if hasattr(cls, "build_" + registry_name):
            builder = getattr(cls, "build_" + registry_name)
        else:
            builder = cls

        return builder(cfg, *extra_args, **extra_kwargs)
示例#2
0
def build_model(cfg: DictConfig, task):

    model = None
    model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None)

    if not model_type and len(cfg) == 1:
        # this is hit if config object is nested in directory that is named after model type

        model_type = next(iter(cfg))
        if model_type in MODEL_DATACLASS_REGISTRY:
            cfg = cfg[model_type]
        else:
            raise Exception(
                "Could not infer model type from directory. Please add _name field to indicate model type"
            )

    if model_type in ARCH_MODEL_REGISTRY:
        # case 1: legacy models
        model = ARCH_MODEL_REGISTRY[model_type]
    elif model_type in MODEL_DATACLASS_REGISTRY:
        # case 2: config-driven models
        model = MODEL_REGISTRY[model_type]

    if model_type in MODEL_DATACLASS_REGISTRY:
        # set defaults from dataclass. note that arch name and model name can be the same
        dc = MODEL_DATACLASS_REGISTRY[model_type]
        cfg = merge_with_parent(dc(), cfg)

    assert model is not None, f"Could not infer model type from {cfg}"

    return model.build_model(cfg, task)
示例#3
0
def add_defaults(cfg: DictConfig) -> None:
    """This function adds default values that are stored in dataclasses that hydra doesn't know about """

    from fairseq.registry import REGISTRIES
    from fairseq.tasks import TASK_DATACLASS_REGISTRY
    from fairseq.models import ARCH_MODEL_NAME_REGISTRY, MODEL_DATACLASS_REGISTRY
    from fairseq.dataclass.utils import merge_with_parent
    from typing import Any

    for k, v in FairseqConfig.__dataclass_fields__.items():
        field_cfg = cfg.get(k)
        if field_cfg is not None and v.type == Any:
            dc = None

            if isinstance(field_cfg, str):
                field_cfg = DictConfig({"_name": field_cfg})
                field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"]

            name = field_cfg.get("_name")

            if k == "task":
                dc = TASK_DATACLASS_REGISTRY.get(name)
            elif k == "model":
                name = ARCH_MODEL_NAME_REGISTRY.get(name, name)
                dc = MODEL_DATACLASS_REGISTRY.get(name)
            elif k in REGISTRIES:
                dc = REGISTRIES[k]["dataclass_registry"].get(name)

            if dc is not None:
                cfg[k] = merge_with_parent(dc, field_cfg)
示例#4
0
文件: __init__.py 项目: tma15/fairseq
def build_model(cfg: FairseqDataclass, task, from_checkpoint=False):

    model = None
    model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None)

    if not model_type and len(cfg) == 1:
        # this is hit if config object is nested in directory that is named after model type

        model_type = next(iter(cfg))
        if model_type in MODEL_DATACLASS_REGISTRY:
            cfg = cfg[model_type]
        else:
            raise Exception(
                "Could not infer model type from directory. Please add _name field to indicate model type. "
                "Available models: " + str(MODEL_DATACLASS_REGISTRY.keys()) +
                " Requested model type: " + model_type)

    if model_type in ARCH_MODEL_REGISTRY:
        # case 1: legacy models
        model = ARCH_MODEL_REGISTRY[model_type]
    elif model_type in MODEL_DATACLASS_REGISTRY:
        # case 2: config-driven models
        model = MODEL_REGISTRY[model_type]

    if model_type in MODEL_DATACLASS_REGISTRY:
        # set defaults from dataclass. note that arch name and model name can be the same
        dc = MODEL_DATACLASS_REGISTRY[model_type]

        if isinstance(cfg, argparse.Namespace):
            cfg = dc.from_namespace(cfg)
        else:
            cfg = merge_with_parent(dc(), cfg, from_checkpoint)
    else:
        if model_type in ARCH_CONFIG_REGISTRY:
            with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack():
                # this calls the different "arch" functions (like base_architecture()) that you indicate
                # if you specify --arch on the command line. this is only applicable to the old argparse based models
                # hydra models should expose different architectures via different config files
                # it will modify the cfg object and default parameters according to the arch
                ARCH_CONFIG_REGISTRY[model_type](cfg)

    assert model is not None, (
        f"Could not infer model type from {cfg}. "
        "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys()) +
        f" Requested model type: {model_type}")

    return model.build_model(cfg, task)
def setup_task(cfg: DictConfig, **kwargs):
    task = None
    task_name = getattr(cfg, "task", None)

    if isinstance(task_name, str):
        # legacy tasks
        task = TASK_REGISTRY[task_name]
    else:
        task_name = getattr(cfg, "_name", None)

        if task_name and task_name in TASK_DATACLASS_REGISTRY:
            dc = TASK_DATACLASS_REGISTRY[task_name]
            cfg = merge_with_parent(dc(), cfg)
            task = TASK_REGISTRY[task_name]

    assert task is not None, f"Could not infer task type from {cfg}"

    return task.setup_task(cfg, **kwargs)
示例#6
0
文件: __init__.py 项目: Fei-WL/CCMT
def build_model(cfg: FairseqDataclass, task):

    model = None
    model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None)

    if not model_type and len(cfg) == 1:
        # this is hit if config object is nested in directory that is named after model type

        model_type = next(iter(cfg))
        if model_type in MODEL_DATACLASS_REGISTRY:
            cfg = cfg[model_type]
        else:
            raise Exception(
                "Could not infer model type from directory. Please add _name field to indicate model type. "
                "Available models: "
                + str(MODEL_DATACLASS_REGISTRY.keys())
                + " Requested model type: "
                + model_type
            )

    if model_type in ARCH_MODEL_REGISTRY:
        # case 1: legacy models
        model = ARCH_MODEL_REGISTRY[model_type]
    elif model_type in MODEL_DATACLASS_REGISTRY:
        # case 2: config-driven models
        model = MODEL_REGISTRY[model_type]

    if model_type in MODEL_DATACLASS_REGISTRY:
        # set defaults from dataclass. note that arch name and model name can be the same
        dc = MODEL_DATACLASS_REGISTRY[model_type]
        if isinstance(cfg, argparse.Namespace):
            cfg = populate_dataclass(dc(), cfg)
        else:
            cfg = merge_with_parent(dc(), cfg)

    assert model is not None, (
        f"Could not infer model type from {cfg}. "
        f"Available models: "
        + str(MODEL_DATACLASS_REGISTRY.keys())
        + " Requested model type: "
        + model_type
    )

    return model.build_model(cfg, task)
示例#7
0
def setup_task(cfg: FairseqDataclass, **kwargs):
    task = None
    task_name = getattr(cfg, "task", None)

    if isinstance(task_name, str):
        # legacy tasks
        task = TASK_REGISTRY[task_name]
        if task_name in TASK_DATACLASS_REGISTRY:
            dc = TASK_DATACLASS_REGISTRY[task_name]
            cfg = populate_dataclass(dc(), cfg)
    else:
        task_name = getattr(cfg, "_name", None)

        if task_name and task_name in TASK_DATACLASS_REGISTRY:
            dc = TASK_DATACLASS_REGISTRY[task_name]
            cfg = merge_with_parent(dc(), cfg)
            task = TASK_REGISTRY[task_name]

    assert task is not None, f"Could not infer task type from {cfg}. Available tasks: {TASK_REGISTRY.keys()}"

    return task.setup_task(cfg, **kwargs)