Ejemplo n.º 1
0
    def configure_device(self) -> None:
        if self.config.training.get("device", "cuda") == "xla":
            import torch_xla.core.xla_model as xm

            self.device = xm.xla_device()
            self.distributed = True
            self.local_rank = xm.get_local_ordinal()
            is_xla = True
        else:
            is_xla = False
            if "device_id" not in self.config:
                warnings.warn(
                    "No 'device_id' in 'config', setting to -1. "
                    "This can cause issues later in training. Ensure that "
                    "distributed setup is properly initialized."
                )
                self.local_rank = -1
            else:
                self.local_rank = self.config.device_id
            self.device = self.local_rank
            self.distributed = False

        # Will be updated later based on distributed setup
        registry.register("global_device", self.device)

        if self.config.distributed.init_method is not None:
            self.distributed = True
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.local_rank)
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
            torch.cuda.set_device(0)
        elif not is_xla:
            self.device = torch.device("cpu")

        if "rank" not in self.config.distributed:
            if torch.distributed.is_available() and torch.distributed.is_initialized():
                global_rank = torch.distributed.get_rank()
            else:
                global_rank = -1
            with open_dict(self.config.distributed):
                self.config.distributed.rank = global_rank

        registry.register("global_device", self.config.distributed.rank)
Ejemplo n.º 2
0
    def launch(self,
               job_overrides: Sequence[Sequence[str]]) -> Sequence[JobReturn]:
        """
        :param job_overrides: a List of List<String>, where each inner list is the arguments for one job run.
        :return: an array of return values from run_job with indexes corresponding to the input list indexes.
        """
        setup_globals()
        assert self.config is not None
        assert self.config_loader is not None
        assert self.task_function is not None

        configure_log(self.config.hydra.hydra_logging,
                      self.config.hydra.verbose)
        sweep_dir = Path(str(self.config.hydra.sweep.dir))
        sweep_dir.mkdir(parents=True, exist_ok=True)
        log.info(
            "Example Launcher(foo={}, bar={}) is launching {} jobs locally".
            format(self.foo, self.bar, len(job_overrides)))
        log.info("Sweep output dir : {}".format(sweep_dir))
        runs = []

        for idx, overrides in enumerate(job_overrides):
            log.info("\t#{} : {}".format(idx, " ".join(
                filter_overrides(overrides))))
            sweep_config = self.config_loader.load_sweep_config(
                self.config, list(overrides))
            with open_dict(sweep_config):
                # This typically coming from the underlying scheduler (SLURM_JOB_ID for instance)
                # In that case, it will not be available here because we are still in the main process.
                # but instead should be populated remotely before calling the task_function.
                sweep_config.hydra.job.id = "job_id_for_{}".format(idx)
                sweep_config.hydra.job.num = idx
            HydraConfig.instance().set_config(sweep_config)

            ret = run_job(
                config=sweep_config,
                task_function=self.task_function,
                job_dir_key="hydra.sweep.dir",
                job_subdir_key="hydra.sweep.subdir",
            )
            runs.append(ret)
            configure_log(self.config.hydra.hydra_logging,
                          self.config.hydra.verbose)
        return runs
Ejemplo n.º 3
0
    def test_load_config_with_schema(self, hydra_restore_singletons: Any,
                                     path: str) -> None:

        ConfigStore.instance().store(name="config",
                                     node=TopLevelConfig,
                                     provider="this_test")
        ConfigStore.instance().store(
            group="db",
            name="mysql",
            node=MySQLConfig,
            provider="this_test",
        )

        config_loader = ConfigLoaderImpl(
            config_search_path=create_config_search_path(path))

        cfg = config_loader.load_configuration(config_name="config",
                                               overrides=["+db=mysql"])
        with open_dict(cfg):
            del cfg["hydra"]
        assert cfg == {
            "normal_yaml_config": True,
            "db": {
                "driver": "mysql",
                "host": "???",
                "port": "???",
                "user": "******",
                "password": "******",
            },
        }

        expected = hydra_load_list.copy()
        expected.append(LoadTrace("config", path, "main", "this_test"))
        expected.append(LoadTrace("db/mysql", path, "main", "this_test"))
        assert config_loader.get_load_history() == expected

        # verify illegal modification is rejected at runtime
        with pytest.raises(ValidationError):
            cfg.db.port = "fail"

        # verify illegal override is rejected during load
        with pytest.raises(HydraException):
            config_loader.load_configuration(config_name="db/mysql",
                                             overrides=["db.port=fail"])
Ejemplo n.º 4
0
def test_experimental_save_job_info_callback(tmpdir: Path, multirun: bool) -> None:
    app_path = "tests/test_apps/app_with_pickle_job_info_callback/my_app.py"

    cmd = [
        app_path,
        "hydra.run.dir=" + str(tmpdir),
        "hydra.sweep.dir=" + str(tmpdir),
        "hydra.job.chdir=True",
    ]
    if multirun:
        cmd.append("-m")
    _, _err = run_python_script(cmd)

    def load_pickle(path: Path) -> Any:
        with open(str(path), "rb") as input:
            obj = pickle.load(input)  # nosec
        return obj

    # load pickles from callbacks
    callback_output = tmpdir / Path("0") / ".hydra" if multirun else tmpdir / ".hydra"
    config_on_job_start = load_pickle(callback_output / "config.pickle")
    job_return_on_job_end: JobReturn = load_pickle(
        callback_output / "job_return.pickle"
    )

    task_cfg_from_callback = copy.deepcopy(config_on_job_start)
    with read_write(task_cfg_from_callback):
        with open_dict(task_cfg_from_callback):
            del task_cfg_from_callback["hydra"]

    # load pickles generated from the application
    app_output_dir = tmpdir / "0" if multirun else tmpdir
    task_cfg_from_app = load_pickle(app_output_dir / "task_cfg.pickle")
    hydra_cfg_from_app = load_pickle(app_output_dir / "hydra_cfg.pickle")

    # verify the cfg pickles are the same on_job_start
    assert task_cfg_from_callback == task_cfg_from_app
    assert config_on_job_start.hydra == hydra_cfg_from_app

    # verify pickled object are the same on_job_end
    assert job_return_on_job_end.cfg == task_cfg_from_app
    assert job_return_on_job_end.hydra_cfg.hydra == hydra_cfg_from_app  # type: ignore
    assert job_return_on_job_end.return_value == "hello world"
    assert job_return_on_job_end.status == JobStatus.COMPLETED
Ejemplo n.º 5
0
    def test_load_yml_file(self, path: str) -> None:
        config_loader = ConfigLoaderImpl(
            config_search_path=create_config_search_path(path)
        )
        with pytest.warns(
            UserWarning,
            match="Support for .yml files is deprecated. Use .yaml extension for Hydra config files",
        ):
            cfg = config_loader.load_configuration(
                config_name="config.yml",
                overrides=[],
                strict=False,
                run_mode=RunMode.RUN,
            )

        with open_dict(cfg):
            del cfg["hydra"]

        assert cfg == {"yml_file_here": True}
Ejemplo n.º 6
0
def launch(
    launcher: RayLocalLauncher,
    job_overrides: Sequence[Sequence[str]],
    initial_job_idx: int,
) -> Sequence[JobReturn]:
    setup_globals()
    assert launcher.config is not None
    assert launcher.config_loader is not None
    assert launcher.task_function is not None

    configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose)
    sweep_dir = Path(str(launcher.config.hydra.sweep.dir))
    sweep_dir.mkdir(parents=True, exist_ok=True)
    log.info(
        f"Ray Launcher is launching {len(job_overrides)} jobs, "
        f"sweep output dir: {sweep_dir}"
    )

    start_ray(launcher.ray_init_cfg)

    runs = []
    for idx, overrides in enumerate(job_overrides):
        idx = initial_job_idx + idx
        ostr = " ".join(filter_overrides(overrides))
        log.info(f"\t#{idx} : {ostr}")
        sweep_config = launcher.config_loader.load_sweep_config(
            launcher.config, list(overrides)
        )
        with open_dict(sweep_config):
            # This typically coming from the underlying scheduler (SLURM_JOB_ID for instance)
            # In that case, it will not be available here because we are still in the main process.
            # but instead should be populated remotely before calling the task_function.
            sweep_config.hydra.job.id = f"job_id_for_{idx}"
            sweep_config.hydra.job.num = idx
            ray_obj = launch_job_on_ray(
                launcher.ray_remote_cfg,
                sweep_config,
                launcher.task_function,
                Singleton.get_state(),
            )
            runs.append(ray_obj)

    return [ray.get(run) for run in runs]
Ejemplo n.º 7
0
    def _extract_defaults_list(self, config_path: str,
                               cfg: Container) -> ListConfig:
        empty = OmegaConf.create([])
        if not OmegaConf.is_dict(cfg):
            return empty
        assert isinstance(cfg, DictConfig)
        with read_write(cfg):
            with open_dict(cfg):
                defaults = cfg.pop("defaults", empty)
        if not isinstance(defaults, ListConfig):
            if isinstance(defaults, DictConfig):
                type_str = "mapping"
            else:
                type_str = type(defaults).__name__
            raise ValueError(
                f"Invalid defaults list in '{config_path}', defaults must be a list (got {type_str})"
            )

        return defaults
Ejemplo n.º 8
0
    def load_sweep_config(self, master_config: DictConfig,
                          sweep_overrides: List[str]) -> DictConfig:
        # Recreate the config for this sweep instance with the appropriate overrides
        overrides = OmegaConf.to_container(master_config.hydra.overrides.hydra)
        assert isinstance(overrides, list)
        overrides = overrides + sweep_overrides
        sweep_config = self.load_configuration(
            config_file=master_config.hydra.job.config_file,
            strict=self.default_strict,
            overrides=overrides,
        )

        with open_dict(sweep_config):
            sweep_config.hydra.runtime.merge_with(master_config.hydra.runtime)

        # Copy old config cache to ensure we get the same resolved values (for things like timestamps etc)
        OmegaConf.copy_cache(from_config=master_config, to_config=sweep_config)

        return sweep_config
Ejemplo n.º 9
0
    def load_model(self):
        logger.info("Loading model")
        if self.config.model in self.config.model_config:
            attributes = self.config.model_config[self.config.model]
        else:
            warnings.warn(
                f"Model {self.config.model}'s config not present. "
                + "Continuing with empty config"
            )
            attributes = OmegaConf.create()
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_config[attributes]

        with omegaconf.open_dict(attributes):
            attributes.model = self.config.model

        self.model = build_model(attributes)
        self.model = self.model.to(self.device)
Ejemplo n.º 10
0
def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
    # this will be deprecated when we get rid of argparse and model_overrides logic

    from fairseq.registry import REGISTRIES

    with open_dict(cfg):
        for k in cfg.keys():
            # "k in cfg" will return false if its a "mandatory value (e.g. ???)"
            if k in cfg and isinstance(cfg[k], DictConfig):
                overwrite_args_by_name(cfg[k], overrides)
            elif k in overrides:
                if (k in REGISTRIES and overrides[k]
                        in REGISTRIES[k]["dataclass_registry"]):
                    cfg[k] = DictConfig(
                        REGISTRIES[k]["dataclass_registry"][overrides[k]])
                    overwrite_args_by_name(cfg[k], overrides)
                    cfg[k]._name = overrides[k]
                else:
                    cfg[k] = overrides[k]
Ejemplo n.º 11
0
def merge_with_parent(dc: FairseqDataclass,
                      cfg: DictConfig,
                      remove_missing=True):
    if remove_missing:

        if is_dataclass(dc):
            target_keys = set(dc.__dataclass_fields__.keys())
        else:
            target_keys = set(dc.keys())

        with open_dict(cfg):
            for k in list(cfg.keys()):
                if k not in target_keys:
                    del cfg[k]

    merged_cfg = OmegaConf.merge(dc, cfg)
    merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
    OmegaConf.set_struct(merged_cfg, True)
    return merged_cfg
Ejemplo n.º 12
0
def test_setdefault() -> None:
    cfg = OmegaConf.create({})
    assert cfg.setdefault("foo", 10) == 10
    assert cfg["foo"] == 10
    assert cfg.setdefault("foo", 20) == 10
    assert cfg["foo"] == 10

    cfg = OmegaConf.create({})
    OmegaConf.set_struct(cfg, True)
    with pytest.raises(ConfigKeyError):
        assert cfg.setdefault("foo", 10) == 10
    assert cfg == {}
    with open_dict(cfg):
        assert cfg.setdefault("foo", 10) == 10

    assert cfg.setdefault("foo", 20) == 10
    assert cfg["foo"] == 10

    assert cfg["foo"] == 10
Ejemplo n.º 13
0
    def __init__(self, cfg: DictConfig, trainer=None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        if 'tokenizer' not in cfg:
            raise ValueError(
                "`cfg` must have `tokenizer` config to create a tokenizer !")

        # Setup the tokenizer
        self._setup_tokenizer(cfg.tokenizer)

        # Initialize a dummy vocabulary
        vocabulary = self.tokenizer.tokenizer.get_vocab()

        # Set the new vocabulary
        with open_dict(cfg):
            # sidestepping the potential overlapping tokens issue in aggregate tokenizers
            if self.tokenizer_type == "agg":
                cfg.decoder.vocabulary = ListConfig(vocabulary)
            else:
                cfg.decoder.vocabulary = ListConfig(list(vocabulary.keys()))

        # Override number of classes if placeholder provided
        num_classes = cfg.decoder["num_classes"]

        if num_classes < 1:
            logging.info(
                "\nReplacing placeholder number of classes ({}) with actual number of classes - {}"
                .format(num_classes, len(vocabulary)))
            cfg.decoder["num_classes"] = len(vocabulary)

        super().__init__(cfg=cfg, trainer=trainer)

        # Setup metric objects
        self._wer = WERBPE(
            tokenizer=self.tokenizer,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=self._cfg.get("log_prediction", False),
        )
Ejemplo n.º 14
0
 def _get_cfg(
     self,
     config_name: Optional[str],
     overrides: List[str],
     cfg_type: str,
     with_log_configuration: bool,
 ) -> DictConfig:
     assert cfg_type in ["job", "hydra", "all"]
     cfg = self.compose_config(
         config_name=config_name,
         overrides=overrides,
         with_log_configuration=with_log_configuration,
     )
     if cfg_type == "job":
         with open_dict(cfg):
             del cfg["hydra"]
     elif cfg_type == "hydra":
         cfg = self.get_sanitized_hydra_cfg(cfg)
     return cfg
Ejemplo n.º 15
0
    def test_load_config_with_schema(
        self, hydra_restore_singletons: Any, path: str
    ) -> None:

        ConfigStore.instance().store(
            name="config_with_schema", node=TopLevelConfig, provider="this_test"
        )
        ConfigStore.instance().store(
            group="db", name="base_mysql", node=MySQLConfig, provider="this_test"
        )

        config_loader = ConfigLoaderImpl(
            config_search_path=create_config_search_path(path)
        )

        cfg = config_loader.load_configuration(
            config_name="config",
            overrides=["+db=validated_mysql"],
            run_mode=RunMode.RUN,
        )

        with open_dict(cfg):
            del cfg["hydra"]
        assert cfg == {
            "normal_yaml_config": True,
            "db": {
                "driver": "mysql",
                "host": "???",
                "port": "???",
                "user": "******",
                "password": "******",
            },
        }

        # verify illegal modification is rejected at runtime
        with raises(ValidationError):
            cfg.db.port = "fail"

        # verify illegal override is rejected during load
        with raises(HydraException):
            config_loader.load_configuration(
                config_name="db/mysql", overrides=["db.port=fail"], run_mode=RunMode.RUN
            )
Ejemplo n.º 16
0
    def change_decoding_strategy(self, decoding_cfg: DictConfig):
        """
        Changes decoding strategy used during RNNT decoding process.

        Args:
            decoding_cfg: A config for the decoder, which is optional. If the decoding type
                needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.
        """
        if decoding_cfg is None:
            # Assume same decoding config as before
            logging.info(
                "No `decoding_cfg` passed when changing decoding strategy, using internal config"
            )
            decoding_cfg = self.cfg.decoding

        self.decoding = RNNTDecoding(
            decoding_cfg=decoding_cfg,
            decoder=self.decoder,
            joint=self.joint,
            vocabulary=self.joint.vocabulary,
        )

        self.wer = RNNTWER(
            decoding=self.decoding,
            batch_dim_index=self.wer.batch_dim_index,
            use_cer=self.wer.use_cer,
            log_prediction=self.wer.log_prediction,
            dist_sync_on_step=True,
        )

        # Setup fused Joint step
        if self.joint.fuse_loss_wer:
            self.joint.set_loss(self.loss)
            self.joint.set_wer(self.wer)

        # Update config
        with open_dict(self.cfg.decoding):
            self.cfg.decoding = decoding_cfg

        logging.info(
            f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}"
        )
Ejemplo n.º 17
0
def change_conv_asr_se_context_window(model: 'ASRModel',
                                      context_window: int,
                                      update_config: bool = True):
    """
    Update the context window of the SqueezeExcitation module if the provided model contains an
    `encoder` which is an instance of `ConvASREncoder`.

    Args:
        model: A subclass of `ASRModel`, itself a subclass of `ModelPT`.
        context_window:  An integer representing the number of input timeframes that will be used
            to compute the context. Each timeframe corresponds to a single window stride of the
            STFT features.

            Say the window_stride = 0.01s, then a context window of 128 represents 128 * 0.01 s
            of context to compute the Squeeze step.
        update_config: Whether to update the config or not with the new context window.
    """
    if update_config and not hasattr(model.cfg, 'encoder'):
        logging.info(
            "Could not change the context window in SqueezeExcite module "
            "since the model provided does not contain an `encoder` module in its config."
        )
        return

    if not isinstance(model.encoder, conv_asr.ConvASREncoder):
        logging.info(
            f"Could not change the context window in SqueezeExcite module "
            f"since the `encoder` module is not an instance of `ConvASREncoder`.\n"
            f"Provided encoder class = {model.encoder.__class__.__name__}")
        return

    enc_cfg = model.cfg.encoder if update_config else None

    if enc_cfg is not None:
        with open_dict(enc_cfg):
            _update_se_context_window(model, context_window, cfg=enc_cfg)
    else:
        _update_se_context_window(model, context_window)

    # Update model config
    if update_config:
        model.cfg.encoder = enc_cfg
Ejemplo n.º 18
0
    def change_decoding_strategy(self, decoding_cfg: DictConfig):
        """
        Changes decoding strategy used during RNNT decoding process.

        Args:
            decoding_cfg: A config for the decoder, which is optional. If the decoding type
                needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.
        """
        if decoding_cfg is None:
            # Assume same decoding config as before
            logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config")
            decoding_cfg = self.cfg.decoding

        # Assert the decoding config with all hyper parameters
        decoding_cls = OmegaConf.structured(RNNTBPEDecodingConfig)
        decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
        decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

        self.decoding = RNNTBPEDecoding(
            decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
        )

        self.wer = RNNTBPEWER(
            decoding=self.decoding,
            batch_dim_index=self.wer.batch_dim_index,
            use_cer=self.wer.use_cer,
            log_prediction=self.wer.log_prediction,
            dist_sync_on_step=True,
        )

        # Setup fused Joint step
        if self.joint.fuse_loss_wer or (
            self.decoding.joint_fused_batch_size is not None and self.decoding.joint_fused_batch_size > 0
        ):
            self.joint.set_loss(self.loss)
            self.joint.set_wer(self.wer)

        # Update config
        with open_dict(self.cfg.decoding):
            self.cfg.decoding = decoding_cfg

        logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}")
Ejemplo n.º 19
0
def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
    add_defaults(cfg)

    if cfg.common.reset_logging:
        reset_logging()  # Hydra hijacks logging, fix that
    else:
        # check if directly called or called through hydra_main
        if HydraConfig.initialized():
            with open_dict(cfg):
                # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
                cfg.job_logging_cfg = OmegaConf.to_container(
                    HydraConfig.get().job_logging, resolve=True)

    with omegaconf_no_object_check():
        cfg = OmegaConf.create(
            OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
    OmegaConf.set_struct(cfg, True)

    try:
        if cfg.common.profile:
            with torch.cuda.profiler.profile():
                with torch.autograd.profiler.emit_nvtx():
                    distributed_utils.call_main(cfg, pre_main, **kwargs)
        else:
            distributed_utils.call_main(cfg, pre_main, **kwargs)
    except BaseException as e:
        if not cfg.common.suppress_crashes:
            raise
        else:
            logger.error("Crashed! " + str(e))

    # get best val and return - useful for sweepers
    try:
        best_val = metrics.get_smoothed_value(
            "valid", cfg.checkpoint.best_checkpoint_metric)
    except:
        best_val = None

    if best_val is None:
        best_val = float("inf")

    return best_val
Ejemplo n.º 20
0
 def __init__(self, cfg, dictionary, embed_tokens, no_encoder_attn=False):
     transformer_cfg = copy.deepcopy(cfg)
     with open_dict(transformer_cfg):
         transformer_cfg.dropout = transformer_cfg.decoder_dropout
         transformer_cfg.attention_dropout = (
             transformer_cfg.decoder_attention_dropout
         )
         transformer_cfg.activation_dropout = (
             transformer_cfg.decoder_activation_dropout
         )
         transformer_cfg.layernorm_embedding = True
         transformer_cfg.adaptive_input = False
         transformer_cfg.no_scale_embedding = False
         transformer_cfg.quant_noise_pq = 0.0
         transformer_cfg.adaptive_softmax_cutoff = None
     super().__init__(transformer_cfg, dictionary, embed_tokens, no_encoder_attn)
     if cfg.decoder_enc_attention_dropout is not None:
         for layer in self.layers:
             layer.encoder_attn.dropout_module.p = \
                 cfg.decoder_enc_attention_dropout
Ejemplo n.º 21
0
    def _print_config_info(
        self, config_name: Optional[str], overrides: List[str]
    ) -> None:
        assert log is not None
        self._print_search_path()
        self._print_defaults_tree(config_name=config_name, overrides=overrides)
        self._print_defaults_list(config_name=config_name, overrides=overrides)

        cfg = run_and_report(
            lambda: self._get_cfg(
                config_name=config_name,
                overrides=overrides,
                cfg_type="all",
                with_log_configuration=False,
            )
        )
        self._log_header(header="Config", filler="*")
        with open_dict(cfg):
            del cfg["hydra"]
        log.info(OmegaConf.to_yaml(cfg))
Ejemplo n.º 22
0
 def run_task(job):
     idx, overrides = job
     LOGGER.info("\t#{} : {}".format(
         idx, " ".join(filter_overrides(overrides))))
     sweep_config = self.config_loader.load_sweep_config(
         self.config, list(overrides))
     with open_dict(sweep_config):
         # id is concatenated overrides here
         sweep_config.hydra.job.id = '_'.join(sorted(overrides))
         sweep_config.hydra.job.num = idx
     HydraConfig().set_config(sweep_config)
     ret = run_job(
         config=sweep_config,
         task_function=self.task_function,
         job_dir_key="hydra.sweep.dir",
         job_subdir_key="hydra.sweep.subdir",
     )
     configure_log(self.config.hydra.hydra_logging,
                   self.config.hydra.verbose)
     return (idx, ret)
Ejemplo n.º 23
0
def main(hydra_cfg):
    hydra_cfg.device = hydra_cfg.device.lower()
    with open_dict(hydra_cfg):
        hydra_cfg.job_logging_cfg = HydraConfig.get().job_logging
    # random seed
    if hydra_cfg.random_seed is None:
        hydra_cfg.random_seed = random.randint(1, 10000)
    set_random_seed(hydra_cfg.random_seed)

    if hydra_cfg.dist.gpus < 0:
        hydra_cfg.dist.gpus = torch.cuda.device_count()
        hydra_cfg.dist.master_port = os.environ["MASTER_PORT"]
        hydra_cfg.dist.master_addr = os.environ["MASTER_ADDR"]
        print(hydra_cfg.dist)

    if hydra_cfg.device == "cpu" or hydra_cfg.dist.gpus == 0:
        hydra_cfg.dist.gpus = 0
        train_loop(0, hydra_cfg)
    else:
        distributed_run(train_loop, hydra_cfg)
Ejemplo n.º 24
0
def hydra_main(cfg):
    with open_dict(cfg):
        # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
        cfg.job_logging_cfg = OmegaConf.to_container(
            HydraConfig.get().job_logging, resolve=True
        )

    cfg = OmegaConf.create(
        OmegaConf.to_container(cfg, resolve=False, enum_to_str=False)
    )
    OmegaConf.set_struct(cfg, True)
    logger.info(cfg)

    utils.import_user_module(cfg.fairseq.common)

    _, score = main(cfg)

    if cfg.is_ax:
        return score, None
    return score
Ejemplo n.º 25
0
def fix_task(cfg: DictConfig) -> None:
    with open_dict(cfg):
        semantic = 'semantic_criterion' in cfg
        instance = 'embed_criterion' in cfg
        if semantic and instance:
            cfg.task = 'panoptic'
        elif semantic:
            cfg.task = 'semantic'
        elif instance:
            cfg.task = 'instance'
        else:
            raise RuntimeError(
                'semantic_criterion and/or embed_criterion must be set!')

        if instance:
            requires_semantic = ('method' not in cfg.embed_criterion
                                 ) or cfg.embed_criterion.method in [
                                     'ignore', 'separate'
                                 ]
            cfg.requires_semantic = requires_semantic
Ejemplo n.º 26
0
    def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")  # TODO
        if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")  # TODO
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True"
                )
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
Ejemplo n.º 27
0
    def __init__(
        self,
        cfg: DictConfig,
        trainer: Trainer = None,
    ):

        self.cfg = cfg
        self.data_prepared = False
        self.epoch_number = 0
        if self.cfg.library == "huggingface":
            self.setup_tokenizer(cfg.tokenizer)
        elif self.cfg.library == "megatron":
            # supporting MegatronT5Model in precision = fp16
            t5_cfg = MegatronT5Model.restore_from(
                restore_path=cfg.language_model.lm_checkpoint,
                trainer=trainer,
                return_config=True)
            # Override the T5 configuration with the one from the config file.
            OmegaConf.set_struct(t5_cfg, True)
            with open_dict(t5_cfg):
                t5_cfg.masked_softmax_fusion = False
                t5_cfg.precision = 16

            language_model = MegatronT5Model.restore_from(
                restore_path=cfg.language_model.lm_checkpoint,
                trainer=trainer,
                override_config_path=t5_cfg)
            self.tokenizer = language_model.tokenizer

        super().__init__(cfg=cfg, trainer=trainer, no_lm_init=True)

        if self.cfg.library == "huggingface":
            self.language_model = AutoModelForSeq2SeqLM.from_pretrained(
                cfg.language_model.pretrained_model_name)
            self.language_model.resize_token_embeddings(
                len(self.tokenizer.tokenizer))
            if self.cfg.language_model.lm_checkpoint:
                self.language_model.load_state_dict(
                    torch.load(self.cfg.language_model.lm_checkpoint))
        elif self.cfg.library == "megatron":
            self.language_model = language_model
Ejemplo n.º 28
0
def test_load_schema_as_config(hydra_restore_singletons: Any) -> None:
    """
    Load structured config as a configuration
    """
    ConfigStore.instance().store(
        name="config", node=TopLevelConfig, provider="this_test"
    )

    ConfigStore.instance().store(
        name="db/mysql", node=MySQLConfig, provider="this_test"
    )

    config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None))
    cfg = config_loader.load_configuration(
        config_name="config", overrides=[], run_mode=RunMode.RUN
    )

    expected = deepcopy(hydra_load_list)
    expected.append(
        LoadTrace(
            config_path="config",
            package="",
            parent="<root>",
            search_path="structured://",
            provider="this_test",
        )
    )
    assert_same_composition_trace(cfg.hydra.composition_trace, expected)

    with open_dict(cfg):
        del cfg["hydra"]
    assert cfg == {
        "normal_yaml_config": "???",
        "db": {
            "driver": MISSING,
            "host": MISSING,
            "port": MISSING,
            "user": MISSING,
            "password": MISSING,
        },
    }
Ejemplo n.º 29
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        super().__init__(cfg=cfg, trainer=trainer)
        self.preprocessor = EncDecCTCModel.from_config_dict(self._cfg.preprocessor)
        self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder)

        with open_dict(self._cfg):
            if "feat_in" not in self._cfg.decoder or (
                not self._cfg.decoder.feat_in and hasattr(self.encoder, '_feat_out')
            ):
                self._cfg.decoder.feat_in = self.encoder._feat_out
            if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in:
                raise ValueError("param feat_in of the decoder's config is not set!")

        self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder)

        self.loss = CTCLoss(
            num_classes=self.decoder.num_classes_with_blank - 1,
            zero_infinity=True,
            reduction=self._cfg.get("ctc_reduction", "mean_batch"),
        )

        if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None:
            self.spec_augmentation = EncDecCTCModel.from_config_dict(self._cfg.spec_augment)
        else:
            self.spec_augmentation = None

        # Setup metric objects
        self._wer = WER(
            vocabulary=self.decoder.vocabulary,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=self._cfg.get("log_prediction", False),
        )
Ejemplo n.º 30
0
def _get_rerun_conf(file_path: str, overrides: List[str]) -> DictConfig:
    msg = "Experimental rerun CLI option, other command line args are ignored."
    warnings.warn(msg, UserWarning)
    file = Path(file_path)
    if not file.exists():
        raise ValueError(f"File {file} does not exist!")

    if len(overrides) > 0:
        msg = "Config overrides are not supported as of now."
        warnings.warn(msg, UserWarning)

    with open(str(file), "rb") as input:
        config = pickle.load(input)  # nosec
    configure_log(config.hydra.job_logging, config.hydra.verbose)
    HydraConfig.instance().set_config(config)
    task_cfg = copy.deepcopy(config)
    with read_write(task_cfg):
        with open_dict(task_cfg):
            del task_cfg["hydra"]
    assert isinstance(task_cfg, DictConfig)
    return task_cfg