Beispiel #1
0
    def test_from_config_dict_without_cls(self):
        """Here we test that instantiation works for configs without cls class path in them.
        IMPORTANT: in this case, correct class type should call from_config_dict. This should work for Models."""
        preprocessor = {
            'cls':
            'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
            'params': dict({})
        }
        encoder = {
            'cls': 'nemo.collections.asr.modules.ConvASREncoder',
            'params': {
                'feat_in':
                64,
                'activation':
                'relu',
                'conv_mask':
                True,
                'jasper': [{
                    'filters': 1024,
                    'repeat': 1,
                    'kernel': [1],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': False,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                }],
            },
        }

        decoder = {
            'cls': 'nemo.collections.asr.modules.ConvASRDecoder',
            'params': {
                'feat_in':
                1024,
                'num_classes':
                28,
                'vocabulary': [
                    ' ',
                    'a',
                    'b',
                    'c',
                    'd',
                    'e',
                    'f',
                    'g',
                    'h',
                    'i',
                    'j',
                    'k',
                    'l',
                    'm',
                    'n',
                    'o',
                    'p',
                    'q',
                    'r',
                    's',
                    't',
                    'u',
                    'v',
                    'w',
                    'x',
                    'y',
                    'z',
                    "'",
                ],
            },
        }
        modelConfig = DictConfig({
            'preprocessor': DictConfig(preprocessor),
            'encoder': DictConfig(encoder),
            'decoder': DictConfig(decoder)
        })
        obj = EncDecCTCModel.from_config_dict(config=modelConfig)
        assert isinstance(obj, EncDecCTCModel)
def test_resolve_interpolation_without_parent() -> None:
    with raises(
            IRE,
            match=re.escape(
                "Cannot resolve interpolation for a node without a parent")):
        DictConfig(content="${foo}")._dereference_node()
    assert {"a": [1, 2]} == c


def test_list_of_dicts() -> None:
    v = [dict(key1="value1"), dict(key2="value2")]
    c = OmegaConf.create(v)
    assert c[0].key1 == "value1"
    assert c[1].key2 == "value2"


@mark.parametrize("default", [None, 0, "default"])
@mark.parametrize(
    ("cfg", "key"),
    [
        (["???"], 0),
        ([DictConfig(content="???")], 0),
        ([ListConfig(content="???")], 0),
    ],
)
def test_list_get_return_default(cfg: List[Any], key: int,
                                 default: Any) -> None:
    c = OmegaConf.create(cfg)
    val = c.get(key, default_value=default)
    assert val is default


@mark.parametrize("default", [None, 0, "default"])
@mark.parametrize(
    ("cfg", "key", "expected"),
    [
        (["found"], 0, "found"),
Beispiel #4
0
    c1 = OmegaConf.create({
        "dataset": {
            "name": "imagenet",
            "path": "/datasets/imagenet"
        },
        "defaults": []
    })
    OmegaConf.set_struct(c1, True)
    c2 = copy.deepcopy(c1)
    with raises(ConfigKeyError):
        OmegaConf.merge(c2, OmegaConf.from_dotlist(["dataset.bad_key=yes"]))


@mark.parametrize(
    "cfg", [ListConfig(element_type=int, content=[]),
            DictConfig(content={})])
def test_deepcopy_preserves_container_type(cfg: Container) -> None:
    cp: Container = copy.deepcopy(cfg)
    assert cp._metadata.element_type == cfg._metadata.element_type


@mark.parametrize(
    "src, flag_name, func, expectation",
    [
        param(
            {},
            "struct",
            lambda c: c.__setitem__("foo", 1),
            raises(KeyError),
            id="struct_setiitem",
        ),
 def cfg(self, python_dict: Any, struct_mode: Optional[bool]) -> DictConfig:
     """Create a DictConfig instance from the given data"""
     cfg: DictConfig = DictConfig(content=python_dict)
     OmegaConf.set_struct(cfg, struct_mode)
     return cfg
Beispiel #6
0
 "cfg, key, expected_is_missing, expectation",
 [
     ({}, "foo", False, does_not_raise()),
     ({
         "foo": True
     }, "foo", False, does_not_raise()),
     ({
         "foo": "${no_such_key}"
     }, "foo", False, raises(ConfigKeyError)),
     ({
         "foo": MISSING
     }, "foo", True, raises(MissingMandatoryValue)),
     pytest.param(
         {
             "foo": "${bar}",
             "bar": DictConfig(content=MISSING)
         },
         "foo",
         True,
         raises(MissingMandatoryValue),
         id="missing_interpolated_dict",
     ),
     pytest.param(
         {"foo": ListConfig(content="???")},
         "foo",
         True,
         raises(MissingMandatoryValue),
         id="missing_list",
     ),
     pytest.param(
         {"foo": DictConfig(content="???")},
Beispiel #7
0
def configure_logging(config: DictConfig = None) -> None:
    """
    This function initializes the logging. It is recommended to use Hydra to 
    configure the training and pass the config to this function.

    Args:
        config: A DictConfig from hydra.main
    """
    if config is None:
        config = DictConfig({
            "logging": {
                "log_dir": "logs",
                "level": "INFO",
                "color": True,
            },
            "training": {
                "rank": 0,
                "num_gpus_per_node": 1,
            },
        })
    elif config.logging.log_dir is None:
        log_dir = "logs"
    else:
        log_dir = config.logging.log_dir

    os.makedirs(log_dir, exist_ok=True)

    # Only setup training for node 0
    if not hasattr(
            config.training, "rank"
    ) or config.training.rank == 0 or config.training.rank is None:
        root = logging.getLogger()
        root.setLevel(getattr(logging, config.logging.level))
        # setup formaters
        file_formatter = logging.Formatter(
            "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s")
        if config.logging.color:
            stream_formater = colorlog.ColoredFormatter(
                "[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s"
            )
        else:
            stream_formater = file_formatter
        # setup handlers
        if config.training.num_gpus_per_node > 1:
            stream_handler = logging.StreamHandler(sys.stdout)
            stream_handler.setFormatter(stream_formater)
            root.addHandler(stream_handler)

        # append the log
        file_handler = logging.FileHandler(os.path.join(
            log_dir, f"experiment.log"),
                                           mode='a')
        file_handler.setFormatter(file_formatter)
        root.addHandler(file_handler)


# def get_original_cwd(config, resume_mode) -> str:
#     if resume_mode:
#         os.getcwd()
#     else:
#         return os.getcwd()
Beispiel #8
0
import onnx
import pytest
import torch
from omegaconf import DictConfig

from nemo.collections.tts.models import WaveGlowModel
from nemo.collections.tts.modules import WaveGlowModule
from nemo.core.classes import typecheck

mcfg = DictConfig({
    "_target_": "nemo.collections.tts.modules.waveglow.WaveGlowModule",
    "n_flows": 12,
    "n_group": 8,
    "n_mel_channels": 80,
    "n_early_every": 4,
    "n_early_size": 2,
    "n_wn_channels": 512,
    "n_wn_layers": 8,
    "wn_kernel_size": 3,
})

pcfg = DictConfig({
    "_target_":
    "nemo.collections.asr.parts.preprocessing.features.FilterbankFeatures",
    "dither": 0.0,
    "nfilt": 80,
    "stft_conv": False,
})

wcfg = DictConfig({
        ({
            "hello": "${foo}",
            "foo": "???"
        }, "", "hello"),
        ({
            "hello": None
        }, "", "hello"),
        ({
            "hello": "${foo}"
        }, "", "hello"),
        ({
            "hello": "${foo}",
            "foo": "???"
        }, "", "hello"),
        ({
            "hello": DictConfig(is_optional=True, content=None)
        }, "", "hello"),
        ({
            "hello": DictConfig(content="???")
        }, "", "hello"),
        ({
            "hello": DictConfig(content="${foo}")
        }, "", "hello"),
        ({
            "hello": ListConfig(is_optional=True, content=None)
        }, "", "hello"),
        ({
            "hello": ListConfig(content="???")
        }, "", "hello"),
    ],
)
Beispiel #10
0
 "cfg, key, expected_is_missing, expectation",
 [
     ({}, "foo", False, raises(ConfigKeyError)),
     ({
         "foo": True
     }, "foo", False, nullcontext()),
     ({
         "foo": "${no_such_key}"
     }, "foo", False, raises(InterpolationKeyError)),
     ({
         "foo": MISSING
     }, "foo", True, raises(MissingMandatoryValue)),
     param(
         {
             "foo": "${bar}",
             "bar": DictConfig(content=MISSING)
         },
         "foo",
         False,
         raises(InterpolationToMissingValueError),
         id="missing_interpolated_dict",
     ),
     param(
         {"foo": ListConfig(content="???")},
         "foo",
         True,
         raises(MissingMandatoryValue),
         id="missing_list",
     ),
     param(
         {"foo": DictConfig(content="???")},
Beispiel #11
0
 (UntypedList, "opt_list", Any, Any, True, Optional[List[Any]]),
 (UntypedDict, "dict", Any, Any, False, Dict[Any, Any]),
 (
     UntypedDict,
     "opt_dict",
     Any,
     Any,
     True,
     Optional[Dict[Any, Any]],
 ),
 (SubscriptedDict, "dict", int, str, False, Dict[str, int]),
 (SubscriptedList, "list", int, Any, False, List[int]),
 (
     DictConfig(
         content={"a": "foo"},
         ref_type=Dict[str, str],
         element_type=str,
         key_type=str,
     ),
     None,
     str,
     str,
     True,
     Optional[Dict[str, str]],
 ),
 (
     ListConfig(content=[1, 2], ref_type=List[int], element_type=int),
     None,
     int,
     Any,
     True,
     Optional[List[int]],
Beispiel #12
0
    def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
        """merge src into dest and return a new copy, does not modified input"""
        from omegaconf import AnyNode, DictConfig, OmegaConf, ValueNode

        assert isinstance(dest, DictConfig)
        assert isinstance(src, DictConfig)
        src_type = src._metadata.object_type
        src_ref_type = get_ref_type(src)
        assert src_ref_type is not None

        # If source DictConfig is:
        #  - an interpolation => set the destination DictConfig to be the same interpolation
        #  - None => set the destination DictConfig to None
        if src._is_interpolation() or src._is_none():
            dest._set_value(src._value())
            _update_types(node=dest,
                          ref_type=src_ref_type,
                          object_type=src_type)
            return

        dest._validate_merge(value=src)

        def expand(node: Container) -> None:
            rt = node._metadata.ref_type
            val: Any
            if rt is not Any:
                if is_dict_annotation(rt):
                    val = {}
                elif is_list_annotation(rt):
                    val = []
                else:
                    val = rt
            elif isinstance(node, DictConfig):
                val = {}
            else:
                assert False

            node._set_value(val)

        if (src._is_missing() and not dest._is_missing()
                and is_structured_config(src_ref_type)):
            # Replace `src` with a prototype of its corresponding structured config
            # whose fields are all missing (to avoid overwriting fields in `dest`).
            src = _create_structured_with_missing_fields(ref_type=src_ref_type,
                                                         object_type=src_type)

        if (dest._is_interpolation()
                or dest._is_missing()) and not src._is_missing():
            expand(dest)

        for key, src_value in src.items_ex(resolve=False):
            src_node = src._get_node(key, validate_access=False)
            dest_node = dest._get_node(key, validate_access=False)

            if isinstance(dest_node, DictConfig):
                dest_node._validate_merge(value=src_node)

            missing_src_value = _is_missing_value(src_value)

            if (isinstance(dest_node, Container)
                    and OmegaConf.is_none(dest, key) and not missing_src_value
                    and not OmegaConf.is_none(src_value)):
                expand(dest_node)

            if dest_node is not None and dest_node._is_interpolation():
                target_node = dest_node._dereference_node(
                    throw_on_resolution_failure=False)
                if isinstance(target_node, Container):
                    dest[key] = target_node
                    dest_node = dest._get_node(key)

            if (dest_node is None
                    and is_structured_config(dest._metadata.element_type)
                    and not missing_src_value):
                # merging into a new node. Use element_type as a base
                dest[key] = DictConfig(content=dest._metadata.element_type,
                                       parent=dest)
                dest_node = dest._get_node(key)

            if dest_node is not None:
                if isinstance(dest_node, BaseContainer):
                    if isinstance(src_value, BaseContainer):
                        dest_node._merge_with(src_value)
                    elif not missing_src_value:
                        dest.__setitem__(key, src_value)
                else:
                    if isinstance(src_value, BaseContainer):
                        dest.__setitem__(key, src_value)
                    else:
                        assert isinstance(dest_node, ValueNode)
                        assert isinstance(src_node, ValueNode)
                        # Compare to literal missing, ignoring interpolation
                        src_node_missing = src_value == "???"
                        try:
                            if isinstance(dest_node, AnyNode):
                                if src_node_missing:
                                    node = copy.copy(src_node)
                                    # if src node is missing, use the value from the dest_node,
                                    # but validate it against the type of the src node before assigment
                                    node._set_value(dest_node._value())
                                else:
                                    node = src_node
                                dest.__setitem__(key, node)
                            else:
                                if not src_node_missing:
                                    dest_node._set_value(src_value)

                        except (ValidationError, ReadonlyConfigError) as e:
                            dest._format_and_raise(key=key,
                                                   value=src_value,
                                                   cause=e)
            else:
                from omegaconf import open_dict

                if is_structured_config(src_type):
                    # verified to be compatible above in _validate_merge
                    with open_dict(dest):
                        dest[key] = src._get_node(key)
                else:
                    dest[key] = src._get_node(key)

        _update_types(node=dest, ref_type=src_ref_type, object_type=src_type)

        # explicit flags on the source config are replacing the flag values in the destination
        flags = src._metadata.flags
        assert flags is not None
        for flag, value in flags.items():
            if value is not None:
                dest._set_flag(flag, value)
Beispiel #13
0
 ),
 pytest.param(
     Expected(
         create=lambda: create_readonly({"foo": "bar"}),
         op=lambda cfg: setattr(cfg, "foo", 20),
         exception_type=ReadonlyConfigError,
         msg="Cannot change read-only config container",
         key="foo",
         child_node=lambda cfg: cfg.foo,
     ),
     id="dict,readonly:set_attribute",
 ),
 pytest.param(
     Expected(
         create=lambda: OmegaConf.create(
             {"foo": DictConfig(is_optional=False, content={})}),
         op=lambda cfg: setattr(cfg, "foo", None),
         exception_type=ValidationError,
         msg="child 'foo' is not Optional",
         key="foo",
         full_key="foo",
         child_node=lambda cfg: cfg.foo,
     ),
     id="dict:setattr:not_optional:set_none",
 ),
 pytest.param(
     Expected(
         create=lambda: OmegaConf.structured(ConcretePlugin),
         op=lambda cfg: cfg.params.__setattr__("foo", "bar"),
         exception_type=ValidationError,
         msg="Value 'bar' could not be converted to Integer",
Beispiel #14
0
def ssl_model():
    preprocessor = {
        'cls':
        'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
        'params': dict({
            'pad_to': 16,
            'dither': 0
        }),
    }

    model_defaults = {'enc_hidden': 32, 'dec_out': 128}

    encoder = {
        'cls': 'nemo.collections.asr.modules.ConvASREncoder',
        'params': {
            'feat_in':
            64,
            'activation':
            'relu',
            'conv_mask':
            True,
            'jasper': [
                {
                    'filters': model_defaults['enc_hidden'],
                    'repeat': 1,
                    'kernel': [1],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': False,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                },
                {
                    'filters': model_defaults['enc_hidden'],
                    'repeat': 1,
                    'kernel': [1],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': False,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                },
                {
                    'filters': model_defaults['enc_hidden'],
                    'repeat': 1,
                    'kernel': [1],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': False,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                },
            ],
        },
    }

    spec_augment = {
        '_target_': 'nemo.collections.asr.modules.MaskedPatchAugmentation',
        'freq_masks': 3,
        'freq_width': 20,
        'patch_size': 16,
        'mask_patches': 0.5,
    }

    loss_list_contr_mlm = {
        'contr': {
            'decoder': {
                '_target_':
                'nemo.collections.asr.modules.ConvASRDecoderReconstruction',
                'feat_in': model_defaults['enc_hidden'],
                'feat_hidden': 128,
                'feat_out': model_defaults['dec_out'],
                'stride_layers': 0,
                'non_stride_layers': 0,
                'stride_transpose': False,
            },
            'loss': {
                '_target_': 'nemo.collections.asr.losses.ContrastiveLoss',
                'in_dim': 64,
                'proj_dim': model_defaults['dec_out'],
                'combine_time_steps': 1,
                'quantized_targets': True,
                'codebook_size': 64,
                'sample_from_same_utterance_only': True,
                'sample_from_non_masked': False,
                'num_negatives': 3,
            },
        },
        'mlm': {
            'decoder': {
                '_target_': 'nemo.collections.asr.modules.ConvASRDecoder',
                'feat_in': model_defaults['enc_hidden'],
                'num_classes': 4096,
            },
            'loss': {
                '_target_': 'nemo.collections.asr.losses.MLMLoss',
                'combine_time_steps': 1
            },
            'targets_from_loss': "contr",
        },
    }

    modelConfig_contr_mlm = DictConfig({
        'preprocessor':
        DictConfig(preprocessor),
        'spec_augment':
        DictConfig(spec_augment),
        'model_defaults':
        DictConfig(model_defaults),
        'encoder':
        DictConfig(encoder),
        'loss_list':
        DictConfig(loss_list_contr_mlm),
    })
    ssl_model = SpeechEncDecSelfSupervisedModel(cfg=modelConfig_contr_mlm)
    return ssl_model
Beispiel #15
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC

import torch
from omegaconf import DictConfig

_ACCESS_CFG = DictConfig({"detach": False, "convert_to_cpu": False})
_ACCESS_ENABLED = False


def set_access_cfg(cfg: 'DictConfig'):
    if cfg is None or not isinstance(cfg, DictConfig):
        raise TypeError(f"cfg must be a DictConfig")
    global _ACCESS_CFG
    _ACCESS_CFG = cfg


class AccessMixin(ABC):
    """
    Allows access to output of intermediate layers of a model
    """
    def __init__(self):
def test_shallow_copy_missing() -> None:
    cfg = DictConfig(content=MISSING)
    c = cfg.copy()
    c._set_value({"foo": 1})
    assert c.foo == 1
    assert cfg._is_missing()
Beispiel #17
0
 def test_assignment_of_non_subclass_1(self, module: Any) -> None:
     cfg = OmegaConf.create(
         {"plugin": DictConfig(module.Plugin, ref_type=module.Plugin)})
     with raises(ValidationError):
         cfg.plugin = OmegaConf.structured(module.FaultyPlugin)
def test_shallow_copy_none() -> None:
    cfg = DictConfig(content=None)
    c = cfg.copy()
    c._set_value({"foo": 1})
    assert c.foo == 1
    assert cfg._is_none()
Beispiel #19
0
     "a": 12
 }, {
     "a": AnyNode(12)
 }),
 # nested dict empty
 (dict(a=12, b=dict()), dict(a=12, b=dict())),
 # nested dict
 (dict(a=12, b=dict(c=10)), dict(a=12, b=dict(c=10))),
 # nested list
 (dict(a=12, b=[1, 2, 3]), dict(a=12, b=[1, 2, 3])),
 # nested list with any
 (dict(a=12, b=[1, 2, AnyNode(3)]), dict(a=12, b=[1, 2,
                                                  AnyNode(3)])),
 # In python 3.6+ insert order changes iteration order. this ensures that equality is preserved.
 (dict(a=1, b=2, c=3, d=4, e=5), dict(e=5, b=2, c=3, d=4, a=1)),
 (DictConfig(content=None), DictConfig(content=None)),
 pytest.param({"a": [1, 2]}, {"a": [1, 2]}, id="list_in_dict"),
 # With interpolations
 ([10, "${0}"], [10, 10]),
 (dict(a=12, b="${a}"), dict(a=12, b=12)),
 # With missing interpolation
 pytest.param([10, "${0}"], [10, 10], id="list_simple_interpolation"),
 pytest.param({"a": "${ref_error}"}, {"a": "${ref_error}"},
              id="dict==dict,ref_error"),
 pytest.param({"a": "???"}, {"a": "???"}, id="dict==dict,missing"),
 pytest.param(User, User, id="User==User"),
 pytest.param({
     "name": "poo",
     "age": 7
 },
              User(name="poo", age=7),
def test_assign_to_reftype_none_or_any(ref_type: Any, assign: Any) -> None:
    cfg = OmegaConf.create({"foo": DictConfig(ref_type=ref_type, content={})})
    cfg.foo = assign
    assert cfg.foo == assign
     }},
     id="conf_missing_dict",
 ),
 pytest.param(
     [{}, ConfWithMissingDict],
     {"dict": "???"},
     id="merge_missing_dict_into_missing_dict",
 ),
 ([{
     "user": User
 }, {
     "user": Group
 }], pytest.raises(ValidationError)),
 (
     [{
         "user": DictConfig(ref_type=User, content=User)
     }, {
         "user": Group
     }],
     pytest.raises(ValidationError),
 ),
 ([Plugin, ConcretePlugin], ConcretePlugin),
 pytest.param(
     [{
         "user": "******"
     }, {
         "user": Group
     }],
     {"user": Group},
     id="merge_into_missing_node",
 ),
 def _test_assign(self, ref_type: Any, value: Any, assign: Any,
                  expectation: Any) -> None:
     cfg = OmegaConf.create(
         {"foo": DictConfig(ref_type=ref_type, content=value)})
     with expectation:
         cfg.foo = assign
Beispiel #23
0
class TestCopy:
    @mark.parametrize(
        "src",
        [
            # lists
            param(OmegaConf.create([]), id="list_empty"),
            param(OmegaConf.create([1, 2]), id="list"),
            param(OmegaConf.create(["a", "b", "c"]), id="list"),
            param(ListConfig(content=None), id="list_none"),
            param(ListConfig(content="???"), id="list_missing"),
            # dicts
            param(OmegaConf.create({}), id="dict_empty"),
            param(OmegaConf.create({"a": "b"}), id="dict"),
            param(OmegaConf.create({"a": {
                "b": []
            }}), id="dict"),
            param(DictConfig(content=None), id="dict_none"),
        ],
    )
    def test_copy(self, copy_method: Any, src: Any) -> None:
        cp = copy_method(src)
        assert src is not cp
        assert src == cp

    @mark.parametrize(
        "src",
        [
            param(
                DictConfig(content={
                    "a": {
                        "c": 10
                    },
                    "b": DictConfig(content="${a}")
                }),
                id="dict_inter",
            )
        ],
    )
    def test_copy_dict_inter(self, copy_method: Any, src: Any) -> None:
        # test direct copying of the b node (without de-referencing by accessing)
        cp = copy_method(src._get_node("b"))
        assert src.b is not cp
        assert OmegaConf.is_interpolation(src, "b")
        assert OmegaConf.is_interpolation(cp)
        assert src._get_node("b")._value() == cp._value()

        # test copy of src and ensure interpolation is copied as interpolation
        cp2 = copy_method(src)
        assert OmegaConf.is_interpolation(cp2, "b")

    @mark.parametrize(
        "src,interpolating_key,interpolated_key",
        [([1, 2, "${0}"], 2, 0), ({
            "a": 10,
            "b": "${a}"
        }, "b", "a")],
    )
    def test_copy_with_interpolation(self, copy_method: Any, src: Any,
                                     interpolating_key: str,
                                     interpolated_key: str) -> None:
        cfg = OmegaConf.create(src)
        assert cfg[interpolated_key] == cfg[interpolating_key]
        cp = copy_method(cfg)
        assert id(cfg) != id(cp)
        assert cp[interpolated_key] == cp[interpolating_key]
        assert cfg[interpolated_key] == cp[interpolating_key]

        # Interpolation is preserved in original
        cfg[interpolated_key] = "XXX"
        assert cfg[interpolated_key] == cfg[interpolating_key]

        # Test interpolation is preserved in copy
        cp[interpolated_key] = "XXX"
        assert cp[interpolated_key] == cp[interpolating_key]

    def test_list_shallow_copy_is_deepcopy(self, copy_method: Any) -> None:
        cfg = OmegaConf.create([[10, 20]])
        cp = copy_method(cfg)
        assert cfg is not cp
        assert cfg[0] is not cp[0]
 def _test_merge(self, ref_type: Any, value: Any, assign: Any,
                 expectation: Any) -> None:
     cfg = OmegaConf.create(
         {"foo": DictConfig(ref_type=ref_type, content=value)})
     with expectation:
         OmegaConf.merge(cfg, {"foo": assign})
            7,
            InterpolationResultNode,
            id="convert_str_to_int",
        ),
        param(
            MissingList(list=SI("${oc.create:[a, b, c]}")),
            "list",
            ListConfig(["a", "b", "c"]),
            ListConfig,
            id="list_str",
        ),
        param(
            MissingDict(dict=SI("${oc.create:{key1: val1, key2: val2}}")),
            "dict",
            DictConfig({
                "key1": "val1",
                "key2": "val2"
            }),
            DictConfig,
            id="dict_str",
        ),
    ],
)
def test_interpolation_type_validated_ok(
    cfg: Any,
    key: str,
    expected_value: Any,
    expected_node_type: Any,
    common_resolvers: Any,
) -> Any:
    def drop_last(s: str) -> str:
        return s[0:-1]  # drop last character from string `s`
Beispiel #26
0
 # simple
 (dict(a=12), dict(a=12)),
 # any vs raw
 (dict(a=12), dict(a=AnyNode(12))),
 # nested dict empty
 (dict(a=12, b=dict()), dict(a=12, b=dict())),
 # nested dict
 (dict(a=12, b=dict(c=10)), dict(a=12, b=dict(c=10))),
 # nested list
 (dict(a=12, b=[1, 2, 3]), dict(a=12, b=[1, 2, 3])),
 # nested list with any
 (dict(a=12, b=[1, 2, AnyNode(3)]), dict(a=12, b=[1, 2,
                                                  AnyNode(3)])),
 # In python 3.6+ insert order changes iteration order. this ensures that equality is preserved.
 (dict(a=1, b=2, c=3, d=4, e=5), dict(e=5, b=2, c=3, d=4, a=1)),
 (DictConfig(content=None), DictConfig(content=None)),
 # With interpolations
 ([10, "${0}"], [10, 10]),
 (dict(a=12, b="${a}"), dict(a=12, b=12)),
 # With missing interpolation
 ([10, "${0}"], [10, 10]),
 (dict(a="${missing}"), dict(a="${missing}")),
 (User, User),
 ({
     "name": "poo",
     "age": 7
 }, User(name="poo", age=7)),
 (Group, Group),
 ({
     "group": {
         "admin": None
def test_resolve_interpolation_without_parent_no_throw() -> None:
    cfg = DictConfig(content="${foo}")
    assert cfg._maybe_dereference_node() is None
        Main function that runs the entire image composition process
        :return: None
        """
        self._validate_and_process_cfg()
        self._generate_images()
        self._create_info()
        log.info(f'Done composing images for {self.cfg.description}')


if __name__ == '__main__':
    _cfg = {
        'option': True,
        'name': 'experiments',
        'input_dir': 'tools/experiments/image_composition',
        'output_dir': 'tools/experiments/image_composition/training',
        'num_images': 2,
        'max_foregrounds': 15,
        'output_width': 512,
        'output_height': 512,
        'output_type': 'png',
        'description': 'experiment',
        'url': 'none',
        'version': '1.0',
        'contributor': 'DK',
        'license_name': 'free',
        'license_url': 'none'
    }
    cfg = DictConfig(_cfg)
    image_comp_training: ImageComposition = ImageComposition(cfg)
    image_comp_training.main()
Beispiel #29
0
def asr_model():
    preprocessor = {
        'cls':
        'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
        'params': dict({})
    }

    # fmt: off
    labels = [
        ' ',
        'a',
        'b',
        'c',
        'd',
        'e',
        'f',
        'g',
        'h',
        'i',
        'j',
        'k',
        'l',
        'm',
        'n',
        'o',
        'p',
        'q',
        'r',
        's',
        't',
        'u',
        'v',
        'w',
        'x',
        'y',
        'z',
        "'",
    ]
    # fmt: on

    model_defaults = {'enc_hidden': 1024, 'pred_hidden': 64}

    encoder = {
        'cls': 'nemo.collections.asr.modules.ConvASREncoder',
        'params': {
            'feat_in':
            64,
            'activation':
            'relu',
            'conv_mask':
            True,
            'jasper': [{
                'filters': model_defaults['enc_hidden'],
                'repeat': 1,
                'kernel': [1],
                'stride': [1],
                'dilation': [1],
                'dropout': 0.0,
                'residual': False,
                'separable': True,
                'se': True,
                'se_context_size': -1,
            }],
        },
    }

    decoder = {
        '_target_': 'nemo.collections.asr.modules.RNNTDecoder',
        'prednet': {
            'pred_hidden': model_defaults['pred_hidden'],
            'pred_rnn_layers': 1
        },
    }

    joint = {
        '_target_': 'nemo.collections.asr.modules.RNNTJoint',
        'jointnet': {
            'joint_hidden': 32,
            'activation': 'relu'
        },
    }

    decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}}

    modelConfig = DictConfig({
        'labels': ListConfig(labels),
        'preprocessor': DictConfig(preprocessor),
        'model_defaults': DictConfig(model_defaults),
        'encoder': DictConfig(encoder),
        'decoder': DictConfig(decoder),
        'joint': DictConfig(joint),
        'decoding': DictConfig(decoding),
    })

    model_instance = EncDecRNNTModel(cfg=modelConfig)
    return model_instance
Beispiel #30
0
) -> None:
    c = OmegaConf.create(input_)
    c[key] = value
    assert c[key] == value
    assert c[key] == value._value()


@pytest.mark.parametrize(  # type: ignore
    "input_",
    [
        pytest.param([1, 2, 3], id="list"),
        pytest.param([1, 2, {"a": 3}], id="dict_in_list"),
        pytest.param([1, 2, [10, 20]], id="list_in_list"),
        pytest.param({"b": {"b": 10}}, id="dict_in_dict"),
        pytest.param({"b": [False, 1, "2", 3.0, Color.RED]}, id="list_in_dict"),
        pytest.param({"b": DictConfig(content=None)}, id="none_dictconfig"),
        pytest.param({"b": ListConfig(content=None)}, id="none_listconfig"),
        pytest.param({"b": DictConfig(content="???")}, id="missing_dictconfig"),
        pytest.param({"b": ListConfig(content="???")}, id="missing_listconfig"),
    ],
)
def test_to_container_returns_primitives(input_: Any) -> None:
    def assert_container_with_primitives(item: Any) -> None:
        if isinstance(item, list):
            for v in item:
                assert_container_with_primitives(v)
        elif isinstance(item, dict):
            for _k, v in item.items():
                assert_container_with_primitives(v)
        else:
            assert isinstance(item, (int, float, str, bool, type(None), Enum))