def get_model_detection_efficientdet(model_name, num_classes, target_dim, freeze_batch_norm=False): print("Using EffDet detection model") config = effdet.get_efficientdet_config(model_name) efficientDetModel = EfficientDet(config, pretrained_backbone=False) load_pretrained(efficientDetModel, config.url) import omegaconf with omegaconf.read_write(config): config.num_classes = num_classes # config.image_size = target_dim efficientDetModel.class_net = HeadNet(config, num_outputs=num_classes) if freeze_batch_norm: # we only freeze BN layers in backbone and the BiFPN print("Freezing batch normalization weights") freeze_bn(efficientDetModel.backbone) with omegaconf.read_write(efficientDetModel.config): efficientDetModel.config.num_classes = num_classes # print(DetBenchTrain(efficientDetModel, config)) return DetBenchTrain(efficientDetModel, config)
def launch( launcher: RayAWSLauncher, job_overrides: Sequence[Sequence[str]], initial_job_idx: int, ) -> Sequence[JobReturn]: setup_globals() assert launcher.config is not None assert launcher.config_loader is not None assert launcher.task_function is not None setup_commands = launcher.env_setup.commands with read_write(setup_commands): setup_commands.extend([ f"pip install {package}=={version}" for package, version in launcher.env_setup.pip_packages.items() ]) setup_commands.extend(launcher.ray_cfg.cluster.setup_commands) with read_write(launcher.ray_cfg.cluster): launcher.ray_cfg.cluster.setup_commands = setup_commands configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose) log.info(f"Ray Launcher is launching {len(job_overrides)} jobs, ") with tempfile.TemporaryDirectory() as local_tmp_dir: sweep_configs = [] for idx, overrides in enumerate(job_overrides): idx = initial_job_idx + idx ostr = " ".join(filter_overrides(overrides)) log.info(f"\t#{idx} : {ostr}") sweep_config = launcher.config_loader.load_sweep_config( launcher.config, list(overrides)) with open_dict(sweep_config): # job.id will be set on the EC2 instance before running the job. sweep_config.hydra.job.num = idx sweep_configs.append(sweep_config) _pickle_jobs( tmp_dir=local_tmp_dir, sweep_configs=sweep_configs, # type: ignore task_function=launcher.task_function, singleton_state=Singleton.get_state(), ) with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f: with open(f.name, "w") as file: OmegaConf.save(config=launcher.ray_cfg.cluster, f=file.name, resolve=True) launcher.ray_yaml_path = f.name log.info( f"Saving RayClusterConf in a temp yaml file: {launcher.ray_yaml_path}." ) return launch_jobs(launcher, local_tmp_dir, Path(HydraConfig.get().sweep.dir))
def launch( launcher: RayAWSLauncher, job_overrides: Sequence[Sequence[str]], initial_job_idx: int, ) -> Sequence[JobReturn]: setup_globals() assert launcher.config is not None assert launcher.hydra_context is not None assert launcher.task_function is not None setup_commands = launcher.env_setup.commands packages = filter( lambda x: x[1] is not None, launcher.env_setup.pip_packages.items() ) with read_write(setup_commands): setup_commands.extend( [f"pip install {package}=={version}" for package, version in packages] ) setup_commands.extend(launcher.ray_cfg.cluster.setup_commands) with read_write(launcher.ray_cfg.cluster): launcher.ray_cfg.cluster.setup_commands = setup_commands configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose) logging_config = OmegaConf.to_container( launcher.logging, resolve=True, enum_to_str=True ) sdk.configure_logging(**logging_config) log.info(f"Ray Launcher is launching {len(job_overrides)} jobs, ") with tempfile.TemporaryDirectory() as local_tmp_dir: sweep_configs = [] for idx, overrides in enumerate(job_overrides): idx = initial_job_idx + idx ostr = " ".join(filter_overrides(overrides)) log.info(f"\t#{idx} : {ostr}") sweep_config = launcher.hydra_context.config_loader.load_sweep_config( launcher.config, list(overrides) ) with open_dict(sweep_config): # job.id will be set on the EC2 instance before running the job. sweep_config.hydra.job.num = idx sweep_configs.append(sweep_config) _pickle_jobs( tmp_dir=local_tmp_dir, hydra_context=launcher.hydra_context, sweep_configs=sweep_configs, # type: ignore task_function=launcher.task_function, singleton_state=Singleton.get_state(), ) return launch_jobs( launcher, local_tmp_dir, Path(launcher.config.hydra.sweep.dir) )
def _extract_defaults_list(self, config_path: str, cfg: Container) -> ListConfig: empty = OmegaConf.create([]) if not OmegaConf.is_dict(cfg): return empty assert isinstance(cfg, DictConfig) with read_write(cfg): with open_dict(cfg): if not cfg._is_typed(): defaults = cfg.pop("defaults", empty) else: # If node is a backed by Structured Config, flag it and temporarily keep the defaults list in. # It will be removed later. # This is addressing an edge case where the defaults list re-appears once the dataclass is used # as a prototype during OmegaConf merge. cfg["__HYDRA_REMOVE_TOP_LEVEL_DEFAULTS__"] = True defaults = cfg.get("defaults", empty) if not isinstance(defaults, ListConfig): if isinstance(defaults, DictConfig): type_str = "mapping" else: type_str = type(defaults).__name__ raise ValueError( f"Invalid defaults list in '{config_path}', defaults must be a list (got {type_str})" ) return defaults
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any: # copy config to avoid mutating it when merging with kwargs config_copy = copy.deepcopy(config) # Manually set parent as deepcopy does not currently handles it (https://github.com/omry/omegaconf/issues/130) # noinspection PyProtectedMember config_copy._set_parent(config._get_parent()) # type: ignore config = config_copy params = config.params if "params" in config else OmegaConf.create() assert isinstance( params, DictConfig ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}" primitives = {} rest = {} for k, v in kwargs.items(): if _utils.is_primitive_type(v) or isinstance(v, (dict, list)): primitives[k] = v else: rest[k] = v final_kwargs = {} with read_write(params): params.merge_with(OmegaConf.create(primitives)) for k, v in params.items(): final_kwargs[k] = v for k, v in rest.items(): final_kwargs[k] = v return final_kwargs
def __init__(self, parent: Optional[Container], value: Any, metadata: Metadata): from omegaconf import read_write super().__init__(parent=parent, metadata=metadata) with read_write(self): self._set_value(value)
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any: if isinstance(config, ObjectConf): config = OmegaConf.structured(config) else: config = copy.deepcopy(config) params = config.params if hasattr(config, "params") else {} assert isinstance( params, MutableMapping ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}" if isinstance(config, DictConfig): assert isinstance(params, DictConfig) params._set_parent(config) primitives = {} rest = {} for k, v in kwargs.items(): if _utils.is_primitive_type(v) or isinstance(v, (dict, list)): primitives[k] = v else: rest[k] = v final_kwargs = {} with read_write(params): params.merge_with(primitives) for k, v in params.items(): final_kwargs[k] = v for k, v in rest.items(): final_kwargs[k] = v return final_kwargs
def run_job( config: DictConfig, task_function: TaskFunction, job_dir_key: str, job_subdir_key: Optional[str], configure_logging: bool = True, ) -> "JobReturn": old_cwd = os.getcwd() working_dir = str(OmegaConf.select(config, job_dir_key)) orig_hydra_cfg = HydraConfig.instance().cfg if job_subdir_key is not None: # evaluate job_subdir_key lazily. # this is running on the client side in sweep and contains things such as job:id which # are only available there. subdir = str(OmegaConf.select(config, job_subdir_key)) working_dir = os.path.join(working_dir, subdir) try: ret = JobReturn() ret.working_dir = working_dir task_cfg = copy.deepcopy(config) hydra_cfg = OmegaConf.masked_copy(task_cfg, "hydra") # maintain parent to preserve interpolation links from hydra_cfg to job_cfg hydra_cfg._set_parent(task_cfg) with read_write(task_cfg): with open_dict(task_cfg): del task_cfg["hydra"] HydraConfig.instance().cfg = hydra_cfg # type: ignore ret.cfg = task_cfg ret.hydra_cfg = hydra_cfg overrides = OmegaConf.to_container(config.hydra.overrides.task) assert isinstance(overrides, list) ret.overrides = overrides # handle output directories here Path(str(working_dir)).mkdir(parents=True, exist_ok=True) os.chdir(working_dir) if configure_logging: configure_log(config.hydra.job_logging, config.hydra.verbose) if config.hydra.output_subdir is not None: hydra_output = Path(config.hydra.output_subdir) _save_config(task_cfg, "config.yaml", hydra_output) _save_config(hydra_cfg, "hydra.yaml", hydra_output) _save_config(config.hydra.overrides.task, "overrides.yaml", hydra_output) with env_override(hydra_cfg.hydra.job.env_set): ret.return_value = task_function(task_cfg) ret.task_name = JobRuntime.instance().get("name") _flush_loggers() return ret finally: HydraConfig.instance().cfg = orig_hydra_cfg os.chdir(old_cwd)
def test_read_write_override(src: Any, func: Any, expectation: Any) -> None: c = OmegaConf.create(src) OmegaConf.set_readonly(c, True) with expectation: func(c) with does_not_raise(): with read_write(c): func(c)
def run_job( config: DictConfig, task_function: TaskFunction, job_dir_key: str, job_subdir_key: Optional[str], ) -> "JobReturn": old_cwd = os.getcwd() working_dir = str(OmegaConf.select(config, job_dir_key)) if job_subdir_key is not None: # evaluate job_subdir_key lazily. # this is running on the client side in sweep and contains things such as job:id which # are only available there. subdir = str(OmegaConf.select(config, job_subdir_key)) working_dir = os.path.join(working_dir, subdir) try: ret = JobReturn() ret.working_dir = working_dir task_cfg = copy.deepcopy(config) with read_write(task_cfg): with open_dict(task_cfg): del task_cfg["hydra"] ret.cfg = task_cfg ret.hydra_cfg = OmegaConf.create({"hydra": HydraConfig.get()}) overrides = OmegaConf.to_container(config.hydra.overrides.task) assert isinstance(overrides, list) ret.overrides = overrides # handle output directories here Path(str(working_dir)).mkdir(parents=True, exist_ok=True) os.chdir(working_dir) configure_log(config.hydra.job_logging, config.hydra.verbose) hydra_cfg = OmegaConf.masked_copy(config, "hydra") assert isinstance(hydra_cfg, DictConfig) if config.hydra.output_subdir is not None: hydra_output = Path(config.hydra.output_subdir) _save_config(task_cfg, "config.yaml", hydra_output) _save_config(hydra_cfg, "hydra.yaml", hydra_output) _save_config(config.hydra.overrides.task, "overrides.yaml", hydra_output) with env_override(hydra_cfg.hydra.job.env_set): ret.return_value = task_function(task_cfg) ret.task_name = JobRuntime.instance().get("name") # shut down logging to ensure job log files are closed. # If logging is still required after run_job caller is responsible to re-initialize it. logging.shutdown() return ret finally: os.chdir(old_cwd)
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any: if isinstance(config, ObjectConf): config = OmegaConf.structured(config) if config.params is not None: params = config.params else: params = OmegaConf.create() else: config = copy.deepcopy(config) if "params" in config: msg = ( "\nField 'params' is deprecated since Hydra 1.0 and will be removed in Hydra 1.1." "\nInline the content of params directly at the containing node." "\nSee https://hydra.cc/docs/next/upgrades/0.11_to_1.0/object_instantiation_changes" ) warnings.warn(category=UserWarning, message=msg) params = config.params else: params = config assert isinstance( params, DictConfig ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}" config_overrides = {} passthrough = {} for k, v in kwargs.items(): if k in params and not ( get_ref_type(params, k) is Any and OmegaConf.is_missing(params, k) ): config_overrides[k] = v else: passthrough[k] = v final_kwargs = {} with read_write(params): params.merge_with(config_overrides) for k in params.keys(): if k == "_target_": continue if k not in passthrough: final_kwargs[k] = params[k] for k, v in passthrough.items(): final_kwargs[k] = v for k, v in passthrough.items(): final_kwargs[k] = v return final_kwargs
def _extract_defaults_list( config_path: Optional[str], cfg: Container ) -> List[DefaultElement]: if not OmegaConf.is_dict(cfg): return [] assert isinstance(cfg, DictConfig) with read_write(cfg): with open_dict(cfg): defaults = cfg.pop("defaults", OmegaConf.create([])) if len(defaults) > 0: return ConfigSource._create_defaults_list( config_path=config_path, defaults=defaults ) else: return []
def test_experimental_save_job_info_callback(tmpdir: Path, multirun: bool) -> None: app_path = "tests/test_apps/app_with_pickle_job_info_callback/my_app.py" cmd = [ app_path, "hydra.run.dir=" + str(tmpdir), "hydra.sweep.dir=" + str(tmpdir), "hydra.job.chdir=True", ] if multirun: cmd.append("-m") _, _err = run_python_script(cmd) def load_pickle(path: Path) -> Any: with open(str(path), "rb") as input: obj = pickle.load(input) # nosec return obj # load pickles from callbacks callback_output = tmpdir / Path("0") / ".hydra" if multirun else tmpdir / ".hydra" config_on_job_start = load_pickle(callback_output / "config.pickle") job_return_on_job_end: JobReturn = load_pickle( callback_output / "job_return.pickle" ) task_cfg_from_callback = copy.deepcopy(config_on_job_start) with read_write(task_cfg_from_callback): with open_dict(task_cfg_from_callback): del task_cfg_from_callback["hydra"] # load pickles generated from the application app_output_dir = tmpdir / "0" if multirun else tmpdir task_cfg_from_app = load_pickle(app_output_dir / "task_cfg.pickle") hydra_cfg_from_app = load_pickle(app_output_dir / "hydra_cfg.pickle") # verify the cfg pickles are the same on_job_start assert task_cfg_from_callback == task_cfg_from_app assert config_on_job_start.hydra == hydra_cfg_from_app # verify pickled object are the same on_job_end assert job_return_on_job_end.cfg == task_cfg_from_app assert job_return_on_job_end.hydra_cfg.hydra == hydra_cfg_from_app # type: ignore assert job_return_on_job_end.return_value == "hello world" assert job_return_on_job_end.status == JobStatus.COMPLETED
def _extract_defaults_list(self, config_path: str, cfg: Container) -> ListConfig: empty = OmegaConf.create([]) if not OmegaConf.is_dict(cfg): return empty assert isinstance(cfg, DictConfig) with read_write(cfg): with open_dict(cfg): defaults = cfg.pop("defaults", empty) if not isinstance(defaults, ListConfig): if isinstance(defaults, DictConfig): type_str = "mapping" else: type_str = type(defaults).__name__ raise ValueError( f"Invalid defaults list in '{config_path}', defaults must be a list (got {type_str})" ) return defaults
def _get_rerun_conf(file_path: str, overrides: List[str]) -> DictConfig: msg = "Experimental rerun CLI option, other command line args are ignored." warnings.warn(msg, UserWarning) file = Path(file_path) if not file.exists(): raise ValueError(f"File {file} does not exist!") if len(overrides) > 0: msg = "Config overrides are not supported as of now." warnings.warn(msg, UserWarning) with open(str(file), "rb") as input: config = pickle.load(input) # nosec configure_log(config.hydra.job_logging, config.hydra.verbose) HydraConfig.instance().set_config(config) task_cfg = copy.deepcopy(config) with read_write(task_cfg): with open_dict(task_cfg): del task_cfg["hydra"] assert isinstance(task_cfg, DictConfig) return task_cfg
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any: if isinstance(config, ObjectConf): config = OmegaConf.structured(config) else: config = copy.deepcopy(config) params = ( config.params if hasattr(config, "params") and config.params is not None else OmegaConf.create() ) assert isinstance( params, MutableMapping ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}" if isinstance(config, DictConfig): assert isinstance(params, DictConfig) params._set_parent(config) config_overrides = {} passthrough = {} for k, v in kwargs.items(): if k in params: config_overrides[k] = v else: passthrough[k] = v final_kwargs = {} with read_write(params): params.merge_with(config_overrides) for k, v in params.items(): final_kwargs[k] = v for k, v in passthrough.items(): final_kwargs[k] = v return final_kwargs
def run_job( task_function: TaskFunction, config: DictConfig, job_dir_key: str, job_subdir_key: Optional[str], configure_logging: bool = True, hydra_context: Optional[HydraContext] = None, ) -> "JobReturn": callbacks = _get_callbacks_for_run_job(hydra_context) old_cwd = os.getcwd() orig_hydra_cfg = HydraConfig.instance().cfg HydraConfig.instance().set_config(config) working_dir = str(OmegaConf.select(config, job_dir_key)) if job_subdir_key is not None: # evaluate job_subdir_key lazily. # this is running on the client side in sweep and contains things such as job:id which # are only available there. subdir = str(OmegaConf.select(config, job_subdir_key)) working_dir = os.path.join(working_dir, subdir) try: ret = JobReturn() ret.working_dir = working_dir task_cfg = copy.deepcopy(config) with read_write(task_cfg): with open_dict(task_cfg): del task_cfg["hydra"] ret.cfg = task_cfg hydra_cfg = copy.deepcopy(HydraConfig.instance().cfg) assert isinstance(hydra_cfg, DictConfig) ret.hydra_cfg = hydra_cfg overrides = OmegaConf.to_container(config.hydra.overrides.task) assert isinstance(overrides, list) ret.overrides = overrides # handle output directories here Path(str(working_dir)).mkdir(parents=True, exist_ok=True) os.chdir(working_dir) if configure_logging: configure_log(config.hydra.job_logging, config.hydra.verbose) if config.hydra.output_subdir is not None: hydra_output = Path(config.hydra.output_subdir) _save_config(task_cfg, "config.yaml", hydra_output) _save_config(hydra_cfg, "hydra.yaml", hydra_output) _save_config(config.hydra.overrides.task, "overrides.yaml", hydra_output) with env_override(hydra_cfg.hydra.job.env_set): callbacks.on_job_start(config=config) try: ret.return_value = task_function(task_cfg) ret.status = JobStatus.COMPLETED except Exception as e: ret.return_value = e ret.status = JobStatus.FAILED ret.task_name = JobRuntime.instance().get("name") _flush_loggers() callbacks.on_job_end(config=config, job_return=ret) return ret finally: HydraConfig.instance().cfg = orig_hydra_cfg os.chdir(old_cwd)
def run_job( task_function: TaskFunction, config: DictConfig, job_dir_key: str, job_subdir_key: Optional[str], configure_logging: bool = True, hydra_context: Optional[HydraContext] = None, ) -> "JobReturn": callbacks = _get_callbacks_for_run_job(hydra_context) old_cwd = os.getcwd() orig_hydra_cfg = HydraConfig.instance().cfg output_dir = str(OmegaConf.select(config, job_dir_key)) if job_subdir_key is not None: # evaluate job_subdir_key lazily. # this is running on the client side in sweep and contains things such as job:id which # are only available there. subdir = str(OmegaConf.select(config, job_subdir_key)) output_dir = os.path.join(output_dir, subdir) with read_write(config.hydra.runtime): with open_dict(config.hydra.runtime): config.hydra.runtime.output_dir = os.path.abspath(output_dir) HydraConfig.instance().set_config(config) _chdir = None try: ret = JobReturn() task_cfg = copy.deepcopy(config) with read_write(task_cfg): with open_dict(task_cfg): del task_cfg["hydra"] ret.cfg = task_cfg hydra_cfg = copy.deepcopy(HydraConfig.instance().cfg) assert isinstance(hydra_cfg, DictConfig) ret.hydra_cfg = hydra_cfg overrides = OmegaConf.to_container(config.hydra.overrides.task) assert isinstance(overrides, list) ret.overrides = overrides # handle output directories here Path(str(output_dir)).mkdir(parents=True, exist_ok=True) _chdir = hydra_cfg.hydra.job.chdir if _chdir is None: url = "https://hydra.cc/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir" deprecation_warning( message=dedent(f"""\ Hydra 1.3 will no longer change working directory at job runtime by default. See {url} for more information."""), stacklevel=2, ) _chdir = True if _chdir: os.chdir(output_dir) ret.working_dir = output_dir else: ret.working_dir = os.getcwd() if configure_logging: configure_log(config.hydra.job_logging, config.hydra.verbose) if config.hydra.output_subdir is not None: hydra_output = Path(config.hydra.runtime.output_dir) / Path( config.hydra.output_subdir) _save_config(task_cfg, "config.yaml", hydra_output) _save_config(hydra_cfg, "hydra.yaml", hydra_output) _save_config(config.hydra.overrides.task, "overrides.yaml", hydra_output) with env_override(hydra_cfg.hydra.job.env_set): callbacks.on_job_start(config=config) try: ret.return_value = task_function(task_cfg) ret.status = JobStatus.COMPLETED except Exception as e: ret.return_value = e ret.status = JobStatus.FAILED ret.task_name = JobRuntime.instance().get("name") _flush_loggers() callbacks.on_job_end(config=config, job_return=ret) return ret finally: HydraConfig.instance().cfg = orig_hydra_cfg if _chdir: os.chdir(old_cwd)
def __init__(self, parent: Optional[Box], value: Any, metadata: Metadata): from omegaconf import read_write super().__init__(parent=parent, metadata=metadata) with read_write(self): self._set_value(value) # lgtm [py/init-calls-subclass]
def _main(cfg: DictConfig, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) if 'label_dir' in cfg.task: manifest_dir, _ = os.path.split(cfg.dataset.gen_subset) with read_write(cfg): cfg.task.label_dir = os.path.join(cfg.task.data, manifest_dir) print('cfg.task.data', cfg.task.label_dir) logger = logging.getLogger("fairseq_cli.generate") utils.import_user_module(cfg.common) if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.max_tokens = 12000 logger.info(cfg) # Fix seed for stochastic decoding if cfg.common.seed is not None and not cfg.generation.no_seed_provided: np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset splits task = tasks.setup_task(cfg.task) # Set dictionaries try: src_dict = getattr(task, "source_dictionary", None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) token_type = None if type(models[0]) == Wav2Bart or type(models[0]) == WavTransBart or type(models[0]) == WavLinearBart or type(models[0]) == WavBart2Bart: token_type = 'bart' elif type(models[0]) == Wav2BartChr: token_type = 'chr' elif type(models[0]) == Wav2VecCtc or type(models[0]) == Wav2BertChr or type(models[0]) == Wav2BertMixChr: token_type = 'chrctc' elif type(models[0]) == Wav2Bert: token_type = 'bert' else: raise ValueError(f'token_type not defined for {type(models[0])}') print(f'token_type is {token_type}') # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) if cfg.generation.lm_path is not None: overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( [cfg.generation.lm_path], arg_overrides=overrides, task=None ) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " f"as target dict and is located in the data dir ({cfg.task.data})" ) raise assert len(lms) == 1 else: lms = [None] # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue if cfg.common.fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(cfg.generation.replace_unk) itr = task.get_batch_iterator( dataset=task.dataset(cfg.dataset.gen_subset), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models] ), ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, num_shards=cfg.distributed_training.distributed_world_size, shard_id=cfg.distributed_training.distributed_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} print('cfg.generation', cfg.generation) # print(cfg.task._name == 'audio_pretraining') if cfg.task._name != 'audio_pretraining' and cfg.task._name != 'audio_pretraining_bertbpe': generator = task.build_generator( models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) else: print('use W2lViterbiDecoder') from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder from easydict import EasyDict as edict args = edict({ 'criterion': 'ctc', 'nbest': 1, }) generator = W2lViterbiDecoder(args, task.target_dictionary) # Handle tokenization and BPE tokenizer = task.build_tokenizer(cfg.tokenizer) bpe = task.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x scorer = scoring.build_scorer(cfg.scoring, tgt_dict) num_sentences = 0 has_target = True wps_meter = TimeMeter() for si, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if cfg.generation.prefix_size > 0: prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() hypos = task.inference_step( generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints, ) # print('hypos', hypos) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): has_target = sample["target"] is not None # Remove padding if "src_tokens" in sample["net_input"]: src_tokens = utils.strip_pad( sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() ) else: src_tokens = None target_tokens = None if has_target: target_tokens = ( utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() ) # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( sample_id ) target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( sample_id ) else: if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) else: src_str = "" if has_target: if token_type == 'chr': target_str = tgt_dict.string( target_tokens, cfg.common_eval.post_process, escape_unk=True, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator ), ) elif token_type == 'bart': target_str = task.bart.decode(target_tokens.int().cpu()) elif token_type == 'bert': target_str = task.bert.decode(target_tokens.int().cpu()) elif token_type == 'chrctc': target_str = tgt_dict.string( target_tokens, cfg.common_eval.post_process, escape_unk=True, ) else: raise ValueError(f'token_type not defined for {type(models[0])}') src_str = decode_fn(src_str) if has_target and token_type == 'chr': target_str = decode_fn(target_str) elif has_target and token_type == 'chrctc': target_str = ''.join(target_str.split()).replace('|', ' ') if not cfg.common_eval.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: print("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): # print('align', hypo["alignment"]) if token_type == 'bart': hypo_tokens = hypo["tokens"].int().cpu() hypo_str = task.bart.decode(hypo["tokens"].int().cpu()) alignment = hypo["alignment"] elif token_type == 'chr': hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=cfg.common_eval.post_process, # extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) elif token_type == 'chrctc': hypo_tokens = hypo["tokens"].int().cpu() hypo_str = task.target_dictionary.string(hypo_tokens) hypo["positional_scores"] = torch.FloatTensor([0.]) elif token_type == 'bert': hypo_tokens = hypo["tokens"].int().cpu() hypo_str = task.bert.decode(hypo["tokens"].int().cpu()) alignment = hypo["alignment"] else: raise ValueError(f'token_type not defined for {type(models[0])}') detok_hypo_str = decode_fn(hypo_str) if token_type == 'chr' or token_type == 'chrctc': print('target_str', ''.join(target_str.split()).replace('|', ' ')) print('typo_str', ''.join(detok_hypo_str.split()).replace('|', ' ')) detok_hypo_str = ''.join(detok_hypo_str.split()).replace('|', ' ') # target_str = ''.join(target_str.split()).replace('|', ' ') elif token_type == 'bart': print('target_str', target_str) print('typo_str', detok_hypo_str) #elif token_type == 'chrctc': # print('target_str', ''.join(target_str.split()).replace('|', ' ')) # print('typo_str', ''.join(detok_hypo_str.split()).replace('|', ' ')) if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print( "H-{}\t{}\t{}".format(sample_id, score, hypo_str), file=output_file, ) # detokenized hypothesis print( "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), file=output_file, ) print( "P-{}\t{}".format( sample_id, " ".join( map( lambda x: "{:.4f}".format(x), # convert from base e to base 2 hypo["positional_scores"] .div_(math.log(2)) .tolist(), ) ), ), file=output_file, ) if cfg.generation.print_alignment == "hard": print( "A-{}\t{}".format( sample_id, " ".join( [ "{}-{}".format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment ] ), ), file=output_file, ) if cfg.generation.print_alignment == "soft": print( "A-{}\t{}".format( sample_id, " ".join( [ ",".join(src_probs) for src_probs in alignment ] ), ), file=output_file, ) if cfg.generation.print_step: print( "I-{}\t{}".format(sample_id, hypo["steps"]), file=output_file, ) if cfg.generation.retain_iter_history: for step, h in enumerate(hypo["history"]): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h["tokens"].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print( "E-{}_{}\t{}".format(sample_id, step, h_str), file=output_file, ) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or cfg.common_eval.post_process is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True ) hypo_tokens = tgt_dict.encode_line( detok_hypo_str, add_if_not_exist=True ) if hasattr(scorer, "add_string"): # print('add_string 1', target_str, '2', detok_hypo_str) # if si > 2: # raise print('2', target_str, detok_hypo_str) scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) progress.log({"wps": round(wps_meter.avg)}) num_sentences += ( sample["nsentences"] if "nsentences" in sample else sample["id"].numel() ) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info( "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, ) ) if has_target: if cfg.bpe and not cfg.generation.sacrebleu: if cfg.common_eval.post_process: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) else: logger.warning( "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" ) # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( "Generate {} with beam={}: {}".format( cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() ), file=output_file, ) return scorer