def configure_device(self) -> None: if self.config.training.get("device", "cuda") == "xla": import torch_xla.core.xla_model as xm self.device = xm.xla_device() self.distributed = True self.local_rank = xm.get_local_ordinal() is_xla = True else: is_xla = False if "device_id" not in self.config: warnings.warn( "No 'device_id' in 'config', setting to -1. " "This can cause issues later in training. Ensure that " "distributed setup is properly initialized." ) self.local_rank = -1 else: self.local_rank = self.config.device_id self.device = self.local_rank self.distributed = False # Will be updated later based on distributed setup registry.register("global_device", self.device) if self.config.distributed.init_method is not None: self.distributed = True self.device = torch.device("cuda", self.local_rank) torch.cuda.set_device(self.local_rank) elif torch.cuda.is_available(): self.device = torch.device("cuda") torch.cuda.set_device(0) elif not is_xla: self.device = torch.device("cpu") if "rank" not in self.config.distributed: if torch.distributed.is_available() and torch.distributed.is_initialized(): global_rank = torch.distributed.get_rank() else: global_rank = -1 with open_dict(self.config.distributed): self.config.distributed.rank = global_rank registry.register("global_device", self.config.distributed.rank)
def launch(self, job_overrides: Sequence[Sequence[str]]) -> Sequence[JobReturn]: """ :param job_overrides: a List of List<String>, where each inner list is the arguments for one job run. :return: an array of return values from run_job with indexes corresponding to the input list indexes. """ setup_globals() assert self.config is not None assert self.config_loader is not None assert self.task_function is not None configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose) sweep_dir = Path(str(self.config.hydra.sweep.dir)) sweep_dir.mkdir(parents=True, exist_ok=True) log.info( "Example Launcher(foo={}, bar={}) is launching {} jobs locally". format(self.foo, self.bar, len(job_overrides))) log.info("Sweep output dir : {}".format(sweep_dir)) runs = [] for idx, overrides in enumerate(job_overrides): log.info("\t#{} : {}".format(idx, " ".join( filter_overrides(overrides)))) sweep_config = self.config_loader.load_sweep_config( self.config, list(overrides)) with open_dict(sweep_config): # This typically coming from the underlying scheduler (SLURM_JOB_ID for instance) # In that case, it will not be available here because we are still in the main process. # but instead should be populated remotely before calling the task_function. sweep_config.hydra.job.id = "job_id_for_{}".format(idx) sweep_config.hydra.job.num = idx HydraConfig.instance().set_config(sweep_config) ret = run_job( config=sweep_config, task_function=self.task_function, job_dir_key="hydra.sweep.dir", job_subdir_key="hydra.sweep.subdir", ) runs.append(ret) configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose) return runs
def test_load_config_with_schema(self, hydra_restore_singletons: Any, path: str) -> None: ConfigStore.instance().store(name="config", node=TopLevelConfig, provider="this_test") ConfigStore.instance().store( group="db", name="mysql", node=MySQLConfig, provider="this_test", ) config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path)) cfg = config_loader.load_configuration(config_name="config", overrides=["+db=mysql"]) with open_dict(cfg): del cfg["hydra"] assert cfg == { "normal_yaml_config": True, "db": { "driver": "mysql", "host": "???", "port": "???", "user": "******", "password": "******", }, } expected = hydra_load_list.copy() expected.append(LoadTrace("config", path, "main", "this_test")) expected.append(LoadTrace("db/mysql", path, "main", "this_test")) assert config_loader.get_load_history() == expected # verify illegal modification is rejected at runtime with pytest.raises(ValidationError): cfg.db.port = "fail" # verify illegal override is rejected during load with pytest.raises(HydraException): config_loader.load_configuration(config_name="db/mysql", overrides=["db.port=fail"])
def test_experimental_save_job_info_callback(tmpdir: Path, multirun: bool) -> None: app_path = "tests/test_apps/app_with_pickle_job_info_callback/my_app.py" cmd = [ app_path, "hydra.run.dir=" + str(tmpdir), "hydra.sweep.dir=" + str(tmpdir), "hydra.job.chdir=True", ] if multirun: cmd.append("-m") _, _err = run_python_script(cmd) def load_pickle(path: Path) -> Any: with open(str(path), "rb") as input: obj = pickle.load(input) # nosec return obj # load pickles from callbacks callback_output = tmpdir / Path("0") / ".hydra" if multirun else tmpdir / ".hydra" config_on_job_start = load_pickle(callback_output / "config.pickle") job_return_on_job_end: JobReturn = load_pickle( callback_output / "job_return.pickle" ) task_cfg_from_callback = copy.deepcopy(config_on_job_start) with read_write(task_cfg_from_callback): with open_dict(task_cfg_from_callback): del task_cfg_from_callback["hydra"] # load pickles generated from the application app_output_dir = tmpdir / "0" if multirun else tmpdir task_cfg_from_app = load_pickle(app_output_dir / "task_cfg.pickle") hydra_cfg_from_app = load_pickle(app_output_dir / "hydra_cfg.pickle") # verify the cfg pickles are the same on_job_start assert task_cfg_from_callback == task_cfg_from_app assert config_on_job_start.hydra == hydra_cfg_from_app # verify pickled object are the same on_job_end assert job_return_on_job_end.cfg == task_cfg_from_app assert job_return_on_job_end.hydra_cfg.hydra == hydra_cfg_from_app # type: ignore assert job_return_on_job_end.return_value == "hello world" assert job_return_on_job_end.status == JobStatus.COMPLETED
def test_load_yml_file(self, path: str) -> None: config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path) ) with pytest.warns( UserWarning, match="Support for .yml files is deprecated. Use .yaml extension for Hydra config files", ): cfg = config_loader.load_configuration( config_name="config.yml", overrides=[], strict=False, run_mode=RunMode.RUN, ) with open_dict(cfg): del cfg["hydra"] assert cfg == {"yml_file_here": True}
def launch( launcher: RayLocalLauncher, job_overrides: Sequence[Sequence[str]], initial_job_idx: int, ) -> Sequence[JobReturn]: setup_globals() assert launcher.config is not None assert launcher.config_loader is not None assert launcher.task_function is not None configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose) sweep_dir = Path(str(launcher.config.hydra.sweep.dir)) sweep_dir.mkdir(parents=True, exist_ok=True) log.info( f"Ray Launcher is launching {len(job_overrides)} jobs, " f"sweep output dir: {sweep_dir}" ) start_ray(launcher.ray_init_cfg) runs = [] for idx, overrides in enumerate(job_overrides): idx = initial_job_idx + idx ostr = " ".join(filter_overrides(overrides)) log.info(f"\t#{idx} : {ostr}") sweep_config = launcher.config_loader.load_sweep_config( launcher.config, list(overrides) ) with open_dict(sweep_config): # This typically coming from the underlying scheduler (SLURM_JOB_ID for instance) # In that case, it will not be available here because we are still in the main process. # but instead should be populated remotely before calling the task_function. sweep_config.hydra.job.id = f"job_id_for_{idx}" sweep_config.hydra.job.num = idx ray_obj = launch_job_on_ray( launcher.ray_remote_cfg, sweep_config, launcher.task_function, Singleton.get_state(), ) runs.append(ray_obj) return [ray.get(run) for run in runs]
def _extract_defaults_list(self, config_path: str, cfg: Container) -> ListConfig: empty = OmegaConf.create([]) if not OmegaConf.is_dict(cfg): return empty assert isinstance(cfg, DictConfig) with read_write(cfg): with open_dict(cfg): defaults = cfg.pop("defaults", empty) if not isinstance(defaults, ListConfig): if isinstance(defaults, DictConfig): type_str = "mapping" else: type_str = type(defaults).__name__ raise ValueError( f"Invalid defaults list in '{config_path}', defaults must be a list (got {type_str})" ) return defaults
def load_sweep_config(self, master_config: DictConfig, sweep_overrides: List[str]) -> DictConfig: # Recreate the config for this sweep instance with the appropriate overrides overrides = OmegaConf.to_container(master_config.hydra.overrides.hydra) assert isinstance(overrides, list) overrides = overrides + sweep_overrides sweep_config = self.load_configuration( config_file=master_config.hydra.job.config_file, strict=self.default_strict, overrides=overrides, ) with open_dict(sweep_config): sweep_config.hydra.runtime.merge_with(master_config.hydra.runtime) # Copy old config cache to ensure we get the same resolved values (for things like timestamps etc) OmegaConf.copy_cache(from_config=master_config, to_config=sweep_config) return sweep_config
def load_model(self): logger.info("Loading model") if self.config.model in self.config.model_config: attributes = self.config.model_config[self.config.model] else: warnings.warn( f"Model {self.config.model}'s config not present. " + "Continuing with empty config" ) attributes = OmegaConf.create() # Easy way to point to config for other model if isinstance(attributes, str): attributes = self.config.model_config[attributes] with omegaconf.open_dict(attributes): attributes.model = self.config.model self.model = build_model(attributes) self.model = self.model.to(self.device)
def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): # this will be deprecated when we get rid of argparse and model_overrides logic from fairseq.registry import REGISTRIES with open_dict(cfg): for k in cfg.keys(): # "k in cfg" will return false if its a "mandatory value (e.g. ???)" if k in cfg and isinstance(cfg[k], DictConfig): overwrite_args_by_name(cfg[k], overrides) elif k in overrides: if (k in REGISTRIES and overrides[k] in REGISTRIES[k]["dataclass_registry"]): cfg[k] = DictConfig( REGISTRIES[k]["dataclass_registry"][overrides[k]]) overwrite_args_by_name(cfg[k], overrides) cfg[k]._name = overrides[k] else: cfg[k] = overrides[k]
def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=True): if remove_missing: if is_dataclass(dc): target_keys = set(dc.__dataclass_fields__.keys()) else: target_keys = set(dc.keys()) with open_dict(cfg): for k in list(cfg.keys()): if k not in target_keys: del cfg[k] merged_cfg = OmegaConf.merge(dc, cfg) merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] OmegaConf.set_struct(merged_cfg, True) return merged_cfg
def test_setdefault() -> None: cfg = OmegaConf.create({}) assert cfg.setdefault("foo", 10) == 10 assert cfg["foo"] == 10 assert cfg.setdefault("foo", 20) == 10 assert cfg["foo"] == 10 cfg = OmegaConf.create({}) OmegaConf.set_struct(cfg, True) with pytest.raises(ConfigKeyError): assert cfg.setdefault("foo", 10) == 10 assert cfg == {} with open_dict(cfg): assert cfg.setdefault("foo", 10) == 10 assert cfg.setdefault("foo", 20) == 10 assert cfg["foo"] == 10 assert cfg["foo"] == 10
def __init__(self, cfg: DictConfig, trainer=None): # Convert to Hydra 1.0 compatible DictConfig cfg = model_utils.convert_model_config_to_dict_config(cfg) cfg = model_utils.maybe_update_config_version(cfg) if 'tokenizer' not in cfg: raise ValueError( "`cfg` must have `tokenizer` config to create a tokenizer !") # Setup the tokenizer self._setup_tokenizer(cfg.tokenizer) # Initialize a dummy vocabulary vocabulary = self.tokenizer.tokenizer.get_vocab() # Set the new vocabulary with open_dict(cfg): # sidestepping the potential overlapping tokens issue in aggregate tokenizers if self.tokenizer_type == "agg": cfg.decoder.vocabulary = ListConfig(vocabulary) else: cfg.decoder.vocabulary = ListConfig(list(vocabulary.keys())) # Override number of classes if placeholder provided num_classes = cfg.decoder["num_classes"] if num_classes < 1: logging.info( "\nReplacing placeholder number of classes ({}) with actual number of classes - {}" .format(num_classes, len(vocabulary))) cfg.decoder["num_classes"] = len(vocabulary) super().__init__(cfg=cfg, trainer=trainer) # Setup metric objects self._wer = WERBPE( tokenizer=self.tokenizer, batch_dim_index=0, use_cer=self._cfg.get('use_cer', False), ctc_decode=True, dist_sync_on_step=True, log_prediction=self._cfg.get("log_prediction", False), )
def _get_cfg( self, config_name: Optional[str], overrides: List[str], cfg_type: str, with_log_configuration: bool, ) -> DictConfig: assert cfg_type in ["job", "hydra", "all"] cfg = self.compose_config( config_name=config_name, overrides=overrides, with_log_configuration=with_log_configuration, ) if cfg_type == "job": with open_dict(cfg): del cfg["hydra"] elif cfg_type == "hydra": cfg = self.get_sanitized_hydra_cfg(cfg) return cfg
def test_load_config_with_schema( self, hydra_restore_singletons: Any, path: str ) -> None: ConfigStore.instance().store( name="config_with_schema", node=TopLevelConfig, provider="this_test" ) ConfigStore.instance().store( group="db", name="base_mysql", node=MySQLConfig, provider="this_test" ) config_loader = ConfigLoaderImpl( config_search_path=create_config_search_path(path) ) cfg = config_loader.load_configuration( config_name="config", overrides=["+db=validated_mysql"], run_mode=RunMode.RUN, ) with open_dict(cfg): del cfg["hydra"] assert cfg == { "normal_yaml_config": True, "db": { "driver": "mysql", "host": "???", "port": "???", "user": "******", "password": "******", }, } # verify illegal modification is rejected at runtime with raises(ValidationError): cfg.db.port = "fail" # verify illegal override is rejected during load with raises(HydraException): config_loader.load_configuration( config_name="db/mysql", overrides=["db.port=fail"], run_mode=RunMode.RUN )
def change_decoding_strategy(self, decoding_cfg: DictConfig): """ Changes decoding strategy used during RNNT decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. """ if decoding_cfg is None: # Assume same decoding config as before logging.info( "No `decoding_cfg` passed when changing decoding strategy, using internal config" ) decoding_cfg = self.cfg.decoding self.decoding = RNNTDecoding( decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) self.wer = RNNTWER( decoding=self.decoding, batch_dim_index=self.wer.batch_dim_index, use_cer=self.wer.use_cer, log_prediction=self.wer.log_prediction, dist_sync_on_step=True, ) # Setup fused Joint step if self.joint.fuse_loss_wer: self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) # Update config with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg logging.info( f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}" )
def change_conv_asr_se_context_window(model: 'ASRModel', context_window: int, update_config: bool = True): """ Update the context window of the SqueezeExcitation module if the provided model contains an `encoder` which is an instance of `ConvASREncoder`. Args: model: A subclass of `ASRModel`, itself a subclass of `ModelPT`. context_window: An integer representing the number of input timeframes that will be used to compute the context. Each timeframe corresponds to a single window stride of the STFT features. Say the window_stride = 0.01s, then a context window of 128 represents 128 * 0.01 s of context to compute the Squeeze step. update_config: Whether to update the config or not with the new context window. """ if update_config and not hasattr(model.cfg, 'encoder'): logging.info( "Could not change the context window in SqueezeExcite module " "since the model provided does not contain an `encoder` module in its config." ) return if not isinstance(model.encoder, conv_asr.ConvASREncoder): logging.info( f"Could not change the context window in SqueezeExcite module " f"since the `encoder` module is not an instance of `ConvASREncoder`.\n" f"Provided encoder class = {model.encoder.__class__.__name__}") return enc_cfg = model.cfg.encoder if update_config else None if enc_cfg is not None: with open_dict(enc_cfg): _update_se_context_window(model, context_window, cfg=enc_cfg) else: _update_se_context_window(model, context_window) # Update model config if update_config: model.cfg.encoder = enc_cfg
def change_decoding_strategy(self, decoding_cfg: DictConfig): """ Changes decoding strategy used during RNNT decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. """ if decoding_cfg is None: # Assume same decoding config as before logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") decoding_cfg = self.cfg.decoding # Assert the decoding config with all hyper parameters decoding_cls = OmegaConf.structured(RNNTBPEDecodingConfig) decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTBPEDecoding( decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, ) self.wer = RNNTBPEWER( decoding=self.decoding, batch_dim_index=self.wer.batch_dim_index, use_cer=self.wer.use_cer, log_prediction=self.wer.log_prediction, dist_sync_on_step=True, ) # Setup fused Joint step if self.joint.fuse_loss_wer or ( self.decoding.joint_fused_batch_size is not None and self.decoding.joint_fused_batch_size > 0 ): self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) # Update config with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}")
def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: add_defaults(cfg) if cfg.common.reset_logging: reset_logging() # Hydra hijacks logging, fix that else: # check if directly called or called through hydra_main if HydraConfig.initialized(): with open_dict(cfg): # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) cfg.job_logging_cfg = OmegaConf.to_container( HydraConfig.get().job_logging, resolve=True) with omegaconf_no_object_check(): cfg = OmegaConf.create( OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) OmegaConf.set_struct(cfg, True) try: if cfg.common.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): distributed_utils.call_main(cfg, pre_main, **kwargs) else: distributed_utils.call_main(cfg, pre_main, **kwargs) except BaseException as e: if not cfg.common.suppress_crashes: raise else: logger.error("Crashed! " + str(e)) # get best val and return - useful for sweepers try: best_val = metrics.get_smoothed_value( "valid", cfg.checkpoint.best_checkpoint_metric) except: best_val = None if best_val is None: best_val = float("inf") return best_val
def __init__(self, cfg, dictionary, embed_tokens, no_encoder_attn=False): transformer_cfg = copy.deepcopy(cfg) with open_dict(transformer_cfg): transformer_cfg.dropout = transformer_cfg.decoder_dropout transformer_cfg.attention_dropout = ( transformer_cfg.decoder_attention_dropout ) transformer_cfg.activation_dropout = ( transformer_cfg.decoder_activation_dropout ) transformer_cfg.layernorm_embedding = True transformer_cfg.adaptive_input = False transformer_cfg.no_scale_embedding = False transformer_cfg.quant_noise_pq = 0.0 transformer_cfg.adaptive_softmax_cutoff = None super().__init__(transformer_cfg, dictionary, embed_tokens, no_encoder_attn) if cfg.decoder_enc_attention_dropout is not None: for layer in self.layers: layer.encoder_attn.dropout_module.p = \ cfg.decoder_enc_attention_dropout
def _print_config_info( self, config_name: Optional[str], overrides: List[str] ) -> None: assert log is not None self._print_search_path() self._print_defaults_tree(config_name=config_name, overrides=overrides) self._print_defaults_list(config_name=config_name, overrides=overrides) cfg = run_and_report( lambda: self._get_cfg( config_name=config_name, overrides=overrides, cfg_type="all", with_log_configuration=False, ) ) self._log_header(header="Config", filler="*") with open_dict(cfg): del cfg["hydra"] log.info(OmegaConf.to_yaml(cfg))
def run_task(job): idx, overrides = job LOGGER.info("\t#{} : {}".format( idx, " ".join(filter_overrides(overrides)))) sweep_config = self.config_loader.load_sweep_config( self.config, list(overrides)) with open_dict(sweep_config): # id is concatenated overrides here sweep_config.hydra.job.id = '_'.join(sorted(overrides)) sweep_config.hydra.job.num = idx HydraConfig().set_config(sweep_config) ret = run_job( config=sweep_config, task_function=self.task_function, job_dir_key="hydra.sweep.dir", job_subdir_key="hydra.sweep.subdir", ) configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose) return (idx, ret)
def main(hydra_cfg): hydra_cfg.device = hydra_cfg.device.lower() with open_dict(hydra_cfg): hydra_cfg.job_logging_cfg = HydraConfig.get().job_logging # random seed if hydra_cfg.random_seed is None: hydra_cfg.random_seed = random.randint(1, 10000) set_random_seed(hydra_cfg.random_seed) if hydra_cfg.dist.gpus < 0: hydra_cfg.dist.gpus = torch.cuda.device_count() hydra_cfg.dist.master_port = os.environ["MASTER_PORT"] hydra_cfg.dist.master_addr = os.environ["MASTER_ADDR"] print(hydra_cfg.dist) if hydra_cfg.device == "cpu" or hydra_cfg.dist.gpus == 0: hydra_cfg.dist.gpus = 0 train_loop(0, hydra_cfg) else: distributed_run(train_loop, hydra_cfg)
def hydra_main(cfg): with open_dict(cfg): # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) cfg.job_logging_cfg = OmegaConf.to_container( HydraConfig.get().job_logging, resolve=True ) cfg = OmegaConf.create( OmegaConf.to_container(cfg, resolve=False, enum_to_str=False) ) OmegaConf.set_struct(cfg, True) logger.info(cfg) utils.import_user_module(cfg.fairseq.common) _, score = main(cfg) if cfg.is_ax: return score, None return score
def fix_task(cfg: DictConfig) -> None: with open_dict(cfg): semantic = 'semantic_criterion' in cfg instance = 'embed_criterion' in cfg if semantic and instance: cfg.task = 'panoptic' elif semantic: cfg.task = 'semantic' elif instance: cfg.task = 'instance' else: raise RuntimeError( 'semantic_criterion and/or embed_criterion must be set!') if instance: requires_semantic = ('method' not in cfg.embed_criterion ) or cfg.embed_criterion.method in [ 'ignore', 'separate' ] cfg.requires_semantic = requires_semantic
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") # TODO if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") # TODO if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True" ) with open_dict(cfg["dataloader_params"]): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!") elif not shuffle_should_be and cfg.dataloader_params.shuffle: logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!") dataset = instantiate(cfg.dataset) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def __init__( self, cfg: DictConfig, trainer: Trainer = None, ): self.cfg = cfg self.data_prepared = False self.epoch_number = 0 if self.cfg.library == "huggingface": self.setup_tokenizer(cfg.tokenizer) elif self.cfg.library == "megatron": # supporting MegatronT5Model in precision = fp16 t5_cfg = MegatronT5Model.restore_from( restore_path=cfg.language_model.lm_checkpoint, trainer=trainer, return_config=True) # Override the T5 configuration with the one from the config file. OmegaConf.set_struct(t5_cfg, True) with open_dict(t5_cfg): t5_cfg.masked_softmax_fusion = False t5_cfg.precision = 16 language_model = MegatronT5Model.restore_from( restore_path=cfg.language_model.lm_checkpoint, trainer=trainer, override_config_path=t5_cfg) self.tokenizer = language_model.tokenizer super().__init__(cfg=cfg, trainer=trainer, no_lm_init=True) if self.cfg.library == "huggingface": self.language_model = AutoModelForSeq2SeqLM.from_pretrained( cfg.language_model.pretrained_model_name) self.language_model.resize_token_embeddings( len(self.tokenizer.tokenizer)) if self.cfg.language_model.lm_checkpoint: self.language_model.load_state_dict( torch.load(self.cfg.language_model.lm_checkpoint)) elif self.cfg.library == "megatron": self.language_model = language_model
def test_load_schema_as_config(hydra_restore_singletons: Any) -> None: """ Load structured config as a configuration """ ConfigStore.instance().store( name="config", node=TopLevelConfig, provider="this_test" ) ConfigStore.instance().store( name="db/mysql", node=MySQLConfig, provider="this_test" ) config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) cfg = config_loader.load_configuration( config_name="config", overrides=[], run_mode=RunMode.RUN ) expected = deepcopy(hydra_load_list) expected.append( LoadTrace( config_path="config", package="", parent="<root>", search_path="structured://", provider="this_test", ) ) assert_same_composition_trace(cfg.hydra.composition_trace, expected) with open_dict(cfg): del cfg["hydra"] assert cfg == { "normal_yaml_config": "???", "db": { "driver": MISSING, "host": MISSING, "port": MISSING, "user": MISSING, "password": MISSING, }, }
def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus super().__init__(cfg=cfg, trainer=trainer) self.preprocessor = EncDecCTCModel.from_config_dict(self._cfg.preprocessor) self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder) with open_dict(self._cfg): if "feat_in" not in self._cfg.decoder or ( not self._cfg.decoder.feat_in and hasattr(self.encoder, '_feat_out') ): self._cfg.decoder.feat_in = self.encoder._feat_out if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in: raise ValueError("param feat_in of the decoder's config is not set!") self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder) self.loss = CTCLoss( num_classes=self.decoder.num_classes_with_blank - 1, zero_infinity=True, reduction=self._cfg.get("ctc_reduction", "mean_batch"), ) if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: self.spec_augmentation = EncDecCTCModel.from_config_dict(self._cfg.spec_augment) else: self.spec_augmentation = None # Setup metric objects self._wer = WER( vocabulary=self.decoder.vocabulary, batch_dim_index=0, use_cer=self._cfg.get('use_cer', False), ctc_decode=True, dist_sync_on_step=True, log_prediction=self._cfg.get("log_prediction", False), )
def _get_rerun_conf(file_path: str, overrides: List[str]) -> DictConfig: msg = "Experimental rerun CLI option, other command line args are ignored." warnings.warn(msg, UserWarning) file = Path(file_path) if not file.exists(): raise ValueError(f"File {file} does not exist!") if len(overrides) > 0: msg = "Config overrides are not supported as of now." warnings.warn(msg, UserWarning) with open(str(file), "rb") as input: config = pickle.load(input) # nosec configure_log(config.hydra.job_logging, config.hydra.verbose) HydraConfig.instance().set_config(config) task_cfg = copy.deepcopy(config) with read_write(task_cfg): with open_dict(task_cfg): del task_cfg["hydra"] assert isinstance(task_cfg, DictConfig) return task_cfg