Exemplo n.º 1
0
def test_none_cases(instantiate_func: Any, ) -> Any:
    assert instantiate_func(None) is None

    cfg = {
        "_target_": "tests.instantiate.ArgsClass",
        "none_dict": DictConfig(None),
        "none_list": ListConfig(None),
        "dict": {
            "field": 10,
            "none_dict": DictConfig(None),
            "none_list": ListConfig(None),
        },
        "list": [
            10,
            DictConfig(None),
            ListConfig(None),
        ],
    }
    ret = instantiate_func(cfg)
    assert ret.kwargs["none_dict"] is None
    assert ret.kwargs["none_list"] is None
    assert ret.kwargs["dict"].field == 10
    assert ret.kwargs["dict"].none_dict is None
    assert ret.kwargs["dict"].none_list is None
    assert ret.kwargs["list"][0] == 10
    assert ret.kwargs["list"][1] is None
    assert ret.kwargs["list"][2] is None
Exemplo n.º 2
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Required loss function
        if not WARP_RNNT_AVAILABLE:
            raise ImportError(
                "Could not import `warprnnt_pytorch`.\n"
                "Please visit https://github.com/HawkAaron/warp-transducer "
                "and follow the steps in the readme to build and install the "
                "pytorch bindings for RNNT Loss, or use the provided docker "
                "container that supports RNN-T loss.")

        # Tokenizer is necessary for this model
        if 'tokenizer' not in cfg:
            raise ValueError(
                "`cfg` must have `tokenizer` config to create a tokenizer !")

        if not isinstance(cfg, DictConfig):
            cfg = OmegaConf.create(cfg)

        # Setup the tokenizer
        self._setup_tokenizer(cfg.tokenizer)

        # Initialize a dummy vocabulary
        vocabulary = self.tokenizer.tokenizer.get_vocab()

        # Set the new vocabulary
        with open_dict(cfg):
            cfg.labels = ListConfig(list(vocabulary))

        with open_dict(cfg.decoder):
            cfg.decoder.vocab_size = len(vocabulary)

        with open_dict(cfg.joint):
            cfg.joint.num_classes = len(vocabulary)
            cfg.joint.vocabulary = ListConfig(list(vocabulary))
            cfg.joint.jointnet.encoder_hidden = cfg.model_defaults.enc_hidden
            cfg.joint.jointnet.pred_hidden = cfg.model_defaults.pred_hidden

        super().__init__(cfg=cfg, trainer=trainer)

        # Setup decoding object
        self.decoding = RNNTBPEDecoding(
            decoding_cfg=self.cfg.decoding,
            decoder=self.decoder,
            joint=self.joint,
            tokenizer=self.tokenizer,
        )

        # Setup wer object
        self.wer = RNNTBPEWER(decoding=self.decoding,
                              batch_dim_index=0,
                              use_cer=False,
                              log_prediction=True,
                              dist_sync_on_step=True)

        # Setup fused Joint step if flag is set
        if self.joint.fuse_loss_wer:
            self.joint.set_loss(self.loss)
            self.joint.set_wer(self.wer)
Exemplo n.º 3
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        # Tokenizer is necessary for this model
        if 'tokenizer' not in cfg:
            raise ValueError(
                "`cfg` must have `tokenizer` config to create a tokenizer !")

        if not isinstance(cfg, DictConfig):
            cfg = OmegaConf.create(cfg)

        # Setup the tokenizer
        self._setup_tokenizer(cfg.tokenizer)

        # Initialize a dummy vocabulary
        vocabulary = self.tokenizer.tokenizer.get_vocab()

        # Set the new vocabulary
        with open_dict(cfg):
            cfg.labels = ListConfig(list(vocabulary))

        with open_dict(cfg.decoder):
            cfg.decoder.vocab_size = len(vocabulary)

        with open_dict(cfg.joint):
            cfg.joint.num_classes = len(vocabulary)
            cfg.joint.vocabulary = ListConfig(list(vocabulary))
            cfg.joint.jointnet.encoder_hidden = cfg.model_defaults.enc_hidden
            cfg.joint.jointnet.pred_hidden = cfg.model_defaults.pred_hidden

        super().__init__(cfg=cfg, trainer=trainer)

        # Setup decoding object
        self.decoding = RNNTBPEDecoding(
            decoding_cfg=self.cfg.decoding,
            decoder=self.decoder,
            joint=self.joint,
            tokenizer=self.tokenizer,
        )

        # Setup wer object
        self.wer = RNNTBPEWER(
            decoding=self.decoding,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            log_prediction=self._cfg.get('log_prediction', True),
            dist_sync_on_step=True,
        )

        # Setup fused Joint step if flag is set
        if self.joint.fuse_loss_wer:
            self.joint.set_loss(self.loss)
            self.joint.set_wer(self.wer)
Exemplo n.º 4
0
    def __init__(self, cfg: DictConfig, trainer=None):
        if 'tokenizer' not in cfg:
            raise ValueError(
                "`cfg` must have `tokenizer` config to create a tokenizer !")

        self.tokenizer_cfg = OmegaConf.to_container(cfg.tokenizer,
                                                    resolve=True)  # type: dict
        self.tokenizer_dir = self.tokenizer_cfg.pop(
            'dir')  # Remove tokenizer directory
        self.tokenizer_type = self.tokenizer_cfg.pop(
            'type').lower()  # Remove tokenizer_type

        # Setup the tokenizer
        self._setup_tokenizer()

        # Initialize a dummy vocabulary
        vocabulary = self.tokenizer.tokenizer.get_vocab()

        # Set the new vocabulary
        cfg.decoder.params.vocabulary = ListConfig(list(vocabulary.values()))

        # Override number of classes if placeholder provided
        if cfg.decoder.params['num_classes'] < 1:
            logging.info(
                "\nReplacing placeholder number of classes ({}) with actual number of classes - {}"
                .format(cfg.decoder.params['num_classes'], len(vocabulary)))
            cfg.decoder.params['num_classes'] = len(vocabulary)

        super().__init__(cfg=cfg, trainer=trainer)

        # Setup metric objects
        self._wer = WERBPE(tokenizer=self.tokenizer,
                           batch_dim_index=0,
                           use_cer=False,
                           ctc_decode=True)
Exemplo n.º 5
0
 def test_direct_creation_of_listconfig_or_dictconfig(self, input_: Any) -> None:
     if isinstance(input_, Sequence):
         cfg = ListConfig(input_)  # type: ignore
         assert isinstance(cfg, ListConfig)
     else:
         cfg = DictConfig(input_)  # type: ignore
         assert isinstance(cfg, DictConfig)
Exemplo n.º 6
0
def test_create_from_listconfig_preserves_metadata() -> None:
    cfg1 = ListConfig(element_type=int, is_optional=False, content=[1, 2, 3])
    OmegaConf.set_struct(cfg1, True)
    OmegaConf.set_readonly(cfg1, True)
    cfg2 = OmegaConf.create(cfg1)
    assert cfg1 == cfg2
    assert cfg1._metadata == cfg2._metadata
Exemplo n.º 7
0
    def __init__(self, cfg: DictConfig, trainer=None):
        if 'tokenizer' not in cfg:
            raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !")

        # Setup the tokenizer
        self._setup_tokenizer(cfg.tokenizer)

        # Initialize a dummy vocabulary
        vocabulary = self.tokenizer.tokenizer.get_vocab()

        # Set the new vocabulary
        with open_dict(cfg):
            if "params" in cfg.decoder:
                cfg.decoder.params.vocabulary = ListConfig(list(vocabulary.values()))
            else:
                cfg.decoder.vocabulary = ListConfig(list(vocabulary.values()))

        # Override number of classes if placeholder provided
        if "params" in cfg.decoder:
            num_classes = cfg.decoder["params"]["num_classes"]
        else:
            num_classes = cfg.decoder["num_classes"]

        if num_classes < 1:
            logging.info(
                "\nReplacing placeholder number of classes ({}) with actual number of classes - {}".format(
                    num_classes, len(vocabulary)
                )
            )
            if "params" in cfg.decoder:
                cfg.decoder["params"]["num_classes"] = len(vocabulary)
            else:
                cfg.decoder["num_classes"] = len(vocabulary)

        super().__init__(cfg=cfg, trainer=trainer)

        # Setup metric objects
        self._wer = WERBPE(
            tokenizer=self.tokenizer,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=self._cfg.get("log_prediction", False),
        )
Exemplo n.º 8
0
    def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str):
        """
        Changes vocabulary of the tokenizer used during CTC decoding process.
        Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        Args:
            new_tokenizer_dir: Path to the new tokenizer directory.
            new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers,
                whereas `wpe` is used for `BertTokenizer`.

        Returns: None

        """
        if not os.path.isdir(new_tokenizer_dir):
            raise NotADirectoryError(
                f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}'
            )

        if new_tokenizer_type.lower() not in ('bpe', 'wpe'):
            raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`')

        self.tokenizer_dir = new_tokenizer_dir  # Remove tokenizer directory
        self.tokenizer_type = new_tokenizer_type.lower()  # Remove tokenizer_type

        # Setup the tokenizer
        self._setup_tokenizer()

        # Initialize a dummy vocabulary
        vocabulary = self.tokenizer.tokenizer.get_vocab()

        # Set the new vocabulary
        decoder_config = copy.deepcopy(self.decoder.to_config_dict())
        decoder_config.params.vocabulary = ListConfig(list(vocabulary.values()))

        # Override number of classes if placeholder provided
        logging.info(
            "\nReplacing old number of classes ({}) with new number of classes - {}".format(
                decoder_config['params']['num_classes'], len(vocabulary)
            )
        )
        decoder_config['params']['num_classes'] = len(vocabulary)

        del self.decoder
        self.decoder = EncDecCTCModelBPE.from_config_dict(decoder_config)
        del self.loss
        self.loss = CTCLoss(num_classes=self.decoder.num_classes_with_blank - 1, zero_infinity=True)
        self._wer = WERBPE(tokenizer=self.tokenizer, batch_dim_index=0, use_cer=False, ctc_decode=True)

        # Update config
        OmegaConf.set_struct(self._cfg.decoder, False)
        self._cfg.decoder = decoder_config
        OmegaConf.set_struct(self._cfg.decoder, True)

        logging.info(f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.")
Exemplo n.º 9
0
    def __init__(self, cfg: DictConfig, trainer=None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        if 'tokenizer' not in cfg:
            raise ValueError(
                "`cfg` must have `tokenizer` config to create a tokenizer !")

        # Setup the tokenizer
        self._setup_tokenizer(cfg.tokenizer)

        # Initialize a dummy vocabulary
        vocabulary = self.tokenizer.tokenizer.get_vocab()

        # Set the new vocabulary
        with open_dict(cfg):
            # sidestepping the potential overlapping tokens issue in aggregate tokenizers
            if self.tokenizer_type == "agg":
                cfg.decoder.vocabulary = ListConfig(vocabulary)
            else:
                cfg.decoder.vocabulary = ListConfig(list(vocabulary.keys()))

        # Override number of classes if placeholder provided
        num_classes = cfg.decoder["num_classes"]

        if num_classes < 1:
            logging.info(
                "\nReplacing placeholder number of classes ({}) with actual number of classes - {}"
                .format(num_classes, len(vocabulary)))
            cfg.decoder["num_classes"] = len(vocabulary)

        super().__init__(cfg=cfg, trainer=trainer)

        # Setup metric objects
        self._wer = WERBPE(
            tokenizer=self.tokenizer,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=self._cfg.get("log_prediction", False),
        )
Exemplo n.º 10
0
    def change_labels(self, new_labels: List[str]):
        """
        Changes labels used by the decoder model. Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on a data in another dataset.

        If new_labels == self.decoder.vocabulary then nothing will be changed.

        Args:

            new_labels: list with new labels. Must contain at least 2 elements. Typically, \
            this is set of labels for the dataset.

        Returns: None

        """
        if new_labels is not None and not isinstance(new_labels, ListConfig):
            new_labels = ListConfig(new_labels)

        if self._cfg.labels == new_labels:
            logging.warning(
                f"Old labels ({self._cfg.labels}) and new labels ({new_labels}) match. Not changing anything"
            )
        else:
            if new_labels is None or len(new_labels) == 0:
                raise ValueError(
                    f'New labels must be non-empty list of labels. But I got: {new_labels}'
                )

            # Update config
            self._cfg.labels = new_labels

            decoder_config = self.decoder.to_config_dict()
            new_decoder_config = copy.deepcopy(decoder_config)
            self._update_decoder_config(new_decoder_config)
            del self.decoder
            self.decoder = EncDecClassificationModel.from_config_dict(
                new_decoder_config)

            OmegaConf.set_struct(self._cfg.decoder, False)
            self._cfg.decoder = new_decoder_config
            OmegaConf.set_struct(self._cfg.decoder, True)

            if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
                self._cfg.train_ds.labels = new_labels

            if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None:
                self._cfg.validation_ds.labels = new_labels

            if 'test_ds' in self._cfg and self._cfg.test_ds is not None:
                self._cfg.test_ds.labels = new_labels

            logging.info(
                f"Changed decoder output to {self.decoder.num_classes} labels."
            )
Exemplo n.º 11
0
def instantiate(cfg):
    """
    Recursively instantiate objects defined in dictionaries by
    "_target_" and arguments.

    Args:
        cfg: a dict-like object with "_target_" that defines the caller, and
            other keys that define the arguments

    Returns:
        object instantiated by cfg
    """
    from omegaconf import ListConfig, DictConfig, OmegaConf

    if isinstance(cfg, ListConfig):
        lst = [instantiate(x) for x in cfg]
        return ListConfig(lst, flags={"allow_objects": True})
    if isinstance(cfg, list):
        # Specialize for list, because many classes take
        # list[objects] as arguments, such as ResNet, DatasetMapper
        return [instantiate(x) for x in cfg]

    # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
    # instantiate it to the actual dataclass.
    if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(
            cfg._metadata.object_type):
        return OmegaConf.to_object(cfg)

    if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
        # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
        # but faster: https://github.com/facebookresearch/hydra/issues/1200
        cfg = {k: instantiate(v) for k, v in cfg.items()}
        cls = cfg.pop("_target_")
        cls = instantiate(cls)

        if isinstance(cls, str):
            cls_name = cls
            cls = locate(cls_name)
            assert cls is not None, cls_name
        else:
            try:
                cls_name = cls.__module__ + "." + cls.__qualname__
            except Exception:
                # target could be anything, so the above could fail
                cls_name = str(cls)
        assert callable(
            cls), f"_target_ {cls} does not define a callable object"
        try:
            return cls(**cfg)
        except TypeError:
            logger = logging.getLogger(__name__)
            logger.error(f"Error when instantiating {cls_name}!")
            raise
    return cfg  # return as-is if don't know what to do
Exemplo n.º 12
0
def speech_classification_model():
    preprocessor = {
        'cls':
        'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
        'params': dict({})
    }
    encoder = {
        'cls': 'nemo.collections.asr.modules.ConvASREncoder',
        'params': {
            'feat_in':
            64,
            'activation':
            'relu',
            'conv_mask':
            True,
            'jasper': [{
                'filters': 32,
                'repeat': 1,
                'kernel': [1],
                'stride': [1],
                'dilation': [1],
                'dropout': 0.0,
                'residual': False,
                'separable': True,
                'se': True,
                'se_context_size': -1,
            }],
        },
    }

    decoder = {
        'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification',
        'params': {
            'feat_in': 32,
            'num_classes': 30,
        },
    }

    modelConfig = DictConfig({
        'preprocessor':
        DictConfig(preprocessor),
        'encoder':
        DictConfig(encoder),
        'decoder':
        DictConfig(decoder),
        'labels':
        ListConfig(["dummy_cls_{}".format(i + 1) for i in range(30)]),
    })
    model = EncDecClassificationModel(cfg=modelConfig)
    return model
Exemplo n.º 13
0
def values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig:
    assert isinstance(_parent_, BaseContainer)
    in_dict = _get_and_validate_dict_input(key,
                                           parent=_parent_,
                                           resolver_name="oc.dict.values")

    content = in_dict._content
    assert isinstance(content, dict)

    ret = ListConfig([])
    for k in content:
        ref_node = AnyNode(f"${{{key}.{k!s}}}")
        ret.append(ref_node)

    # Finalize result by setting proper type and parent.
    element_type: Any = in_dict._metadata.element_type
    ret._metadata.element_type = element_type
    ret._metadata.ref_type = List[element_type]
    ret._set_parent(_parent_)

    return ret
Exemplo n.º 14
0
    def _list_merge(dest: Any, src: Any) -> None:
        from omegaconf import DictConfig, ListConfig, OmegaConf

        assert isinstance(dest, ListConfig)
        assert isinstance(src, ListConfig)

        if src._is_none():
            dest._set_value(None)
        elif src._is_missing():
            # do not change dest if src is MISSING.
            if dest._metadata.element_type is Any:
                dest._metadata.element_type = src._metadata.element_type
        elif src._is_interpolation():
            dest._set_value(src._value())
        else:
            temp_target = ListConfig(content=[], parent=dest._get_parent())
            temp_target.__dict__["_metadata"] = copy.deepcopy(
                dest.__dict__["_metadata"])
            is_optional, et = _resolve_optional(dest._metadata.element_type)
            if is_structured_config(et):
                prototype = DictConfig(et,
                                       ref_type=et,
                                       is_optional=is_optional)
                for item in src._iter_ex(resolve=False):
                    if isinstance(item, DictConfig):
                        item = OmegaConf.merge(prototype, item)
                    temp_target.append(item)
            else:
                for item in src._iter_ex(resolve=False):
                    temp_target.append(item)

            dest.__dict__["_content"] = temp_target.__dict__["_content"]

        # explicit flags on the source config are replacing the flag values in the destination
        flags = src._metadata.flags
        assert flags is not None
        for flag, value in flags.items():
            if value is not None:
                dest._set_flag(flag, value)
Exemplo n.º 15
0
         lambda value, is_optional, key=None: EnumNode(
             enum_type=Color, value=value, is_optional=is_optional, key=key
         ),
         [Color.RED],
     ),
     # DictConfig
     (
         lambda value, is_optional, key=None: DictConfig(
             is_optional=is_optional, content=value, key=key),
         [{}, {
             "foo": "bar"
         }],
     ),
     # ListConfig
     (
         lambda value, is_optional, key=None: ListConfig(
             is_optional=is_optional, content=value, key=key),
         [[], [1, 2, 3]],
     ),
     # dataclass
     (
         lambda value, is_optional, key=None: DictConfig(ref_type=Group,
                                                         is_optional=
                                                         is_optional,
                                                         content=value,
                                                         key=key),
         [Group, Group()],
     ),
 ],
 ids=(
     "BooleanNode",
     "FloatNode",
Exemplo n.º 16
0
    User,
    does_not_raise,
)


@mark.parametrize(
    "target_type, value, expected",
    [
        # Any
        param(Any, "foo", AnyNode("foo"), id="any"),
        param(Any, True, AnyNode(True), id="any"),
        param(Any, 1, AnyNode(1), id="any"),
        param(Any, 1.0, AnyNode(1.0), id="any"),
        param(Any, Color.RED, AnyNode(Color.RED), id="any"),
        param(Any, {}, DictConfig(content={}), id="any_as_dict"),
        param(Any, [], ListConfig(content=[]), id="any_as_list"),
        # int
        param(int, "foo", ValidationError, id="int"),
        param(int, True, ValidationError, id="int"),
        param(int, 1, IntegerNode(1), id="int"),
        param(int, 1.0, ValidationError, id="int"),
        param(int, Color.RED, ValidationError, id="int"),
        # float
        param(float, "foo", ValidationError, id="float"),
        param(float, True, ValidationError, id="float"),
        param(float, 1, FloatNode(1), id="float"),
        param(float, 1.0, FloatNode(1.0), id="float"),
        param(float, Color.RED, ValidationError, id="float"),
        # bool
        param(bool, "foo", ValidationError, id="bool"),
        param(bool, True, BooleanNode(True), id="bool"),
Exemplo n.º 17
0
    assert str(node) == str(output_)


# testing invalid conversions
@pytest.mark.parametrize(
    "type_,input_",
    [
        (IntegerNode, "abc"),
        (IntegerNode, 10.1),
        (IntegerNode, "-1132c"),
        (FloatNode, "abc"),
        (IntegerNode, "-abc"),
        (BooleanNode, "Nope"),
        (BooleanNode, "Yup"),
        (StringNode, [1, 2]),
        (StringNode, ListConfig([1, 2])),
        (StringNode, {"foo": "var"}),
        (FloatNode, DictConfig({"foo": "var"})),
        (IntegerNode, [1, 2]),
        (IntegerNode, ListConfig([1, 2])),
        (IntegerNode, {"foo": "var"}),
        (IntegerNode, DictConfig({"foo": "var"})),
        (BooleanNode, [1, 2]),
        (BooleanNode, ListConfig([1, 2])),
        (BooleanNode, {"foo": "var"}),
        (BooleanNode, DictConfig({"foo": "var"})),
        (FloatNode, [1, 2]),
        (FloatNode, ListConfig([1, 2])),
        (FloatNode, {"foo": "var"}),
        (FloatNode, DictConfig({"foo": "var"})),
        (AnyNode, [1, 2]),
Exemplo n.º 18
0
         },
     ),
     {
         "a": "${b}",
         "b": 1
     },
     id="dict_merge_inter_to_missing",
 ),
 param(
     (
         {
             "a": [0],
             "b": [1]
         },
         {
             "a": ListConfig(content="${b}"),
             "b": "???"
         },
     ),
     {
         "a": ListConfig(content="${b}"),
         "b": [1]
     },
     id="dict_with_list_merge_inter_to_missing",
 ),
 # lists
 (([1, 2, 3], [4, 5, 6]), [4, 5, 6]),
 (([[1, 2, 3]], [[4, 5, 6]]), [[4, 5, 6]]),
 (([1, 2, {
     "a": 10
 }], [4, 5, {
Exemplo n.º 19
0
class TestGetWithDefault:
    @mark.parametrize(
        "d,select,key",
        [
            ({
                "key": {
                    "subkey": 2
                }
            }, "", "missing"),
            ({
                "key": {
                    "subkey": 2
                }
            }, "key", "missing"),
            ({
                "key": "???"
            }, "", "key"),
            ({
                "key": DictConfig(content="???")
            }, "", "key"),
            ({
                "key": ListConfig(content="???")
            }, "", "key"),
        ],
    )
    def test_dict_get_with_default(self, d: Any, select: Any, key: Any,
                                   default_val: Any,
                                   struct: Optional[bool]) -> None:
        c = OmegaConf.create(d)
        c = OmegaConf.select(c, select)
        OmegaConf.set_struct(c, struct)
        assert c.get(key, default_val) == default_val

    @mark.parametrize(
        ("d", "select", "key", "expected"),
        [
            ({
                "key": "value"
            }, "", "key", "value"),
            ({
                "key": None
            }, "", "key", None),
            ({
                "key": {
                    "subkey": None
                }
            }, "key", "subkey", None),
            ({
                "key": DictConfig(is_optional=True, content=None)
            }, "", "key", None),
            ({
                "key": ListConfig(is_optional=True, content=None)
            }, "", "key", None),
        ],
    )
    def test_dict_get_not_returning_default(
        self,
        d: Any,
        select: Any,
        key: Any,
        expected: Any,
        default_val: Any,
        struct: Optional[bool],
    ) -> None:
        c = OmegaConf.create(d)
        c = OmegaConf.select(c, select)
        OmegaConf.set_struct(c, struct)
        assert c.get(key, default_val) == expected

    @mark.parametrize(
        "d,exc",
        [
            ({
                "key": "${foo}"
            }, InterpolationKeyError),
            (
                {
                    "key": "${foo}",
                    "foo": "???"
                },
                InterpolationToMissingValueError,
            ),
            ({
                "key": DictConfig(content="${foo}")
            }, InterpolationKeyError),
        ],
    )
    def test_dict_get_with_default_errors(self, d: Any, exc: type,
                                          struct: Optional[bool],
                                          default_val: Any) -> None:
        c = OmegaConf.create(d)
        OmegaConf.set_struct(c, struct)
        with raises(exc):
            c.get("key", default_value=123)
Exemplo n.º 20
0
    def change_vocabulary(self,
                          new_tokenizer_dir: str,
                          new_tokenizer_type: str,
                          decoding_cfg: Optional[DictConfig] = None):
        """
        Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        Args:
            new_tokenizer_dir: Directory path to tokenizer.
            new_tokenizer_type: Type of tokenizer. Can be either `bpe` or `wpe`.
            decoding_cfg: A config for the decoder, which is optional. If the decoding type
                needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.

        Returns: None

        """
        if not os.path.isdir(new_tokenizer_dir):
            raise NotADirectoryError(
                f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}'
            )

        if new_tokenizer_type.lower() not in ('bpe', 'wpe'):
            raise ValueError(
                f'New tokenizer type must be either `bpe` or `wpe`')

        tokenizer_cfg = OmegaConf.create({
            'dir': new_tokenizer_dir,
            'type': new_tokenizer_type
        })

        # Setup the tokenizer
        self._setup_tokenizer(tokenizer_cfg)

        # Initialize a dummy vocabulary
        vocabulary = self.tokenizer.tokenizer.get_vocab()

        joint_config = self.joint.to_config_dict()
        new_joint_config = copy.deepcopy(joint_config)
        new_joint_config['vocabulary'] = ListConfig(list(vocabulary.values()))
        new_joint_config['num_classes'] = len(vocabulary)
        del self.joint
        self.joint = EncDecRNNTBPEModel.from_config_dict(new_joint_config)

        decoder_config = self.decoder.to_config_dict()
        new_decoder_config = copy.deepcopy(decoder_config)
        new_decoder_config.vocab_size = len(vocabulary)
        del self.decoder
        self.decoder = EncDecRNNTBPEModel.from_config_dict(new_decoder_config)

        del self.loss
        self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1)

        if decoding_cfg is None:
            # Assume same decoding config as before
            decoding_cfg = self.cfg.decoding

        self.decoding = RNNTBPEDecoding(
            decoding_cfg=decoding_cfg,
            decoder=self.decoder,
            joint=self.joint,
            tokenizer=self.tokenizer,
        )

        self.wer = RNNTBPEWER(
            decoding=self.decoding,
            batch_dim_index=self.wer.batch_dim_index,
            use_cer=self.wer.use_cer,
            log_prediction=self.wer.log_prediction,
            dist_sync_on_step=True,
        )

        # Setup fused Joint step
        if self.joint.fuse_loss_wer:
            self.joint.set_loss(self.loss)
            self.joint.set_wer(self.wer)

        # Update config
        with open_dict(self.cfg.joint):
            self.cfg.joint = new_joint_config

        with open_dict(self.cfg.decoder):
            self.cfg.decoder = new_decoder_config

        with open_dict(self.cfg.decoding):
            self.cfg.decoding = decoding_cfg

        logging.info(
            f"Changed decoder to output to {self.joint.vocabulary} vocabulary."
        )
Exemplo n.º 21
0
        }, "", "hello"),
        ({
            "hello": "${foo}",
            "foo": "???"
        }, "", "hello"),
        ({
            "hello": DictConfig(is_optional=True, content=None)
        }, "", "hello"),
        ({
            "hello": DictConfig(content="???")
        }, "", "hello"),
        ({
            "hello": DictConfig(content="${foo}")
        }, "", "hello"),
        ({
            "hello": ListConfig(is_optional=True, content=None)
        }, "", "hello"),
        ({
            "hello": ListConfig(content="???")
        }, "", "hello"),
    ],
)
def test_dict_get_with_default(d: Any, select: Any, key: Any, default_val: Any,
                               struct: Any) -> None:
    c = OmegaConf.create(d)
    c = OmegaConf.select(c, select)
    OmegaConf.set_struct(c, struct)
    assert c.get(key, default_val) == default_val


def test_map_expansion() -> None:
Exemplo n.º 22
0
def test_create_untyped_list() -> None:
    from omegaconf._utils import get_ref_type

    cfg = ListConfig(ref_type=List, content=[])
    assert get_ref_type(cfg) == Optional[List]
Exemplo n.º 23
0
     id="list:get_nox_ex:invalid_index_type",
 ),
 pytest.param(
     Expected(
         create=lambda: OmegaConf.create([1, 2, 3]),
         op=lambda cfg: cfg._get_node(20),
         exception_type=IndexError,
         msg="list index out of range",
         key=20,
         full_key="[20]",
     ),
     id="list:get_node_ex:index_out_of_range",
 ),
 pytest.param(
     Expected(
         create=lambda: ListConfig(content=None),
         op=lambda cfg: cfg._get_node(20),
         exception_type=TypeError,
         msg="Cannot get_node from a ListConfig object representing None",
         key=20,
         full_key="[20]",
     ),
     id="list:get_node_none",
 ),
 pytest.param(
     Expected(
         create=lambda: ListConfig(content="???"),
         op=lambda cfg: cfg._get_node(20),
         exception_type=MissingMandatoryValue,
         msg="Cannot get_node from a missing ListConfig",
         key=20,
Exemplo n.º 24
0
 (IntegerNode, Color.RED),
 (IntegerNode, b"123"),
 (FloatNode, "abc"),
 (FloatNode, Color.RED),
 (FloatNode, b"10.1"),
 (BytesNode, "abc"),
 (BytesNode, 23),
 (BytesNode, Color.RED),
 (BytesNode, 3.14),
 (BytesNode, True),
 (BooleanNode, "Nope"),
 (BooleanNode, "Yup"),
 (BooleanNode, Color.RED),
 (BooleanNode, b"True"),
 (IntegerNode, [1, 2]),
 (IntegerNode, ListConfig([1, 2])),
 (IntegerNode, {"foo": "var"}),
 (IntegerNode, b"10"),
 (IntegerNode, DictConfig({"foo": "var"})),
 (BytesNode, [1, 2]),
 (BytesNode, ListConfig([1, 2])),
 (BytesNode, {"foo": "var"}),
 (BytesNode, DictConfig({"foo": "var"})),
 (BooleanNode, [1, 2]),
 (BooleanNode, ListConfig([1, 2])),
 (BooleanNode, {"foo": "var"}),
 (BooleanNode, DictConfig({"foo": "var"})),
 (FloatNode, [1, 2]),
 (FloatNode, ListConfig([1, 2])),
 (FloatNode, {"foo": "var"}),
 (FloatNode, DictConfig({"foo": "var"})),
Exemplo n.º 25
0
         "a": [1, 2, 3]
     }),
     {"a": [1, 2, 3]},
     id="list_merge_into_missing",
 ),
 pytest.param(
     ({
         "a": [1, 2, 3]
     }, {
         "a": "???"
     }),
     {"a": "???"},
     id="list_merge_missing_onto",
 ),
 pytest.param(
     ([1, 2, 3], ListConfig(content=MISSING)),
     ListConfig(content=MISSING),
     id="list_merge_missing_onto2",
 ),
 # Interpolations
 # value interpolation
 pytest.param(
     ({
         "d1": 1,
         "inter": "${d1}"
     }, {
         "d1": 2
     }),
     {
         "d1": 2,
         "inter": 2
Exemplo n.º 26
0
    c = OmegaConf.create(input_)
    c[key] = value
    assert c[key] == value
    assert c[key] == value._value()


@pytest.mark.parametrize(
    "input_",
    [
        pytest.param([1, 2, 3], id="list"),
        pytest.param([1, 2, {"a": 3}], id="dict_in_list"),
        pytest.param([1, 2, [10, 20]], id="list_in_list"),
        pytest.param({"b": {"b": 10}}, id="dict_in_dict"),
        pytest.param({"b": [False, 1, "2", 3.0, Color.RED]}, id="list_in_dict"),
        pytest.param({"b": DictConfig(content=None)}, id="none_dictconfig"),
        pytest.param({"b": ListConfig(content=None)}, id="none_listconfig"),
        pytest.param({"b": DictConfig(content="???")}, id="missing_dictconfig"),
        pytest.param({"b": ListConfig(content="???")}, id="missing_listconfig"),
    ],
)
def test_to_container_returns_primitives(input_: Any) -> None:
    def assert_container_with_primitives(item: Any) -> None:
        if isinstance(item, list):
            for v in item:
                assert_container_with_primitives(v)
        elif isinstance(item, dict):
            for _k, v in item.items():
                assert_container_with_primitives(v)
        else:
            assert isinstance(item, (int, float, str, bool, type(None), Enum))
Exemplo n.º 27
0
@pytest.mark.parametrize("struct", [None, True, False])  # type: ignore
@pytest.mark.parametrize("default_val", [4, True, False, None])  # type: ignore
@pytest.mark.parametrize(  # type: ignore
    "d,select,key",
    [
        ({"hello": {"a": 2}}, "", "missing"),
        ({"hello": {"a": 2}}, "hello", "missing"),
        ({"hello": "???"}, "", "hello"),
        ({"hello": "${foo}", "foo": "???"}, "", "hello"),
        ({"hello": None}, "", "hello"),
        ({"hello": "${foo}"}, "", "hello"),
        ({"hello": "${foo}", "foo": "???"}, "", "hello"),
        ({"hello": DictConfig(is_optional=True, content=None)}, "", "hello"),
        ({"hello": DictConfig(content="???")}, "", "hello"),
        ({"hello": DictConfig(content="${foo}")}, "", "hello"),
        ({"hello": ListConfig(is_optional=True, content=None)}, "", "hello"),
        ({"hello": ListConfig(content="???")}, "", "hello"),
    ],
)
def test_dict_get_with_default(
    d: Any, select: Any, key: Any, default_val: Any, struct: Any
) -> None:
    c = OmegaConf.create(d)
    c = OmegaConf.select(c, select)
    OmegaConf.set_struct(c, struct)
    assert c.get(key, default_val) == default_val


def test_map_expansion() -> None:
    c = OmegaConf.create("{a: 2, b: 10}")
    assert isinstance(c, DictConfig)
Exemplo n.º 28
0
            id="dict:interpolation_value_error",
        ),
        param(
            DictConfig({
                "a": 10,
                "b": "foo_${a}"
            }),
            {
                "a": AnyNode(10),
                "b": AnyNode("foo_${a}")
            },
            id="dict:str_interpolation_value",
        ),
        param(DictConfig("${zzz}"), {}, id="dict:inter_error"),
        # ListConfig
        param(ListConfig(["a", "b"]), {
            "0": AnyNode("a"),
            "1": AnyNode("b")
        },
              id="list"),
        param(
            ListConfig(["${1}", 10]),
            {
                "0": AnyNode("${1}"),
                "1": AnyNode(10)
            },
            id="list:interpolation_value",
        ),
        param(ListConfig("${zzz}"), {}, id="list:inter_error"),
    ],
)
Exemplo n.º 29
0
class TestCopy:
    @pytest.mark.parametrize(
        "src",
        [
            # lists
            pytest.param(OmegaConf.create([]), id="list_empty"),
            pytest.param(OmegaConf.create([1, 2]), id="list"),
            pytest.param(OmegaConf.create(["a", "b", "c"]), id="list"),
            pytest.param(ListConfig(content=None), id="list_none"),
            pytest.param(ListConfig(content="???"), id="list_missing"),
            # dicts
            pytest.param(OmegaConf.create({}), id="dict_empty"),
            pytest.param(OmegaConf.create({"a": "b"}), id="dict"),
            pytest.param(OmegaConf.create({"a": {"b": []}}), id="dict"),
            pytest.param(DictConfig(content=None), id="dict_none"),
        ],
    )
    def test_copy(self, copy_method: Any, src: Any) -> None:
        cp = copy_method(src)
        assert src is not cp
        assert src == cp

    @pytest.mark.parametrize(
        "src",
        [
            pytest.param(
                DictConfig(content={"a": {"c": 10}, "b": DictConfig(content="${a}")}),
                id="dict_inter",
            )
        ],
    )
    def test_copy_dict_inter(self, copy_method: Any, src: Any) -> None:
        # test direct copying of the b node (without de-referencing by accessing)
        cp = copy_method(src._get_node("b"))
        assert src.b is not cp
        assert OmegaConf.is_interpolation(src, "b")
        assert OmegaConf.is_interpolation(cp)
        assert src._get_node("b")._value() == cp._value()

        # test copy of src and ensure interpolation is copied as interpolation
        cp2 = copy_method(src)
        assert OmegaConf.is_interpolation(cp2, "b")

    @pytest.mark.parametrize(
        "src,interpolating_key,interpolated_key",
        [([1, 2, "${0}"], 2, 0), ({"a": 10, "b": "${a}"}, "b", "a")],
    )
    def test_copy_with_interpolation(
        self, copy_method: Any, src: Any, interpolating_key: str, interpolated_key: str
    ) -> None:
        cfg = OmegaConf.create(src)
        assert cfg[interpolated_key] == cfg[interpolating_key]
        cp = copy_method(cfg)
        assert id(cfg) != id(cp)
        assert cp[interpolated_key] == cp[interpolating_key]
        assert cfg[interpolated_key] == cp[interpolating_key]

        # Interpolation is preserved in original
        cfg[interpolated_key] = "XXX"
        assert cfg[interpolated_key] == cfg[interpolating_key]

        # Test interpolation is preserved in copy
        cp[interpolated_key] = "XXX"
        assert cp[interpolated_key] == cp[interpolating_key]

    def test_list_copy_is_shallow(self, copy_method: Any) -> None:
        cfg = OmegaConf.create([[10, 20]])
        cp = copy_method(cfg)
        assert id(cfg) != id(cp)
        assert id(cfg[0]) == id(cp[0])
Exemplo n.º 30
0
        },
        DictConfig({"a": 0}),
    ],
)
def test_dict_assignment_deepcopy_semantics(node: Any) -> None:
    cfg = OmegaConf.create()
    cfg.foo = node
    node["a"] = 1
    assert cfg.foo.a == 0


@pytest.mark.parametrize(
    "node",
    [
        [1, 2],
        ListConfig([1, 2]),
    ],
)
def test_list_assignment_deepcopy_semantics(node: Any) -> None:
    cfg = OmegaConf.create()
    cfg.foo = node
    node[1] = 10
    assert cfg.foo[1] == 2


@pytest.mark.parametrize("d", [{"a": {"b": 10}}, {"a": {"b": {"c": 10}}}])
def test_assign_does_not_modify_src_config(d: Any) -> None:
    cfg1 = OmegaConf.create(d)
    cfg2 = OmegaConf.create({})
    cfg2.a = cfg1.a
    assert cfg1 == d