コード例 #1
0
ファイル: __init__.py プロジェクト: ardiya/multical
 def instantiate_board(config):
   if config._type_ == "charuco":
     schema = OmegaConf.structured(CharucoConfig)
     return CharucoBoard(aruco_params=aruco_params, **merge_schema(config, schema))
   elif config._type_ == "aprilgrid":
     schema = OmegaConf.structured(AprilConfig)
     return AprilGrid(**merge_schema(config, schema))
   else:
     assert False, f"unknown board type: {config._type_}, options are (charuco | aprilgrid | checkerboard)"
コード例 #2
0
def legacy_model_config_to_new_model_config(model_cfg: DictConfig) -> DictConfig:
    """
    Transform old style config into
    :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationModelConfig`.
    Old style configs are configs which were used before ``common_dataset_parameters`` item was added. Old style
    datasets use ``dataset`` instead of ``common_dataset_parameters``, ``batch_size`` instead of ``tokens_in_batch``.
    Old style configs do not support tarred datasets.

    Args:
        model_cfg: old style config

    Returns:
        model config which follows dataclass
            :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationModelConfig`
    """
    train_ds = model_cfg.get('train_ds')
    validation_ds = model_cfg.get('validation_ds')
    test_ds = model_cfg.get('test_ds')
    dataset = model_cfg.dataset
    punct_head_config = model_cfg.get('punct_head', {})
    capit_head_config = model_cfg.get('capit_head', {})
    return OmegaConf.structured(
        PunctuationCapitalizationModelConfig(
            class_labels=model_cfg.class_labels,
            common_dataset_parameters=CommonDatasetParametersConfig(
                pad_label=dataset.pad_label,
                ignore_extra_tokens=dataset.get(
                    'ignore_extra_tokens', CommonDatasetParametersConfig.ignore_extra_tokens
                ),
                ignore_start_end=dataset.get('ignore_start_end', CommonDatasetParametersConfig.ignore_start_end),
                punct_label_ids=model_cfg.punct_label_ids,
                capit_label_ids=model_cfg.capit_label_ids,
            ),
            train_ds=None
            if train_ds is None
            else legacy_data_config_to_new_data_config(train_ds, dataset, train=True),
            validation_ds=None
            if validation_ds is None
            else legacy_data_config_to_new_data_config(validation_ds, dataset, train=False),
            test_ds=None if test_ds is None else legacy_data_config_to_new_data_config(test_ds, dataset, train=False),
            punct_head=HeadConfig(
                num_fc_layers=punct_head_config.get('punct_num_fc_layers', HeadConfig.num_fc_layers),
                fc_dropout=punct_head_config.get('fc_dropout', HeadConfig.fc_dropout),
                activation=punct_head_config.get('activation', HeadConfig.activation),
                use_transformer_init=punct_head_config.get('use_transformer_init', HeadConfig.use_transformer_init),
            ),
            capit_head=HeadConfig(
                num_fc_layers=capit_head_config.get('capit_num_fc_layers', HeadConfig.num_fc_layers),
                fc_dropout=capit_head_config.get('fc_dropout', HeadConfig.fc_dropout),
                activation=capit_head_config.get('activation', HeadConfig.activation),
                use_transformer_init=capit_head_config.get('use_transformer_init', HeadConfig.use_transformer_init),
            ),
            tokenizer=model_cfg.tokenizer,
            language_model=model_cfg.language_model,
            optim=model_cfg.optim,
        )
    )
コード例 #3
0
ファイル: _utils.py プロジェクト: zhaodan2000/omegaconf
def _ensure_container(target: Any, flags: Optional[Dict[str, bool]] = None) -> Any:
    from omegaconf import OmegaConf

    if is_primitive_container(target):
        assert isinstance(target, (list, dict))
        target = OmegaConf.create(target, flags=flags)
    elif is_structured_config(target):
        target = OmegaConf.structured(target, flags=flags)
    assert OmegaConf.is_config(target)
    return target
コード例 #4
0
def _ensure_container(target: Any, flags: Optional[Dict[str, bool]] = None) -> Any:
    from omegaconf import OmegaConf

    if is_primitive_container(target):
        assert isinstance(target, (list, dict))
        target = OmegaConf.create(target, flags=flags)
    elif is_structured_config(target):
        target = OmegaConf.structured(target, flags=flags)
    elif not OmegaConf.is_config(target):
        raise ValueError(
            "Invalid input. Supports one of "
            + "[dict,list,DictConfig,ListConfig,dataclass,dataclass instance,attr class,attr class instance]"
        )

    return target