def build_model(cfg: FairseqDataclass, task): model = None model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None) if not model_type and len(cfg) == 1: # this is hit if config object is nested in directory that is named after model type model_type = next(iter(cfg)) if model_type in MODEL_DATACLASS_REGISTRY: cfg = cfg[model_type] else: raise Exception( "Could not infer model type from directory. Please add _name field to indicate model type" ) if model_type in ARCH_MODEL_REGISTRY: # case 1: legacy models model = ARCH_MODEL_REGISTRY[model_type] elif model_type in MODEL_DATACLASS_REGISTRY: # case 2: config-driven models model = MODEL_REGISTRY[model_type] if model_type in MODEL_DATACLASS_REGISTRY: # set defaults from dataclass. note that arch name and model name can be the same dc = MODEL_DATACLASS_REGISTRY[model_type] if isinstance(cfg, argparse.Namespace): cfg = populate_dataclass(dc(), cfg) else: cfg = merge_with_parent(dc(), cfg) assert model is not None, f"Could not infer model type from {cfg}" return model.build_model(cfg, task)
def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs): if isinstance(cfg, DictConfig): choice = cfg._name elif isinstance(cfg, str): choice = cfg if choice in DATACLASS_REGISTRY: cfg = DATACLASS_REGISTRY[choice]() else: choice = getattr(cfg, registry_name, None) if choice in DATACLASS_REGISTRY: cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]()) if choice is None: if required: raise ValueError('{} is required!'.format(registry_name)) return None cls = REGISTRY[choice] if hasattr(cls, "build_" + registry_name): builder = getattr(cls, "build_" + registry_name) else: builder = cls return builder(cfg, *extra_args, **extra_kwargs)
def build_model(cfg: FairseqDataclass, task): model = None model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None) if not model_type and len(cfg) == 1: # this is hit if config object is nested in directory that is named after model type model_type = next(iter(cfg)) if model_type in MODEL_DATACLASS_REGISTRY: cfg = cfg[model_type] else: raise Exception( "Could not infer model type from directory. Please add _name field to indicate model type. " "Available models: " + str(MODEL_DATACLASS_REGISTRY.keys()) + " Requested model type: " + model_type) if model_type in ARCH_MODEL_REGISTRY: # case 1: legacy models model = ARCH_MODEL_REGISTRY[model_type] elif model_type in MODEL_DATACLASS_REGISTRY: # case 2: config-driven models model = MODEL_REGISTRY[model_type] if model_type in MODEL_DATACLASS_REGISTRY: # set defaults from dataclass. note that arch name and model name can be the same dc = MODEL_DATACLASS_REGISTRY[model_type] if isinstance(cfg, argparse.Namespace): cfg = populate_dataclass(dc(), cfg) else: cfg = merge_with_parent(dc(), cfg) else: if model_type in ARCH_CONFIG_REGISTRY: with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack(): # this calls the different "arch" functions (like base_architecture()) that you indicate # if you specify --arch on the command line. this is only applicable to the old argparse based models # hydra models should expose different architectures via different config files # it will modify the cfg object and default parameters according to the arch ARCH_CONFIG_REGISTRY[model_type](cfg) assert model is not None, ( f"Could not infer model type from {cfg}. " "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys()) + f" Requested model type: {model_type}") return model.build_model(cfg, task)
def setup_task(cfg: FairseqDataclass, **kwargs): task = None task_name = getattr(cfg, "task", None) if isinstance(task_name, str): # legacy tasks task = TASK_REGISTRY[task_name] if task_name in TASK_DATACLASS_REGISTRY: dc = TASK_DATACLASS_REGISTRY[task_name] cfg = populate_dataclass(dc(), cfg) else: task_name = getattr(cfg, "_name", None) if task_name and task_name in TASK_DATACLASS_REGISTRY: dc = TASK_DATACLASS_REGISTRY[task_name] cfg = merge_with_parent(dc(), cfg) task = TASK_REGISTRY[task_name] assert task is not None, f"Could not infer task type from {cfg}. Available tasks: {TASK_REGISTRY.keys()}" return task.setup_task(cfg, **kwargs)