Beispiel #1
0
def main(cfg: DictConfig):
    cfg = setup_cfg_gpu(cfg)
    logger.info("CFG (after gpu  configuration):")
    logger.info("%s", OmegaConf.to_yaml(cfg))

    saved_state = load_states_from_checkpoint(cfg.model_file)
    set_cfg_params_from_state(saved_state.encoder_params, cfg)

    tensorizer, encoder, _ = init_biencoder_components(
        cfg.encoder.encoder_model_type, cfg, inference_only=True)
    with omegaconf.open_dict(cfg):
        cfg.others = DictConfig(
            {"is_matching": isinstance(encoder, Match_BiEncoder)})

    encoder_path = cfg.encoder_path
    if encoder_path:
        logger.info("Selecting encoder: %s", encoder_path)
        encoder = getattr(encoder, encoder_path)
    else:
        logger.info("Selecting standard question encoder")
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device,
                                            cfg.n_gpu, cfg.local_rank,
                                            cfg.fp16)
    encoder.eval()

    if cfg.others.is_matching:
        match_layer = MatchLayer(encoder=get_model_obj(encoder))
        match_layer, _ = setup_for_distributed_mode(match_layer, None,
                                                    cfg.device, cfg.n_gpu,
                                                    cfg.local_rank, cfg.fp16)
        match_layer.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")

    encoder_prefix = (encoder_path if encoder_path else "question_model") + "."
    prefix_len = len(encoder_prefix)

    logger.info("Encoder state prefix %s", encoder_prefix)
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith(encoder_prefix)
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    # load weights from the model file for the match layer
    if cfg.others.is_matching:
        model_to_load = get_model_obj(match_layer)
        logger.info("Loading saved match layer state ...")

        match_layer_state = {
            key: value
            for key, value in saved_state.model_dict.items()
            if key.startswith("linear")
        }
        logger.info(
            f"Loading saved match layer state with {len(match_layer_state)} weight matrices..."
        )
        model_to_load.load_state_dict(match_layer_state)

    # get questions & answers
    questions = []
    question_answers = []

    if not cfg.qa_dataset:
        logger.warning("Please specify qa_dataset to use")
        return

    ds_key = cfg.qa_dataset
    logger.info("qa_dataset: %s", ds_key)

    qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
    qa_src.load_data()

    for ds_item in qa_src.data:
        question, answers = ds_item.query, ds_item.answers
        questions.append(question)
        question_answers.append(answers)

    index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
    logger.info("Index class %s ", type(index))
    index_buffer_sz = index.buffer_size
    index.init_index(vector_size)

    if not cfg.others.is_matching:
        retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer,
                                        index)
    else:
        retriever = LocalFaissRetrieverWithMatchModels(cfg, encoder,
                                                       match_layer,
                                                       cfg.batch_size,
                                                       tensorizer, index)

    logger.info("Using special token %s", qa_src.special_query_token)
    questions_tensor = retriever.generate_question_vectors(
        questions, query_token=qa_src.special_query_token)

    if qa_src.selector:
        logger.info("Using custom representation token selector")
        retriever.selector = qa_src.selector

    id_prefixes = []
    ctx_sources = []
    for ctx_src in cfg.ctx_datatsets:
        ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
        id_prefixes.append(ctx_src.id_prefix)
        ctx_sources.append(ctx_src)

    logger.info("id_prefixes per dataset: %s", id_prefixes)

    # index all passages
    ctx_files_patterns = cfg.encoded_ctx_files
    index_path = cfg.index_path

    logger.info("ctx_files_patterns: %s", ctx_files_patterns)
    if ctx_files_patterns:
        assert len(ctx_files_patterns) == len(
            id_prefixes), "ctx len={} pref leb={}".format(
                len(ctx_files_patterns), len(id_prefixes))
    else:
        assert (
            index_path
        ), "Either encoded_ctx_files or index_path parameter should be set."

    input_paths = []
    path_id_prefixes = []
    for i, pattern in enumerate(ctx_files_patterns):
        pattern_files = glob.glob(pattern)
        pattern_id_prefix = id_prefixes[i]
        input_paths.extend(pattern_files)
        path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))

    logger.info("Embeddings files id prefixes: %s", path_id_prefixes)

    if index_path and index.index_exists(index_path):
        logger.info("Index path: %s", index_path)
        retriever.index.deserialize(index_path)
    else:
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index_encoded_data(input_paths,
                                     index_buffer_sz,
                                     path_id_prefixes=path_id_prefixes)
        if index_path:
            retriever.index.serialize(index_path)

    # get top k results
    if cfg.others.is_matching:
        top_ids_and_scores_first_stage, top_ids_and_scores_second_stage = retriever.get_top_docs(
            questions_tensor.numpy(),
            top_docs=cfg.n_docs,
            top_docs_match=cfg.n_docs_match)
        top_ids_and_scores = top_ids_and_scores_first_stage
    else:
        top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(),
                                                    top_docs=cfg.n_docs)

    # we no longer need the index
    retriever = None

    all_passages = {}
    for ctx_src in ctx_sources:
        ctx_src.load_data_to(all_passages)

    if len(all_passages) == 0:
        raise RuntimeError(
            "No passages data found. Please specify ctx_file param properly.")

    if cfg.validate_as_tables:
        questions_doc_hits = validate_tables(
            all_passages,
            question_answers,
            top_ids_and_scores,
            cfg.validation_workers,
            cfg.match,
        )
        if cfg.others.is_matching:
            questions_doc_hits_match = validate_tables(
                all_passages,
                question_answers,
                top_ids_and_scores_second_stage,
                cfg.validation_workers,
                cfg.match,
            )
    else:
        questions_doc_hits = validate(
            all_passages,
            question_answers,
            top_ids_and_scores,
            cfg.validation_workers,
            cfg.match,
        )
        if cfg.others.is_matching:
            questions_doc_hits_match = validate(
                all_passages,
                question_answers,
                top_ids_and_scores_second_stage,
                cfg.validation_workers,
                cfg.match,
            )

    if cfg.out_file:
        save_results(
            all_passages,
            questions,
            question_answers,
            top_ids_and_scores,
            questions_doc_hits,
            cfg.out_file,
        )
        if cfg.others.is_matching:
            out_file, _ = os.path.splitext(cfg.out_file)
            out_file = f"{out_file}_match.json"
            save_results(
                all_passages,
                questions,
                question_answers,
                top_ids_and_scores_second_stage,
                questions_doc_hits_match,
                out_file,
            )

    if cfg.kilt_out_file:
        kilt_ctx = next(
            iter([
                ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)
            ]), None)
        if not kilt_ctx:
            raise RuntimeError("No Kilt compatible context file provided")
        assert hasattr(cfg, "kilt_out_file")
        kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file,
                                 cfg.kilt_out_file)
Beispiel #2
0
) -> None:
    c = OmegaConf.create(input_)
    c[key] = value
    assert c[key] == value
    assert c[key] == value._value()


@pytest.mark.parametrize(  # type: ignore
    "input_",
    [
        pytest.param([1, 2, 3], id="list"),
        pytest.param([1, 2, {"a": 3}], id="dict_in_list"),
        pytest.param([1, 2, [10, 20]], id="list_in_list"),
        pytest.param({"b": {"b": 10}}, id="dict_in_dict"),
        pytest.param({"b": [False, 1, "2", 3.0, Color.RED]}, id="list_in_dict"),
        pytest.param({"b": DictConfig(content=None)}, id="none_dictconfig"),
        pytest.param({"b": ListConfig(content=None)}, id="none_listconfig"),
        pytest.param({"b": DictConfig(content="???")}, id="missing_dictconfig"),
        pytest.param({"b": ListConfig(content="???")}, id="missing_listconfig"),
    ],
)
def test_to_container_returns_primitives(input_: Any) -> None:
    def assert_container_with_primitives(item: Any) -> None:
        if isinstance(item, list):
            for v in item:
                assert_container_with_primitives(v)
        elif isinstance(item, dict):
            for _k, v in item.items():
                assert_container_with_primitives(v)
        else:
            assert isinstance(item, (int, float, str, bool, type(None), Enum))
Beispiel #3
0
    def launch_task_run_or_die(
            self,
            run_config: DictConfig,
            shared_state: Optional[SharedTaskState] = None) -> str:
        """
        Parse the given arguments and launch a job.
        """
        set_mephisto_log_level(level=run_config.get("log_level", "info"))

        requester, provider_type = self._get_requester_and_provider_from_config(
            run_config)

        # Next get the abstraction classes, and run validation
        # before anything is actually created in the database
        blueprint_type = run_config.blueprint._blueprint_type
        architect_type = run_config.architect._architect_type
        BlueprintClass = get_blueprint_from_type(blueprint_type)
        ArchitectClass = get_architect_from_type(architect_type)
        CrowdProviderClass = get_crowd_provider_from_type(provider_type)

        if shared_state is None:
            shared_state = BlueprintClass.SharedStateClass()

        BlueprintClass.assert_task_args(run_config, shared_state)
        ArchitectClass.assert_task_args(run_config, shared_state)
        CrowdProviderClass.assert_task_args(run_config, shared_state)

        # Find an existing task or create a new one
        task_name = run_config.task.get("task_name", None)
        if task_name is None:
            task_name = blueprint_type
            logger.warning(
                f"Task is using the default blueprint name {task_name} as a name, "
                "as no task_name is provided")
        tasks = self.db.find_tasks(task_name=task_name)
        task_id = None
        if len(tasks) == 0:
            task_id = self.db.new_task(task_name, blueprint_type)
        else:
            task_id = tasks[0].db_id

        logger.info(f"Creating a task run under task name: {task_name}")

        # Create a new task run
        new_run_id = self.db.new_task_run(
            task_id,
            requester.db_id,
            json.dumps(OmegaConf.to_yaml(run_config, resolve=True)),
            provider_type,
            blueprint_type,
            requester.is_sandbox(),
        )
        task_run = TaskRun.get(self.db, new_run_id)

        live_run = self._create_live_task_run(
            run_config,
            shared_state,
            task_run,
            ArchitectClass,
            BlueprintClass,
            CrowdProviderClass,
        )

        try:
            # If anything fails after here, we have to cleanup the architect
            # Setup and deploy the server
            built_dir = live_run.architect.prepare()
            task_url = live_run.architect.deploy()

            # TODO(#102) maybe the cleanup (destruction of the server configuration?) should only
            # happen after everything has already been reviewed, this way it's possible to
            # retrieve the exact build directory to review a task for real
            live_run.architect.cleanup()

            # Register the task with the provider
            live_run.provider.setup_resources_for_task_run(
                task_run, run_config, shared_state, task_url)

            live_run.client_io.launch_channels()
        except (KeyboardInterrupt, Exception) as e:
            logger.error(
                "Encountered error while launching run, shutting down",
                exc_info=True)
            try:
                live_run.architect.shutdown()
            except (KeyboardInterrupt, Exception) as architect_exception:
                logger.exception(
                    f"Could not shut down architect: {architect_exception}",
                    exc_info=True,
                )
            raise e

        live_run.task_launcher.create_assignments()
        live_run.task_launcher.launch_units(task_url)

        self._task_runs_tracked[task_run.db_id] = live_run
        task_run.update_completion_progress(status=False)

        return task_run.db_id
Beispiel #4
0
class TestCopy:
    @mark.parametrize(
        "src",
        [
            # lists
            param(OmegaConf.create([]), id="list_empty"),
            param(OmegaConf.create([1, 2]), id="list"),
            param(OmegaConf.create(["a", "b", "c"]), id="list"),
            param(ListConfig(content=None), id="list_none"),
            param(ListConfig(content="???"), id="list_missing"),
            # dicts
            param(OmegaConf.create({}), id="dict_empty"),
            param(OmegaConf.create({"a": "b"}), id="dict"),
            param(OmegaConf.create({"a": {"b": []}}), id="dict"),
            param(DictConfig(content=None), id="dict_none"),
        ],
    )
    def test_copy(self, copy_method: Any, src: Any) -> None:
        cp = copy_method(src)
        assert src is not cp
        assert src == cp

    @mark.parametrize(
        "src",
        [
            param(
                DictConfig(content={"a": {"c": 10}, "b": DictConfig(content="${a}")}),
                id="dict_inter",
            )
        ],
    )
    def test_copy_dict_inter(self, copy_method: Any, src: Any) -> None:
        # test direct copying of the b node (without de-referencing by accessing)
        cp = copy_method(src._get_node("b"))
        assert src.b is not cp
        assert OmegaConf.is_interpolation(src, "b")
        assert OmegaConf.is_interpolation(cp)
        assert src._get_node("b")._value() == cp._value()

        # test copy of src and ensure interpolation is copied as interpolation
        cp2 = copy_method(src)
        assert OmegaConf.is_interpolation(cp2, "b")

    @mark.parametrize(
        "src,interpolating_key,interpolated_key",
        [([1, 2, "${0}"], 2, 0), ({"a": 10, "b": "${a}"}, "b", "a")],
    )
    def test_copy_with_interpolation(
        self, copy_method: Any, src: Any, interpolating_key: str, interpolated_key: str
    ) -> None:
        cfg = OmegaConf.create(src)
        assert cfg[interpolated_key] == cfg[interpolating_key]
        cp = copy_method(cfg)
        assert id(cfg) != id(cp)
        assert cp[interpolated_key] == cp[interpolating_key]
        assert cfg[interpolated_key] == cp[interpolating_key]

        # Interpolation is preserved in original
        cfg[interpolated_key] = "XXX"
        assert cfg[interpolated_key] == cfg[interpolating_key]

        # Test interpolation is preserved in copy
        cp[interpolated_key] = "XXX"
        assert cp[interpolated_key] == cp[interpolating_key]

    def test_list_shallow_copy_is_deepcopy(self, copy_method: Any) -> None:
        cfg = OmegaConf.create([[10, 20]])
        cp = copy_method(cfg)
        assert cfg is not cp
        assert cfg[0] is not cp[0]
Beispiel #5
0
def test_resolve_interpolation_without_parent_no_throw() -> None:
    cfg = DictConfig(content="${foo}")
    assert cfg._dereference_node(throw_on_resolution_failure=False) is None
def citrinet_model():
    preprocessor = {
        'cls':
        'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
        'params': dict({})
    }
    encoder = {
        'cls': 'nemo.collections.asr.modules.ConvASREncoder',
        'params': {
            'feat_in':
            80,
            'activation':
            'relu',
            'conv_mask':
            True,
            'jasper': [
                {
                    'filters': 512,
                    'repeat': 1,
                    'kernel': [5],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': False,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                },
                {
                    'filters': 512,
                    'repeat': 5,
                    'kernel': [11],
                    'stride': [2],
                    'dilation': [1],
                    'dropout': 0.1,
                    'residual': True,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                    'stride_last': True,
                    'residual_mode': 'stride_add',
                },
                {
                    'filters': 512,
                    'repeat': 5,
                    'kernel': [13],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.1,
                    'residual': True,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                },
                {
                    'filters': 640,
                    'repeat': 1,
                    'kernel': [41],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': True,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                },
            ],
        },
    }

    decoder = {
        'cls': 'nemo.collections.asr.modules.ConvASRDecoder',
        'params': {
            'feat_in': 640,
            'num_classes': 1024,
            'vocabulary': list(chr(i % 28) for i in range(0, 1024))
        },
    }

    modelConfig = DictConfig({
        'preprocessor': DictConfig(preprocessor),
        'encoder': DictConfig(encoder),
        'decoder': DictConfig(decoder)
    })
    citri_model = EncDecSpeakerLabelModel(cfg=modelConfig)
    return citri_model
Beispiel #7
0
    def _register_vocab_from_tokenizer(
        self,
        vocab_file_config_path: str = 'tokenizer.vocab_file',
        vocab_dict_config_path: str = 'tokenizer_vocab_dict',
        cfg: DictConfig = None,
    ):
        """Creates vocab file from tokenizer if vocab file is None.

        Args:
            vocab_file_config_path: path to the vocab_file in the config
            vocab_dict_config_path: path to the vocab_dict in the config
            cfg: tokenizer config
        """
        if self.tokenizer is None:
            raise ValueError(
                'Instantiate self.tokenizer before registering vocab from it.')
        else:
            if isinstance(self.tokenizer, AutoTokenizer):
                # extract vocab from tokenizer
                vocab_dict = self.tokenizer.tokenizer.get_vocab()

                # for fast and slow tokenizer vocabularies compatibility
                vocab_dict = dict(
                    sorted(vocab_dict.items(), key=lambda item: item[1]))

                # get hash of vocab_dict to create a unique directory to write vocab_dict and vocab_file
                m = hashlib.md5()
                if 'tokenizer_name' in cfg:
                    if cfg.tokenizer_name is not None:
                        # different pretrained models with the same vocab will have different hash
                        m.update(cfg.tokenizer_name.encode())
                # get string representation of vocab_dict
                vocab_dict_str = json.dumps(vocab_dict,
                                            sort_keys=True).encode()
                m.update(vocab_dict_str)
                vocab_dict_hash = m.hexdigest()

                hash_path = os.path.join(NEMO_NLP_TMP, vocab_dict_hash)
                os.makedirs(hash_path, exist_ok=True)

                vocab_json_src = os.path.join(hash_path,
                                              vocab_dict_config_path)

                with open(vocab_json_src, 'w', encoding='utf-8') as f:
                    f.write(
                        json.dumps(vocab_dict, indent=2, sort_keys=True) +
                        '\n')
                self.register_artifact(config_path=vocab_dict_config_path,
                                       src=vocab_json_src)
                # create vocab file
                vocab_file_src = os.path.join(hash_path,
                                              vocab_file_config_path)
                with open(vocab_file_src, 'w', encoding='utf-8') as f:
                    for key in vocab_dict:
                        f.write(key + '\n')

                cfg.vocab_file = vocab_file_src
                self.register_artifact(config_path=vocab_file_config_path,
                                       src=vocab_file_src)
            else:
                logging.info(
                    f'Registering tokenizer vocab for {self.tokenizer} is not yet supported. Please override this method if needed.'
                )
Beispiel #8
0
    def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
        """merge src into dest and return a new copy, does not modified input"""
        from omegaconf import AnyNode, DictConfig, ValueNode

        assert isinstance(dest, DictConfig)
        assert isinstance(src, DictConfig)
        src_type = src._metadata.object_type
        src_ref_type = get_ref_type(src)
        assert src_ref_type is not None

        # If source DictConfig is:
        #  - None => set the destination DictConfig to None
        #  - an interpolation => set the destination DictConfig to be the same interpolation
        if src._is_none() or src._is_interpolation():
            dest._set_value(src._value())
            _update_types(node=dest,
                          ref_type=src_ref_type,
                          object_type=src_type)
            return

        dest._validate_merge(value=src)

        def expand(node: Container) -> None:
            rt = node._metadata.ref_type
            val: Any
            if rt is not Any:
                if is_dict_annotation(rt):
                    val = {}
                elif is_list_annotation(rt):
                    val = []
                else:
                    val = rt
            elif isinstance(node, DictConfig):
                val = {}
            else:
                assert False

            node._set_value(val)

        if (src._is_missing() and not dest._is_missing()
                and is_structured_config(src_ref_type)):
            # Replace `src` with a prototype of its corresponding structured config
            # whose fields are all missing (to avoid overwriting fields in `dest`).
            src = _create_structured_with_missing_fields(ref_type=src_ref_type,
                                                         object_type=src_type)

        if (dest._is_interpolation()
                or dest._is_missing()) and not src._is_missing():
            expand(dest)

        src_items = src.items_ex(
            resolve=False) if not src._is_missing() else []
        for key, src_value in src_items:
            src_node = src._get_node(key, validate_access=False)
            dest_node = dest._get_node(key, validate_access=False)
            assert src_node is None or isinstance(src_node, Node)
            assert dest_node is None or isinstance(dest_node, Node)

            if isinstance(dest_node, DictConfig):
                dest_node._validate_merge(value=src_node)

            missing_src_value = _is_missing_value(src_value)

            if (isinstance(dest_node, Container) and dest_node._is_none()
                    and not missing_src_value
                    and not _is_none(src_value, resolve=True)):
                expand(dest_node)

            if dest_node is not None and dest_node._is_interpolation():
                target_node = dest_node._dereference_node(
                    throw_on_resolution_failure=False)
                if isinstance(target_node, Container):
                    dest[key] = target_node
                    dest_node = dest._get_node(key)

            if (dest_node is None
                    and is_structured_config(dest._metadata.element_type)
                    and not missing_src_value):
                # merging into a new node. Use element_type as a base
                dest[key] = DictConfig(content=dest._metadata.element_type,
                                       parent=dest)
                dest_node = dest._get_node(key)

            if dest_node is not None:
                if isinstance(dest_node, BaseContainer):
                    if isinstance(src_value, BaseContainer):
                        dest_node._merge_with(src_value)
                    elif not missing_src_value:
                        dest.__setitem__(key, src_value)
                else:
                    if isinstance(src_value, BaseContainer):
                        dest.__setitem__(key, src_value)
                    else:
                        assert isinstance(dest_node, ValueNode)
                        assert isinstance(src_node, ValueNode)
                        # Compare to literal missing, ignoring interpolation
                        src_node_missing = _is_missing_literal(src_value)
                        try:
                            if isinstance(dest_node, AnyNode):
                                if src_node_missing:
                                    node = copy.copy(src_node)
                                    # if src node is missing, use the value from the dest_node,
                                    # but validate it against the type of the src node before assigment
                                    node._set_value(dest_node._value())
                                else:
                                    node = src_node
                                dest.__setitem__(key, node)
                            else:
                                if not src_node_missing:
                                    dest_node._set_value(src_value)

                        except (ValidationError, ReadonlyConfigError) as e:
                            dest._format_and_raise(key=key,
                                                   value=src_value,
                                                   cause=e)
            else:
                from omegaconf import open_dict

                if is_structured_config(src_type):
                    # verified to be compatible above in _validate_merge
                    with open_dict(dest):
                        dest[key] = src._get_node(key)
                else:
                    dest[key] = src._get_node(key)

        _update_types(node=dest, ref_type=src_ref_type, object_type=src_type)

        # explicit flags on the source config are replacing the flag values in the destination
        flags = src._metadata.flags
        assert flags is not None
        for flag, value in flags.items():
            if value is not None:
                dest._set_flag(flag, value)
Beispiel #9
0
         }
     }, {
         "a": "???"
     }),
     {"a": {
         "b": 10
     }},
     id="dict_merge_missing_onto",
 ),
 pytest.param(
     ({
         "a": {
             "b": 10
         }
     }, {
         "a": DictConfig(content="???")
     }),
     {"a": {
         "b": 10
     }},
     id="dict_merge_missing_onto",
 ),
 pytest.param(
     ({}, {
         "a": "???"
     }),
     {"a": "???"},
     id="dict_merge_missing_onto_no_node",
 ),
 pytest.param(
     (
Beispiel #10
0
        assert ret == None  # noqa E711


@mark.parametrize(
    "target_type, value, expected",
    [
        # Any
        param(Any, "foo", AnyNode("foo"), id="any"),
        param(Any, b"binary", AnyNode(b"binary"), id="any"),
        param(Any, Path("hello.txt"), AnyNode(Path("hello.txt")), id="any"),
        param(Any, True, AnyNode(True), id="any"),
        param(Any, 1, AnyNode(1), id="any"),
        param(Any, 1.0, AnyNode(1.0), id="any"),
        param(Any, b"123", AnyNode(b"123"), id="any"),
        param(Any, Color.RED, AnyNode(Color.RED), id="any"),
        param(Any, {}, DictConfig(content={}), id="any_as_dict"),
        param(Any, [], ListConfig(content=[]), id="any_as_list"),
        # int
        param(int, "foo", ValidationError, id="int"),
        param(int, b"binary", ValidationError, id="int"),
        param(int, Path("hello.txt"), ValidationError, id="int"),
        param(int, True, ValidationError, id="int"),
        param(int, 1, IntegerNode(1), id="int"),
        param(int, 1.0, ValidationError, id="int"),
        param(int, Color.RED, ValidationError, id="int"),
        param(int, b"123", ValidationError, id="int"),
        # float
        param(float, "foo", ValidationError, id="float"),
        param(float, b"binary", ValidationError, id="float"),
        param(float, Path("hello.txt"), ValidationError, id="float"),
        param(float, True, ValidationError, id="float"),
Beispiel #11
0
def test_get_value_container(content: Any) -> None:
    cfg = DictConfig({})
    cfg._set_value(content)
    assert _get_value(cfg) == content
Beispiel #12
0
    def test_interpolation(self, node_type: Any, values: Any,
                           restore_resolvers: Any) -> None:
        resolver_output = 9999
        OmegaConf.register_resolver("func", lambda: resolver_output)
        values = copy.deepcopy(values)
        for value in values:
            node = {
                "reg": node_type(value=value, is_optional=False),
                "opt": node_type(value=value, is_optional=True),
            }
            cfg = OmegaConf.create({
                "const":
                10,
                "primitive_missing":
                "???",
                "resolver":
                StringNode(value="${func:}", is_optional=False),
                "opt_resolver":
                StringNode(value="${func:}", is_optional=True),
                "node":
                DictConfig(content=node, is_optional=False),
                "opt_node":
                DictConfig(content=node, is_optional=True),
                "reg":
                node_type(value=value, is_optional=False),
                "opt":
                node_type(value=value, is_optional=True),
                "opt_none":
                node_type(value=None, is_optional=True),
                "missing":
                node_type(value="???", is_optional=False),
                "opt_missing":
                node_type(value="???", is_optional=True),
                # Interpolations
                "int_reg":
                "${reg}",
                "int_opt":
                "${opt}",
                "int_opt_none":
                "${opt_none}",
                "int_missing":
                "${missing}",
                "int_opt_missing":
                "${opt_missing}",
                "str_int_const":
                StringNode(value="foo_${const}", is_optional=False),
                "opt_str_int_const":
                StringNode(value="foo_${const}", is_optional=True),
                "str_int_with_primitive_missing":
                StringNode(value="foo_${primitive_missing}",
                           is_optional=False),
                "opt_str_int_with_primitive_missing":
                StringNode(value="foo_${primitive_missing}", is_optional=True),
                "int_node":
                "${node}",
                "int_opt_node":
                "${opt_node}",
                "int_resolver":
                "${resolver}",
                "int_opt_resolver":
                "${opt_resolver}",
            })

            verify(cfg,
                   "const",
                   none=False,
                   opt=True,
                   missing=False,
                   inter=False,
                   exp=10)

            verify(
                cfg,
                "resolver",
                none=False,
                # Note, resolvers are always optional because the underlying function may return None
                opt=True,
                missing=False,
                inter=True,
                exp=resolver_output,
            )

            verify(
                cfg,
                "opt_resolver",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=resolver_output,
            )

            verify(
                cfg,
                "reg",
                none=False,
                opt=False,
                missing=False,
                inter=False,
                exp=value,
            )

            verify(cfg,
                   "opt",
                   none=False,
                   opt=True,
                   missing=False,
                   inter=False,
                   exp=value)
            verify(
                cfg,
                "opt_none",
                none=True,
                opt=True,
                missing=False,
                inter=False,
                exp=None,
            )
            verify(cfg,
                   "missing",
                   none=False,
                   opt=False,
                   missing=True,
                   inter=False)
            verify(cfg,
                   "opt_missing",
                   none=False,
                   opt=True,
                   missing=True,
                   inter=False)

            verify(
                cfg,
                "int_reg",
                none=False,
                opt=False,
                missing=False,
                inter=True,
                exp=value,
            )
            verify(
                cfg,
                "int_opt",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=value,
            )
            verify(
                cfg,
                "int_opt_none",
                none=True,
                opt=True,
                missing=False,
                inter=True,
                exp=None,
            )
            verify(cfg,
                   "int_missing",
                   none=False,
                   opt=False,
                   missing=True,
                   inter=True)
            verify(cfg,
                   "int_opt_missing",
                   none=False,
                   opt=True,
                   missing=True,
                   inter=True)

            verify(
                cfg,
                "str_int_const",
                none=False,
                opt=False,
                missing=False,
                inter=True,
                exp="foo_10",
            )
            verify(
                cfg,
                "opt_str_int_const",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp="foo_10",
            )
            verify(
                cfg,
                "int_node",
                none=False,
                opt=False,
                missing=False,
                inter=True,
                exp=node,
            )

            verify(
                cfg,
                "int_opt_node",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=node,
            )

            verify(
                cfg,
                "int_resolver",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=resolver_output,
            )

            verify(
                cfg,
                "int_opt_resolver",
                none=False,
                opt=True,
                missing=False,
                inter=True,
                exp=resolver_output,
            )

            verify(
                cfg,
                "str_int_with_primitive_missing",
                none=False,
                opt=False,
                missing=True,
                inter=True,
            )

            verify(
                cfg,
                "opt_str_int_with_primitive_missing",
                none=False,
                opt=True,
                missing=True,
                inter=True,
            )
Beispiel #13
0
 "node_type, values",
 [
     (BooleanNode, [True, False]),
     (FloatNode, [3.1415]),
     (IntegerNode, [42]),
     (StringNode, ["hello"]),
     # EnumNode
     (
         lambda value, is_optional, key=None: EnumNode(
             enum_type=Color, value=value, is_optional=is_optional, key=key
         ),
         [Color.RED],
     ),
     # DictConfig
     (
         lambda value, is_optional, key=None: DictConfig(
             is_optional=is_optional, content=value, key=key),
         [{}, {
             "foo": "bar"
         }],
     ),
     # ListConfig
     (
         lambda value, is_optional, key=None: ListConfig(
             is_optional=is_optional, content=value, key=key),
         [[], [1, 2, 3]],
     ),
     # dataclass
     (
         lambda value, is_optional, key=None: DictConfig(ref_type=Group,
                                                         is_optional=
                                                         is_optional,
Beispiel #14
0
from envs.env_wrapper import (
    PettingZooEnvWrapper,
    NumpyStateMixin,
    petting_zoo_random_player,
)
from models import GenericLinearModel
from settings import device


class TicTacToeEnvWrapper(PettingZooEnvWrapper, NumpyStateMixin):
    def __init__(self):
        super(TicTacToeEnvWrapper, self).__init__(
            env=tictactoe_v3.env(), opponent_policy=petting_zoo_random_player
        )


if __name__ == "__main__":

    hp = DictConfig({})

    hp.steps = 20
    hp.batch_size = 2
    hp.max_steps = 10
    hp.lr = 1e-3
    hp.epsilon_exploration = 0.1
    hp.gamma_discount = 0.9

    model = GenericLinearModel(18, [10], 9, flatten=True).float().to(device)

    train_dqn(TicTacToeEnvWrapper, model, hp, name="TicTacToe")
Beispiel #15
0
 ),
 pytest.param(
     Expected(
         create=lambda: create_readonly({"foo": "bar"}),
         op=lambda cfg: setattr(cfg, "foo", 20),
         exception_type=ReadonlyConfigError,
         msg="Cannot change read-only config container",
         key="foo",
         child_node=lambda cfg: cfg.foo,
     ),
     id="dict,readonly:set_attribute",
 ),
 pytest.param(
     Expected(
         create=lambda: OmegaConf.create(
             {"foo": DictConfig(is_optional=False, content={})}),
         op=lambda cfg: setattr(cfg, "foo", None),
         exception_type=ValidationError,
         msg="child 'foo' is not Optional",
         key="foo",
         full_key="foo",
         child_node=lambda cfg: cfg.foo,
     ),
     id="dict:setattr:not_optional:set_none",
 ),
 pytest.param(
     Expected(
         create=lambda: OmegaConf.structured(ConcretePlugin),
         op=lambda cfg: cfg.params.__setattr__("foo", "bar"),
         exception_type=ValidationError,
         msg="Value 'bar' could not be converted to Integer",
Beispiel #16
0
def train(config: DictConfig) -> Optional[float]:
    """Contains training pipeline.
    Instantiates all PyTorch Lightning objects from config.

    Args:
        config (DictConfig): Configuration composed by Hydra.

    Returns:
        Optional[float]: Metric score for hyperparameter optimization.
    """

    # Set seed for random number generators in pytorch, numpy and python.random
    if "seed" in config:
        seed_everything(config.seed)

    # Init augmentations, they require primitive types to be instantiated
    train_augs: Compose = None
    valid_augs: Compose = None
    if "augmentations" in config:
        train_augs = Compose(
            utils.instantiate_list(config.augmentations.train,
                                   group="train augs.",
                                   primitive=True))
        valid_augs = Compose(
            utils.instantiate_list(config.augmentations.valid,
                                   group="valid augs.",
                                   primitive=True))

    # Init Lightning datamodule
    LOG.info(f"Instantiating datamodule <{config.datamodule._target_}>")
    datamodule: LightningDataModule = hydra.utils.instantiate(
        config.datamodule,
        train_transforms=train_augs,
        valid_transforms=valid_augs)

    # Init Lightning model
    LOG.info(f"Instantiating model <{config.model._target_}>")
    model: LightningModule = hydra.utils.instantiate(config.model)

    # Init Lightning callbacks
    callbacks: List[Callback] = []
    if "callbacks" in config:
        callbacks = utils.instantiate_list(config.callbacks, group="callback")

    # Init Lightning loggers
    loggers: List[LightningLoggerBase] = []
    if "logger" in config:
        loggers = utils.instantiate_list(config.logger, group="logger")

    # Init trainer plugins
    plugins: List[Plugin] = []
    if "plugin" in config:
        plugins = utils.instantiate_list(config.plugin, group="plugin")

    # Init Lightning trainer
    LOG.info(f"Instantiating trainer <{config.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(config.trainer,
                                               callbacks=callbacks,
                                               logger=loggers,
                                               plugins=plugins,
                                               _convert_="partial")

    # Send some parameters from config to all lightning loggers
    LOG.info("Logging hyperparameters")
    utils.log_hyperparameters(config=config,
                              model=model,
                              datamodule=datamodule,
                              trainer=trainer,
                              callbacks=callbacks,
                              logger=loggers)

    # Train the model
    LOG.info("Starting training")
    trainer.fit(model=model, datamodule=datamodule)

    # Evaluate model on test set after training
    if not config.trainer.get("fast_dev_run"):
        LOG.info("Starting testing...")
        trainer.test()

    # Make sure everything closed properly
    LOG.info("Finalizing")
    utils.finish(config=config,
                 model=model,
                 datamodule=datamodule,
                 trainer=trainer,
                 callbacks=callbacks,
                 logger=loggers)

    # Print path to best checkpoint
    LOG.info(
        f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}"
    )

    # Return metric score for Optuna optimization
    optimized_metric = config.get("optimized_metric")
    if optimized_metric:
        return trainer.callback_metrics[optimized_metric]
Beispiel #17
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params:\n {OmegaConf.to_yaml(cfg)}')
    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))

    # initialize the model using the config file
    model = MultiLabelIntentSlotClassificationModel(cfg.model, trainer=trainer)

    # training
    logging.info(
        "================================================================================================"
    )
    logging.info('Starting training...')
    trainer.fit(model)
    logging.info('Training finished!')

    # Stop further testing as fast_dev_run does not save checkpoints
    if trainer.fast_dev_run:
        return

    # after model training is done, you can load the model from the saved checkpoint
    # and evaluate it on a data file or on given queries.
    logging.info(
        "================================================================================================"
    )
    logging.info("Starting the testing of the trained model on test set...")
    logging.info(
        "We will load the latest model saved checkpoint from the training...")

    # for evaluation and inference you can load the previously trained model saved in .nemo file
    # like this in your code, but we will just reuse the trained model here
    # eval_model = MultiLabelIntentSlotClassificationModel.restore_from(restore_path=checkpoint_path)
    eval_model = model

    # we will setup testing data reusing the same config (test section)
    eval_model.update_data_dir_for_testing(data_dir=cfg.model.data_dir)
    eval_model.setup_test_data(test_data_config=cfg.model.test_ds)

    trainer.test(model=eval_model, ckpt_path=None, verbose=False)
    logging.info("Testing finished!")

    # Optimize Threshold
    eval_model.optimize_threshold(cfg.model.test_ds, 'dev')

    # run an inference on a few examples
    logging.info(
        "======================================================================================"
    )
    logging.info("Evaluate the model on the given queries...")

    # this will work well if you train the model on ATIS dataset
    # for your own dataset change the examples appropriately
    queries = [
        'i would like to find a flight from charlotte to las vegas that makes a stop in st. louis',
        'on april first i need a ticket from tacoma to san jose departing before 7 am',
        'how much is the limousine service in boston',
    ]

    # We use the optimized threshold for predictions
    pred_intents, pred_slots, pred_list = eval_model.predict_from_examples(
        queries, cfg.model.test_ds)
    logging.info(
        'The prediction results of some sample queries with the trained model:'
    )

    for query, intent, slots in zip(queries, pred_intents, pred_slots):
        logging.info(f'Query : {query}')
        logging.info(f'Predicted Intents: {intent}')
        logging.info(f'Predicted Slots: {slots}')

    logging.info("Inference finished!")
Beispiel #18
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Base class from which all NeMo models should inherit

        Args:
            cfg (DictConfig):  configuration object.
                The cfg object should have (optionally) the following sub-configs:

                * train_ds - to instantiate training dataset
                * validation_ds - to instantiate validation dataset
                * test_ds - to instantiate testing dataset
                * optim - to instantiate optimizer with learning rate scheduler

            trainer (Optional): Pytorch Lightning Trainer instance
        """
        if not isinstance(cfg, DictConfig):
            raise ValueError(f"cfg constructor argument must be of type DictConfig but got {type(cfg)} instead.")
        if trainer is not None and not isinstance(trainer, Trainer):
            raise ValueError(
                f"trainer constructor argument must be either None or pytroch_lightning.Trainer. But got {type(trainer)} instead."
            )
        super().__init__()
        if 'target' not in cfg:
            # This is for Jarvis service.
            OmegaConf.set_struct(cfg, False)
            cfg.target = "{0}.{1}".format(self.__class__.__module__, self.__class__.__name__)
            OmegaConf.set_struct(cfg, True)

        config = OmegaConf.to_container(cfg, resolve=True)
        config = OmegaConf.create(config)
        OmegaConf.set_struct(config, True)

        self._cfg = config

        self.save_hyperparameters(self._cfg)
        self._train_dl = None
        self._validation_dl = None
        self._test_dl = None
        self._optimizer = None
        self._scheduler = None
        self._trainer = trainer

        # Set device_id in AppState
        if torch.cuda.current_device() is not None:
            app_state = AppState()
            app_state.device_id = torch.cuda.current_device()

        if self._cfg is not None and not self.__is_model_being_restored():
            if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
                self.setup_training_data(self._cfg.train_ds)

            if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None:
                self.setup_multiple_validation_data(val_data_config=None)

            if 'test_ds' in self._cfg and self._cfg.test_ds is not None:
                self.setup_multiple_test_data(test_data_config=None)

        else:
            if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
                logging.warning(
                    f"Please call the ModelPT.setup_training_data() method "
                    f"and provide a valid configuration file to setup the train data loader.\n"
                    f"Train config : \n{OmegaConf.to_yaml(self._cfg.train_ds)}"
                )

            if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None:
                logging.warning(
                    f"Please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method "
                    f"and provide a valid configuration file to setup the validation data loader(s). \n"
                    f"Validation config : \n{OmegaConf.to_yaml(self._cfg.validation_ds)}"
                )

            if 'test_ds' in self._cfg and self._cfg.test_ds is not None:
                logging.warning(
                    f"Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method "
                    f"and provide a valid configuration file to setup the test data loader(s).\n"
                    f"Test config : \n{OmegaConf.to_yaml(self._cfg.test_ds)}"
                )
Beispiel #19
0
def citrinet_rnnt_model():
    labels = list(chr(i % 28) for i in range(0, 1024))
    model_defaults = {
        'enc_hidden': 640,
        'pred_hidden': 256,
        'joint_hidden': 320
    }

    preprocessor = {
        'cls':
        'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
        'params': dict({})
    }
    encoder = {
        '_target_':
        'nemo.collections.asr.modules.ConvASREncoder',
        'feat_in':
        80,
        'activation':
        'relu',
        'conv_mask':
        True,
        'jasper': [
            {
                'filters': 512,
                'repeat': 1,
                'kernel': [5],
                'stride': [1],
                'dilation': [1],
                'dropout': 0.0,
                'residual': False,
                'separable': True,
                'se': True,
                'se_context_size': -1,
            },
            {
                'filters': 512,
                'repeat': 5,
                'kernel': [11],
                'stride': [2],
                'dilation': [1],
                'dropout': 0.1,
                'residual': True,
                'separable': True,
                'se': True,
                'se_context_size': -1,
                'stride_last': True,
                'residual_mode': 'stride_add',
            },
            {
                'filters': 512,
                'repeat': 5,
                'kernel': [13],
                'stride': [1],
                'dilation': [1],
                'dropout': 0.1,
                'residual': True,
                'separable': True,
                'se': True,
                'se_context_size': -1,
            },
            {
                'filters': 640,
                'repeat': 1,
                'kernel': [41],
                'stride': [1],
                'dilation': [1],
                'dropout': 0.0,
                'residual': True,
                'separable': True,
                'se': True,
                'se_context_size': -1,
            },
        ],
    }

    decoder = {
        '_target_': 'nemo.collections.asr.modules.RNNTDecoder',
        'prednet': {
            'pred_hidden': 256,
            'pred_rnn_layers': 1,
            'dropout': 0.0
        },
    }

    joint = {
        '_target_': 'nemo.collections.asr.modules.RNNTJoint',
        'fuse_loss_wer': False,
        'jointnet': {
            'joint_hidden': 320,
            'activation': 'relu',
            'dropout': 0.0
        },
    }

    decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 5}}

    modelConfig = DictConfig({
        'preprocessor': DictConfig(preprocessor),
        'labels': labels,
        'model_defaults': DictConfig(model_defaults),
        'encoder': DictConfig(encoder),
        'decoder': DictConfig(decoder),
        'joint': DictConfig(joint),
        'decoding': DictConfig(decoding),
    })
    citri_model = EncDecRNNTModel(cfg=modelConfig)
    return citri_model
Beispiel #20
0
def main(cfg: DictConfig) -> None:
    print(cfg.pretty())
    from train import Workspace as W
    workspace = W(cfg)
    workspace.run()
Beispiel #21
0
    assert c2.a.b == 10


# Yes, there was a bug that was a combination of an interaction between the three
def test_deepcopy_and_merge_and_flags() -> None:
    c1 = OmegaConf.create(
        {"dataset": {"name": "imagenet", "path": "/datasets/imagenet"}, "defaults": []}
    )
    OmegaConf.set_struct(c1, True)
    c2 = copy.deepcopy(c1)
    with raises(ConfigKeyError):
        OmegaConf.merge(c2, OmegaConf.from_dotlist(["dataset.bad_key=yes"]))


@mark.parametrize(
    "cfg", [ListConfig(element_type=int, content=[]), DictConfig(content={})]
)
def test_deepcopy_preserves_container_type(cfg: Container) -> None:
    cp: Container = copy.deepcopy(cfg)
    assert cp._metadata.element_type == cfg._metadata.element_type


@mark.parametrize(
    "src, flag_name, func, expectation",
    [
        param(
            {},
            "struct",
            lambda c: c.__setitem__("foo", 1),
            raises(KeyError),
            id="struct_setiitem",
 def _build_tfm(conf_tfm: DictConfig):
     tfm = [instantiate(v) for k, v in conf_tfm.items()]
     tfm = A.Compose(tfm,
                     bbox_params=A.BboxParams(**conf.bbox_params,
                                              label_fields=["cls"]))
     return tfm
Beispiel #23
0
def test_resolve_interpolation_without_parent() -> None:
    with raises(
        InterpolationResolutionError,
        match=re.escape("Cannot resolve interpolation for a node without a parent"),
    ):
        DictConfig(content="${foo}")._dereference_node()
Beispiel #24
0
def train_dqn_double(
    env_class: Type[EnvWrapper],
    model: nn.Module,
    config: DictConfig,
    project_name=None,
    run_name=None,
):
    env = BatchEnvWrapper(env_class, config.batch_size)
    env.reset()
    optim = torch.optim.Adam(model.parameters(), lr=config.lr)
    epsilon_scheduler = decay_functions[config.epsilon_decay_function]

    target_model = deepcopy(model)
    target_model.load_state_dict(model.state_dict())
    target_model.eval()

    wandb.init(
        name=f"{run_name}_{str(datetime.now().timestamp())[5:10]}",
        project=project_name or "testing_dqn",
        config=dict(config),
        save_code=True,
        group=None,
        tags=None,  # List of string tags
        notes=None,  # longer description of run
        dir=BASE_DIR,
    )
    wandb.watch(model)
    replay = PrioritizedReplay(
        buffer_size=config.replay_size,
        batch_size=config.replay_batch,
        delete_freq=config.delete_freq,
        delete_percentage=config.delete_percentage,
        transform=state_action_reward_state_2_transform,
    )
    env_recorder = EnvRecorder(config.env_record_freq, config.env_record_duration)
    sample_actions = EpsilonRandomActionSampler()

    cumulative_reward = 0
    cumulative_done = 0

    # ======= Start training ==========

    # We need _some_ initial replay buffer to start with.
    store_initial_replay(env, replay)

    for step in range(config.steps):
        log = DictConfig({"step": step})

        (
            states_replay,
            actions_replay,
            rewards_replay,
            states2_replay,
        ) = replay.get_batch()
        states = _combine(env.get_state_batch(), states_replay)

        q_pred = model(states)

        epsilon_exploration = epsilon_scheduler(config, log)
        actions_live = sample_actions(
            valid_actions=env.get_legal_actions(),
            q_values=q_pred[: config.batch_size],
            epsilon=epsilon_exploration,
        )

        # ============ Observe the reward && predict value of next state ==============

        states2, actions, rewards, dones_live = step_with_replay(
            env, actions_live, actions_replay, states2_replay, rewards_replay
        )

        with torch.no_grad():
            q_next_target = target_model(states2)
            model.eval()
            q_next_primary = model(states2)
            model.train()

        # Bellman equation
        state2_primary_actions = torch.argmax(q_next_primary, dim=1)
        state2_value = q_next_target[range(len(q_next_target)), state2_primary_actions]
        value = rewards + config.gamma_discount * state2_value

        q_select_actions = q_pred[range(len(q_pred)), actions]

        # =========== LEARN ===============

        loss = F.mse_loss(q_select_actions, value, reduction="none")

        replay.add_batch(loss, (states, actions, rewards, states2))
        loss = torch.mean(loss)

        optim.zero_grad()
        loss.backward()
        optim.step()

        # Copy parameters ever so often
        if step % config.target_model_sync_freq == 0:
            target_model.load_state_dict(model.state_dict())

        # ============ Logging =============

        log.loss = loss.item()

        max_reward = torch.amax(rewards, 0).item()
        min_reward = torch.amin(rewards, 0).item()
        mean_reward = torch.mean(rewards, 0).item()
        log.max_reward = max_reward
        log.min_reward = min_reward
        log.mean_reward = mean_reward

        cumulative_done += dones_live.sum()  # number of dones
        log.cumulative_done = int(cumulative_done)

        cumulative_reward += mean_reward
        log.cumulative_reward = cumulative_reward

        log.epsilon_exploration = epsilon_exploration

        env_recorder.record(step, env.envs, wandb)

        wandb.log(log)
    def __init__(self):
        super(MockModel_, self).__init__(DictConfig({"conv_type": "Dummy"}))

        self._channels = [12, 12, 12, 17]
        self.nn = MLP(self._channels)
Beispiel #26
0
def build_dataloader_and_sampler(
    dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
            dataloader has to be created
        datamodule_config (omegaconf.DictConfig): Datamodule configuration; required
            for infering params for dataloader

    Returns:
        Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
            Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    training_config = get_global_config("training")
    # Support params coming in from dataloader params
    other_args = {
        "num_workers":
        datamodule_config.get("num_workers",
                              training_config.get("num_workers", 4)),
        "pin_memory":
        datamodule_config.get("pin_memory",
                              training_config.get("pin_memory", False)),
        "shuffle":
        datamodule_config.get("shuffle", None),
        "batch_size":
        datamodule_config.get("batch_size", None),
    }

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance,
                                                    other_args)
    else:
        other_args.pop("shuffle")

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        collate_fn=BatchCollator(dataset_instance.dataset_name,
                                 dataset_instance.dataset_type),
        drop_last=is_xla(),  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = xla_pl.MpDeviceLoader(loader, device)

    if other_args["num_workers"] >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
Beispiel #27
0
def my_app(cfg: DictConfig) -> None:
    print(cfg.pretty())
Beispiel #28
0
def main_as_plugin(path_plugin: str, path_wav: Union[str, None]) -> str:
    """
    UtauPluginオブジェクトから音声ファイルを作る
    """
    # UTAUの一時ファイルに書いてある設定を読み取る
    print(f'{datetime.now()} : reading settings in TMP')
    path_ust, voice_dir, _ = get_project_path(path_plugin)
    path_enuconfig = join(voice_dir, 'enuconfig.yaml')

    # configファイルがあるか調べて、なければ例外処理
    if not exists(path_enuconfig):
        raise Exception('音源フォルダに enuconfig.yaml が見つかりません。'
                        'UTAU音源選択でENUNU用モデルを指定してください。')
    # カレントディレクトリを音源フォルダに変更する
    chdir(voice_dir)

    # configファイルを読み取る
    print(f'{datetime.now()} : reading enuconfig')
    config = DictConfig(OmegaConf.load(path_enuconfig))

    # 日付時刻を取得
    str_now = datetime.now().strftime('%Y%m%d_%H%M%S')

    # wav出力パスが指定されていない(プラグインとして実行している)場合
    if path_wav is None:
        # 入出力パスを設定する
        if path_ust is not None:
            songname = splitext(basename(path_ust))[0]
            out_dir = dirname(path_ust)
            temp_dir = join(out_dir, f'{songname}_enutemp')
            path_wav = abspath(join(out_dir, f'{songname}__{str_now}.wav'))
        # WAV出力パス指定なしかつUST未保存の場合
        else:
            print('USTが保存されていないので一時フォルダにWAV出力します。')
            songname = f'temp__{str_now}'
            out_dir = mkdtemp(prefix='enunu-')
            temp_dir = join(out_dir, f'{songname}_enutemp')
            path_wav = abspath(join(out_dir, f'{songname}__{str_now}.wav'))
    # WAV出力パスが指定されている場合
    else:
        songname = splitext(basename(path_wav))[0]
        out_dir = dirname(path_wav)
        temp_dir = join(out_dir, f'{songname}_enutemp')
        path_wav = abspath(path_wav)

    # 一時出力フォルダがなければつくる
    makedirs(temp_dir, exist_ok=True)
    # 各種出力ファイルのパスを設定
    # path_plugin = path_plugin
    path_temp_ust = abspath(join(temp_dir, 'temp.ust'))
    path_temp_table = abspath(join(temp_dir, 'temp.table'))
    path_full_score = abspath(join(temp_dir, 'score.full'))
    path_mono_score = abspath(join(temp_dir, 'score.lab'))
    path_full_timing = abspath(join(temp_dir, 'timing.full'))
    path_mono_timing = abspath(join(temp_dir, 'timing.lab'))
    path_acoustic = abspath(join(temp_dir, 'acoustic.csv'))
    path_f0 = abspath(join(temp_dir, 'f0.csv'))
    path_spectrogram = abspath(join(temp_dir, 'spectrogram.csv'))
    path_aperiodicity = abspath(join(temp_dir, 'aperiodicity.csv'))

    # USTを一時フォルダに複製
    print(f'{datetime.now()} : copying UST')
    copy(path_plugin, path_temp_ust)
    print(f'{datetime.now()} : copying Table')
    copy(config.table_path, path_temp_table)

    # USTを事前加工------------------------------------------------------------------
    extension_list = get_extension_path_list(config, 'ust_editor')
    if extension_list is not None:
        for path_extension in extension_list:
            print(f'{datetime.now()} : editing UST with {path_extension}')
            enulib.extensions.run_extension(path_extension, ust=path_temp_ust)

    # フルラベル(score)生成----------------------------------------------------------
    converter = get_standard_function_config(config, 'ust_converter')
    # フルラベル生成をしない場合
    if converter is None:
        pass
    # ENUNUの組み込み機能でUST→LAB変換をする場合
    elif converter == 'built-in':
        print(
            f'{datetime.now()} : converting UST to score with built-in function'
        )
        enulib.utauplugin2score.utauplugin2score(path_temp_ust,
                                                 path_temp_table,
                                                 path_full_score,
                                                 strict_sinsy_style=False)
        # full_score から mono_score を生成
        enulib.common.full2mono(path_full_score, path_mono_score)
    # 外部ソフトでUST→LAB変換をする場合
    else:
        print(
            f'{datetime.now()} : converting UST to score with built-in function{converter}'
        )
        enulib.extensions.run_extension(converter,
                                        ust=path_temp_ust,
                                        table=path_temp_table,
                                        full_score=path_full_score,
                                        mono_score=path_mono_score)

    # フルラベル(score)を加工-------------------------------------------------------
    extension_list = get_extension_path_list(config, 'score_editor')

    # フルラベル生成を行う場合
    if extension_list is not None:
        for path_extension in extension_list:
            print(f'{datetime.now()} : editing score with {path_extension}')
            # 変更前のモノラベルを読んでおく
            with open(path_mono_score, encoding='utf-8') as f:
                str_mono_old = f.read()
            # 外部ソフトを実行
            enulib.extensions.run_extension(path_extension,
                                            ust=path_temp_ust,
                                            table=path_temp_table,
                                            full_score=path_full_score,
                                            mono_score=path_mono_score)
            # 変更後のモノラベルを読む
            with open(path_mono_score, encoding='utf-8') as f:
                str_mono_new = f.read()

            # モノラベルの時刻が変わっていたらフルラベルに転写して、
            # そうでなければフルラベルの時刻をモノラベルに転写する。
            # NOTE: 歌詞が変更されていると思って処理する。
            # モノラベルが更新されている場合
            if enulib.extensions.str_has_been_changed(str_mono_old,
                                                      str_mono_new):
                # モノラベルの時刻をフルラベルに転写する。
                enulib.extensions.merge_mono_time_change_to_full(
                    path_mono_score, path_full_score)
                # モノラベルの音素記号をフルラベルに転写する。
                enulib.extensions.merge_mono_contexts_change_to_full(
                    path_mono_score, path_full_score)
            # フルラベルに更新があった場合、フルラベルの時刻をモノラベルに転写する。
            else:
                enulib.extensions.merge_full_time_change_to_mono(
                    path_full_score, path_mono_score)

    # フルラベル(timing) を生成 score.full -> timing.full-----------------
    calculator = get_standard_function_config(config, 'timing_calculator')
    # duration計算をしない場合
    if calculator is None:
        print(f'{datetime.now()} : skipped timing calculation')
    # ENUNUの組み込み機能で計算する場合
    elif calculator == 'built-in':
        print(f'{datetime.now()} : calculating timing with built-in function')
        enulib.timing.score2timing(config, path_full_score, path_full_timing)
        # フルラベルからモノラベルを生成
        enulib.common.full2mono(path_full_timing, path_mono_timing)
    # 外部ソフトで計算する場合
    else:
        print(f'{datetime.now()} : calculating timing with {calculator}')
        enulib.extensions.run_extension(calculator,
                                        ust=path_temp_ust,
                                        table=path_temp_table,
                                        full_score=path_full_score,
                                        mono_score=path_mono_score,
                                        full_timing=path_full_timing,
                                        mono_timing=path_mono_timing)

    # フルラベル(timing) を加工: timing.full -> timing.full----------------------
    extension_list = get_extension_path_list(config, 'timing_editor')
    if extension_list is not None:
        # 複数ツールのすべてについて処理実施する
        for path_extension in extension_list:
            print(f'{datetime.now()} : editing timing with {path_extension}')
            # 変更前のモノラベルを読んでおく
            with open(path_mono_timing, encoding='utf-8') as f:
                str_mono_old = f.read()
            enulib.extensions.run_extension(path_extension,
                                            ust=path_temp_ust,
                                            table=path_temp_table,
                                            full_score=path_full_score,
                                            mono_score=path_mono_score,
                                            full_timing=path_full_timing,
                                            mono_timing=path_mono_timing)
            # 変更後のモノラベルを読む
            with open(path_mono_timing, encoding='utf-8') as f:
                str_mono_new = f.read()
            # モノラベルの時刻が変わっていたらフルラベルに転写して、
            # そうでなければフルラベルの時刻をモノラベルに転写する。
            # NOTE: 歌詞は編集していないという前提で処理する。
            if enulib.extensions.str_has_been_changed(str_mono_old,
                                                      str_mono_new):
                enulib.extensions.merge_mono_time_change_to_full(
                    path_mono_timing, path_full_timing)
            else:
                enulib.extensions.merge_full_time_change_to_mono(
                    path_full_timing, path_mono_timing)

    # 音響パラメータを推定 timing.full -> acoustic---------------------------
    calculator = get_standard_function_config(config, 'acoustic_calculator')
    # 計算をしない場合
    if calculator is None:
        print(f'{datetime.now()} : skipped acoustic calculation')
    elif calculator == 'built-in':
        print(
            f'{datetime.now()} : calculating acoustic with built-in function')
        # timing.full から acoustic.csv を作る。
        enulib.acoustic.timing2acoustic(config, path_full_timing,
                                        path_acoustic)
        # acoustic のファイルから f0, spectrogram, aperiodicity のファイルを出力
        enulib.world.acoustic2world(config, path_full_timing, path_acoustic,
                                    path_f0, path_spectrogram,
                                    path_aperiodicity)
    else:
        print(f'{datetime.now()} : calculating acoustic with {calculator}')
        enulib.extensions.run_extension(calculator,
                                        ust=path_temp_ust,
                                        table=path_temp_table,
                                        full_score=path_full_score,
                                        mono_score=path_mono_score,
                                        full_timing=path_full_timing,
                                        mono_timing=path_mono_timing,
                                        acoustic=path_acoustic,
                                        f0=path_f0,
                                        spectrogram=path_spectrogram,
                                        aperiodicity=path_aperiodicity)

    # 音響パラメータを加工: acoustic.csv -> acoustic.csv -------------------------
    extension_list = get_extension_path_list(config, 'acoustic_editor')
    if extension_list is not None:
        for path_extension in extension_list:
            print(f'{datetime.now()} : editing acoustic with {path_extension}')
            enulib.extensions.run_extension(path_extension,
                                            ust=path_temp_ust,
                                            table=path_temp_table,
                                            full_score=path_full_score,
                                            mono_score=path_mono_score,
                                            full_timing=path_full_timing,
                                            mono_timing=path_mono_timing,
                                            acoustic=path_acoustic,
                                            f0=path_f0,
                                            spectrogram=path_spectrogram,
                                            aperiodicity=path_aperiodicity)

    # WORLDを使って音声ファイルを生成: acoustic.csv -> <songname>.wav--------------
    synthesizer = get_standard_function_config(config, 'wav_synthesizer')

    # ここでは合成をしない場合
    if synthesizer is None:
        print(f'{datetime.now()} : skipped synthesizing WAV')

    # 組み込まれたWORLDで合成する場合
    elif synthesizer == 'built-in':
        print(f'{datetime.now()} : synthesizing WAV with built-in function')
        # WAVファイル出力
        enulib.world.world2wav(config, path_f0, path_spectrogram,
                               path_aperiodicity, path_wav)

    # 別途指定するソフトで合成する場合
    else:
        print(f'{datetime.now()} : synthesizing WAV with {synthesizer}')
        enulib.extensions.run_extension(synthesizer,
                                        ust=path_temp_ust,
                                        table=path_temp_table,
                                        full_score=path_full_score,
                                        mono_score=path_mono_score,
                                        full_timing=path_full_timing,
                                        mono_timing=path_mono_timing,
                                        acoustic=path_acoustic,
                                        f0=path_f0,
                                        spectrogram=path_spectrogram,
                                        aperiodicity=path_aperiodicity)

    # 音声ファイルを加工: <songname>.wav -> <songname>.wav
    extension_list = get_extension_path_list(config, 'wav_editor')
    if extension_list is not None:
        for path_extension in extension_list:
            print(f'{datetime.now()} : editing WAV with {path_extension}')
            enulib.extensions.run_extension(path_extension,
                                            ust=path_temp_ust,
                                            table=path_temp_table,
                                            full_score=path_full_score,
                                            mono_score=path_mono_score,
                                            full_timing=path_full_timing,
                                            mono_timing=path_mono_timing,
                                            acoustic=path_acoustic,
                                            f0=path_f0,
                                            spectrogram=path_spectrogram,
                                            aperiodicity=path_aperiodicity)

    # print(f'{datetime.now()} : converting LAB to JSON')
    # hts2json(path_full_score, path_json)

    # 音声を再生する。
    if exists(path_wav):
        startfile(path_wav)

    return path_wav
Beispiel #29
0
    def test_from_config_dict_without_cls(self):
        """Here we test that instantiation works for configs without cls class path in them.
        IMPORTANT: in this case, correct class type should call from_config_dict. This should work for Models."""
        preprocessor = {
            'cls':
            'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
            'params': dict({})
        }
        encoder = {
            'cls': 'nemo.collections.asr.modules.ConvASREncoder',
            'params': {
                'feat_in':
                64,
                'activation':
                'relu',
                'conv_mask':
                True,
                'jasper': [{
                    'filters': 1024,
                    'repeat': 1,
                    'kernel': [1],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': False,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                }],
            },
        }

        decoder = {
            'cls': 'nemo.collections.asr.modules.ConvASRDecoder',
            'params': {
                'feat_in':
                1024,
                'num_classes':
                28,
                'vocabulary': [
                    ' ',
                    'a',
                    'b',
                    'c',
                    'd',
                    'e',
                    'f',
                    'g',
                    'h',
                    'i',
                    'j',
                    'k',
                    'l',
                    'm',
                    'n',
                    'o',
                    'p',
                    'q',
                    'r',
                    's',
                    't',
                    'u',
                    'v',
                    'w',
                    'x',
                    'y',
                    'z',
                    "'",
                ],
            },
        }
        modelConfig = DictConfig({
            'preprocessor': DictConfig(preprocessor),
            'encoder': DictConfig(encoder),
            'decoder': DictConfig(decoder)
        })
        obj = EncDecCTCModel.from_config_dict(config=modelConfig)
        assert isinstance(obj, EncDecCTCModel)
Beispiel #30
0
def run(args: DictConfig) -> None:
    # Load datasets
    train_transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(),
         transforms.RandomCrop(32, padding=4)])

    preprocess = transforms.ToTensor()
    test_transform = preprocess

    data_dir = hydra.utils.to_absolute_path(args.data_dir)
    if args.dataset == 'cifar10':
        train_data = datasets.CIFAR10(
            data_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR10(
            data_dir, train=False, transform=test_transform, download=True)
        base_c_path = os.path.join(data_dir, 'CIFAR-10-C/')
        # args.n_classes = 10
    else:
        train_data = datasets.CIFAR100(
            data_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR100(
            data_dir, train=False, transform=test_transform, download=True)

        base_c_path = os.path.join(data_dir, 'CIFAR-100-C/')
        # args.n_classes = 100

    train_data = AugMixDataset(train_data, preprocess, args, args.no_jsd)
    train_loader = DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    test_loader = DataLoader(
        test_data,
        batch_size=args.eval_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True)

    n_classes = args.get(args.dataset).n_classes
    classifier = resnet18(n_classes=n_classes).to(args.device)
    logger.info('Model resnet18, # parameters: {}'.format(cal_parameters(classifier)))

    cudnn.benchmark = True

    if args.inference:
        classifier.load_state_dict(torch.load('resnet18_c.pth'))
        test_loss, test_acc = eval_epoch(classifier, test_loader, args)
        logger.info('Clean Test CE:{:.4f}, acc:{:.4f}'.format(test_loss, test_acc))
    else:
        optimizer = torch.optim.SGD(
            classifier.parameters(),
            args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=True)

        best_loss = 1e5
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
                step,
                args.epochs * len(train_loader),
                1,  # lr_lambda computes multiplicative factor
                1e-6 / args.learning_rate))

        for epoch in range(args.epochs):
            loss, ce_loss, js_loss, acc = train_epoch(classifier, train_loader,  args, optimizer, scheduler)

            lr = scheduler.get_lr()[0]
            logger.info('Epoch {}, lr:{:.4f}, loss:{:.4f}, CE:{:.4f}, JS:{:.4f}, Acc:{:.4f}'
                        .format(epoch + 1, lr, loss, ce_loss, js_loss, acc))

            test_loss, test_acc = eval_epoch(classifier, test_loader, args)
            logger.info('Clean test CE:{:.4f}, acc:{:.4f}'.format(test_loss, test_acc))

            if loss < best_loss:
                best_loss = loss
                logging.info('===> New optimal, save checkpoint ...')
                torch.save(classifier.state_dict(), 'resnet18_c.pth')

    test_c_acc = eval_c(classifier, base_c_path, args)
    logger.info('Mean Corruption Error:{:.4f}'.format(1 - test_c_acc))