Exemple #1
0
def get_model_detection_efficientdet(model_name,
                                     num_classes,
                                     target_dim,
                                     freeze_batch_norm=False):
    print("Using EffDet detection model")

    config = effdet.get_efficientdet_config(model_name)
    efficientDetModel = EfficientDet(config, pretrained_backbone=False)
    load_pretrained(efficientDetModel, config.url)
    import omegaconf
    with omegaconf.read_write(config):
        config.num_classes = num_classes
        # config.image_size = target_dim
    efficientDetModel.class_net = HeadNet(config, num_outputs=num_classes)

    if freeze_batch_norm:
        # we only freeze BN layers in backbone and the BiFPN
        print("Freezing batch normalization weights")
        freeze_bn(efficientDetModel.backbone)

    with omegaconf.read_write(efficientDetModel.config):
        efficientDetModel.config.num_classes = num_classes

    # print(DetBenchTrain(efficientDetModel, config))
    return DetBenchTrain(efficientDetModel, config)
Exemple #2
0
def launch(
    launcher: RayAWSLauncher,
    job_overrides: Sequence[Sequence[str]],
    initial_job_idx: int,
) -> Sequence[JobReturn]:
    setup_globals()
    assert launcher.config is not None
    assert launcher.config_loader is not None
    assert launcher.task_function is not None

    setup_commands = launcher.env_setup.commands
    with read_write(setup_commands):
        setup_commands.extend([
            f"pip install {package}=={version}"
            for package, version in launcher.env_setup.pip_packages.items()
        ])
        setup_commands.extend(launcher.ray_cfg.cluster.setup_commands)

    with read_write(launcher.ray_cfg.cluster):
        launcher.ray_cfg.cluster.setup_commands = setup_commands

    configure_log(launcher.config.hydra.hydra_logging,
                  launcher.config.hydra.verbose)

    log.info(f"Ray Launcher is launching {len(job_overrides)} jobs, ")

    with tempfile.TemporaryDirectory() as local_tmp_dir:
        sweep_configs = []
        for idx, overrides in enumerate(job_overrides):
            idx = initial_job_idx + idx
            ostr = " ".join(filter_overrides(overrides))
            log.info(f"\t#{idx} : {ostr}")
            sweep_config = launcher.config_loader.load_sweep_config(
                launcher.config, list(overrides))
            with open_dict(sweep_config):
                # job.id will be set on the EC2 instance before running the job.
                sweep_config.hydra.job.num = idx

            sweep_configs.append(sweep_config)

        _pickle_jobs(
            tmp_dir=local_tmp_dir,
            sweep_configs=sweep_configs,  # type: ignore
            task_function=launcher.task_function,
            singleton_state=Singleton.get_state(),
        )

        with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f:
            with open(f.name, "w") as file:
                OmegaConf.save(config=launcher.ray_cfg.cluster,
                               f=file.name,
                               resolve=True)
            launcher.ray_yaml_path = f.name
            log.info(
                f"Saving RayClusterConf in a temp yaml file: {launcher.ray_yaml_path}."
            )

            return launch_jobs(launcher, local_tmp_dir,
                               Path(HydraConfig.get().sweep.dir))
Exemple #3
0
def launch(
    launcher: RayAWSLauncher,
    job_overrides: Sequence[Sequence[str]],
    initial_job_idx: int,
) -> Sequence[JobReturn]:
    setup_globals()
    assert launcher.config is not None
    assert launcher.hydra_context is not None
    assert launcher.task_function is not None

    setup_commands = launcher.env_setup.commands
    packages = filter(
        lambda x: x[1] is not None, launcher.env_setup.pip_packages.items()
    )
    with read_write(setup_commands):
        setup_commands.extend(
            [f"pip install {package}=={version}" for package, version in packages]
        )
        setup_commands.extend(launcher.ray_cfg.cluster.setup_commands)

    with read_write(launcher.ray_cfg.cluster):
        launcher.ray_cfg.cluster.setup_commands = setup_commands

    configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose)
    logging_config = OmegaConf.to_container(
        launcher.logging, resolve=True, enum_to_str=True
    )
    sdk.configure_logging(**logging_config)

    log.info(f"Ray Launcher is launching {len(job_overrides)} jobs, ")

    with tempfile.TemporaryDirectory() as local_tmp_dir:
        sweep_configs = []
        for idx, overrides in enumerate(job_overrides):
            idx = initial_job_idx + idx
            ostr = " ".join(filter_overrides(overrides))
            log.info(f"\t#{idx} : {ostr}")
            sweep_config = launcher.hydra_context.config_loader.load_sweep_config(
                launcher.config, list(overrides)
            )
            with open_dict(sweep_config):
                # job.id will be set on the EC2 instance before running the job.
                sweep_config.hydra.job.num = idx

            sweep_configs.append(sweep_config)

        _pickle_jobs(
            tmp_dir=local_tmp_dir,
            hydra_context=launcher.hydra_context,
            sweep_configs=sweep_configs,  # type: ignore
            task_function=launcher.task_function,
            singleton_state=Singleton.get_state(),
        )
        return launch_jobs(
            launcher, local_tmp_dir, Path(launcher.config.hydra.sweep.dir)
        )
    def _extract_defaults_list(self, config_path: str,
                               cfg: Container) -> ListConfig:
        empty = OmegaConf.create([])
        if not OmegaConf.is_dict(cfg):
            return empty
        assert isinstance(cfg, DictConfig)
        with read_write(cfg):
            with open_dict(cfg):
                if not cfg._is_typed():
                    defaults = cfg.pop("defaults", empty)
                else:
                    # If node is a backed by Structured Config, flag it and temporarily keep the defaults list in.
                    # It will be removed later.
                    # This is addressing an edge case where the defaults list re-appears once the dataclass is used
                    # as a prototype during OmegaConf merge.
                    cfg["__HYDRA_REMOVE_TOP_LEVEL_DEFAULTS__"] = True
                    defaults = cfg.get("defaults", empty)
        if not isinstance(defaults, ListConfig):
            if isinstance(defaults, DictConfig):
                type_str = "mapping"
            else:
                type_str = type(defaults).__name__
            raise ValueError(
                f"Invalid defaults list in '{config_path}', defaults must be a list (got {type_str})"
            )

        return defaults
Exemple #5
0
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any:
    # copy config to avoid mutating it when merging with kwargs
    config_copy = copy.deepcopy(config)

    # Manually set parent as deepcopy does not currently handles it (https://github.com/omry/omegaconf/issues/130)
    # noinspection PyProtectedMember
    config_copy._set_parent(config._get_parent())  # type: ignore
    config = config_copy

    params = config.params if "params" in config else OmegaConf.create()
    assert isinstance(
        params, DictConfig
    ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}"
    primitives = {}
    rest = {}
    for k, v in kwargs.items():
        if _utils.is_primitive_type(v) or isinstance(v, (dict, list)):
            primitives[k] = v
        else:
            rest[k] = v
    final_kwargs = {}
    with read_write(params):
        params.merge_with(OmegaConf.create(primitives))

    for k, v in params.items():
        final_kwargs[k] = v

    for k, v in rest.items():
        final_kwargs[k] = v
    return final_kwargs
Exemple #6
0
    def __init__(self, parent: Optional[Container], value: Any,
                 metadata: Metadata):
        from omegaconf import read_write

        super().__init__(parent=parent, metadata=metadata)
        with read_write(self):
            self._set_value(value)
Exemple #7
0
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any:

    if isinstance(config, ObjectConf):
        config = OmegaConf.structured(config)
    else:
        config = copy.deepcopy(config)

    params = config.params if hasattr(config, "params") else {}

    assert isinstance(
        params, MutableMapping
    ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}"

    if isinstance(config, DictConfig):
        assert isinstance(params, DictConfig)
        params._set_parent(config)

    primitives = {}
    rest = {}
    for k, v in kwargs.items():
        if _utils.is_primitive_type(v) or isinstance(v, (dict, list)):
            primitives[k] = v
        else:
            rest[k] = v
    final_kwargs = {}

    with read_write(params):
        params.merge_with(primitives)

    for k, v in params.items():
        final_kwargs[k] = v

    for k, v in rest.items():
        final_kwargs[k] = v
    return final_kwargs
Exemple #8
0
def run_job(
    config: DictConfig,
    task_function: TaskFunction,
    job_dir_key: str,
    job_subdir_key: Optional[str],
    configure_logging: bool = True,
) -> "JobReturn":
    old_cwd = os.getcwd()
    working_dir = str(OmegaConf.select(config, job_dir_key))
    orig_hydra_cfg = HydraConfig.instance().cfg
    if job_subdir_key is not None:
        # evaluate job_subdir_key lazily.
        # this is running on the client side in sweep and contains things such as job:id which
        # are only available there.
        subdir = str(OmegaConf.select(config, job_subdir_key))
        working_dir = os.path.join(working_dir, subdir)
    try:
        ret = JobReturn()
        ret.working_dir = working_dir
        task_cfg = copy.deepcopy(config)
        hydra_cfg = OmegaConf.masked_copy(task_cfg, "hydra")
        # maintain parent to preserve interpolation links from hydra_cfg to job_cfg
        hydra_cfg._set_parent(task_cfg)
        with read_write(task_cfg):
            with open_dict(task_cfg):
                del task_cfg["hydra"]
        HydraConfig.instance().cfg = hydra_cfg  # type: ignore

        ret.cfg = task_cfg
        ret.hydra_cfg = hydra_cfg
        overrides = OmegaConf.to_container(config.hydra.overrides.task)
        assert isinstance(overrides, list)
        ret.overrides = overrides
        # handle output directories here
        Path(str(working_dir)).mkdir(parents=True, exist_ok=True)
        os.chdir(working_dir)

        if configure_logging:
            configure_log(config.hydra.job_logging, config.hydra.verbose)

        if config.hydra.output_subdir is not None:
            hydra_output = Path(config.hydra.output_subdir)
            _save_config(task_cfg, "config.yaml", hydra_output)
            _save_config(hydra_cfg, "hydra.yaml", hydra_output)
            _save_config(config.hydra.overrides.task, "overrides.yaml",
                         hydra_output)

        with env_override(hydra_cfg.hydra.job.env_set):
            ret.return_value = task_function(task_cfg)
        ret.task_name = JobRuntime.instance().get("name")

        _flush_loggers()

        return ret
    finally:
        HydraConfig.instance().cfg = orig_hydra_cfg
        os.chdir(old_cwd)
Exemple #9
0
def test_read_write_override(src: Any, func: Any, expectation: Any) -> None:
    c = OmegaConf.create(src)
    OmegaConf.set_readonly(c, True)

    with expectation:
        func(c)

    with does_not_raise():
        with read_write(c):
            func(c)
Exemple #10
0
def run_job(
    config: DictConfig,
    task_function: TaskFunction,
    job_dir_key: str,
    job_subdir_key: Optional[str],
) -> "JobReturn":
    old_cwd = os.getcwd()
    working_dir = str(OmegaConf.select(config, job_dir_key))
    if job_subdir_key is not None:
        # evaluate job_subdir_key lazily.
        # this is running on the client side in sweep and contains things such as job:id which
        # are only available there.
        subdir = str(OmegaConf.select(config, job_subdir_key))
        working_dir = os.path.join(working_dir, subdir)
    try:
        ret = JobReturn()
        ret.working_dir = working_dir
        task_cfg = copy.deepcopy(config)
        with read_write(task_cfg):
            with open_dict(task_cfg):
                del task_cfg["hydra"]
        ret.cfg = task_cfg
        ret.hydra_cfg = OmegaConf.create({"hydra": HydraConfig.get()})
        overrides = OmegaConf.to_container(config.hydra.overrides.task)
        assert isinstance(overrides, list)
        ret.overrides = overrides
        # handle output directories here
        Path(str(working_dir)).mkdir(parents=True, exist_ok=True)
        os.chdir(working_dir)

        configure_log(config.hydra.job_logging, config.hydra.verbose)

        hydra_cfg = OmegaConf.masked_copy(config, "hydra")
        assert isinstance(hydra_cfg, DictConfig)

        if config.hydra.output_subdir is not None:
            hydra_output = Path(config.hydra.output_subdir)
            _save_config(task_cfg, "config.yaml", hydra_output)
            _save_config(hydra_cfg, "hydra.yaml", hydra_output)
            _save_config(config.hydra.overrides.task, "overrides.yaml",
                         hydra_output)

        with env_override(hydra_cfg.hydra.job.env_set):
            ret.return_value = task_function(task_cfg)
        ret.task_name = JobRuntime.instance().get("name")

        # shut down logging to ensure job log files are closed.
        # If logging is still required after run_job caller is responsible to re-initialize it.
        logging.shutdown()

        return ret
    finally:
        os.chdir(old_cwd)
Exemple #11
0
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any:

    if isinstance(config, ObjectConf):
        config = OmegaConf.structured(config)
        if config.params is not None:
            params = config.params
        else:
            params = OmegaConf.create()
    else:
        config = copy.deepcopy(config)
        if "params" in config:
            msg = (
                "\nField 'params' is deprecated since Hydra 1.0 and will be removed in Hydra 1.1."
                "\nInline the content of params directly at the containing node."
                "\nSee https://hydra.cc/docs/next/upgrades/0.11_to_1.0/object_instantiation_changes"
            )
            warnings.warn(category=UserWarning, message=msg)
            params = config.params
        else:
            params = config

    assert isinstance(
        params, DictConfig
    ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}"

    config_overrides = {}
    passthrough = {}
    for k, v in kwargs.items():
        if k in params and not (
            get_ref_type(params, k) is Any and OmegaConf.is_missing(params, k)
        ):
            config_overrides[k] = v
        else:
            passthrough[k] = v
    final_kwargs = {}

    with read_write(params):
        params.merge_with(config_overrides)

    for k in params.keys():
        if k == "_target_":
            continue
        if k not in passthrough:
            final_kwargs[k] = params[k]

    for k, v in passthrough.items():
        final_kwargs[k] = v

    for k, v in passthrough.items():
        final_kwargs[k] = v
    return final_kwargs
Exemple #12
0
    def _extract_defaults_list(
        config_path: Optional[str], cfg: Container
    ) -> List[DefaultElement]:
        if not OmegaConf.is_dict(cfg):
            return []

        assert isinstance(cfg, DictConfig)
        with read_write(cfg):
            with open_dict(cfg):
                defaults = cfg.pop("defaults", OmegaConf.create([]))

        if len(defaults) > 0:
            return ConfigSource._create_defaults_list(
                config_path=config_path, defaults=defaults
            )
        else:
            return []
Exemple #13
0
def test_experimental_save_job_info_callback(tmpdir: Path, multirun: bool) -> None:
    app_path = "tests/test_apps/app_with_pickle_job_info_callback/my_app.py"

    cmd = [
        app_path,
        "hydra.run.dir=" + str(tmpdir),
        "hydra.sweep.dir=" + str(tmpdir),
        "hydra.job.chdir=True",
    ]
    if multirun:
        cmd.append("-m")
    _, _err = run_python_script(cmd)

    def load_pickle(path: Path) -> Any:
        with open(str(path), "rb") as input:
            obj = pickle.load(input)  # nosec
        return obj

    # load pickles from callbacks
    callback_output = tmpdir / Path("0") / ".hydra" if multirun else tmpdir / ".hydra"
    config_on_job_start = load_pickle(callback_output / "config.pickle")
    job_return_on_job_end: JobReturn = load_pickle(
        callback_output / "job_return.pickle"
    )

    task_cfg_from_callback = copy.deepcopy(config_on_job_start)
    with read_write(task_cfg_from_callback):
        with open_dict(task_cfg_from_callback):
            del task_cfg_from_callback["hydra"]

    # load pickles generated from the application
    app_output_dir = tmpdir / "0" if multirun else tmpdir
    task_cfg_from_app = load_pickle(app_output_dir / "task_cfg.pickle")
    hydra_cfg_from_app = load_pickle(app_output_dir / "hydra_cfg.pickle")

    # verify the cfg pickles are the same on_job_start
    assert task_cfg_from_callback == task_cfg_from_app
    assert config_on_job_start.hydra == hydra_cfg_from_app

    # verify pickled object are the same on_job_end
    assert job_return_on_job_end.cfg == task_cfg_from_app
    assert job_return_on_job_end.hydra_cfg.hydra == hydra_cfg_from_app  # type: ignore
    assert job_return_on_job_end.return_value == "hello world"
    assert job_return_on_job_end.status == JobStatus.COMPLETED
Exemple #14
0
    def _extract_defaults_list(self, config_path: str,
                               cfg: Container) -> ListConfig:
        empty = OmegaConf.create([])
        if not OmegaConf.is_dict(cfg):
            return empty
        assert isinstance(cfg, DictConfig)
        with read_write(cfg):
            with open_dict(cfg):
                defaults = cfg.pop("defaults", empty)
        if not isinstance(defaults, ListConfig):
            if isinstance(defaults, DictConfig):
                type_str = "mapping"
            else:
                type_str = type(defaults).__name__
            raise ValueError(
                f"Invalid defaults list in '{config_path}', defaults must be a list (got {type_str})"
            )

        return defaults
Exemple #15
0
def _get_rerun_conf(file_path: str, overrides: List[str]) -> DictConfig:
    msg = "Experimental rerun CLI option, other command line args are ignored."
    warnings.warn(msg, UserWarning)
    file = Path(file_path)
    if not file.exists():
        raise ValueError(f"File {file} does not exist!")

    if len(overrides) > 0:
        msg = "Config overrides are not supported as of now."
        warnings.warn(msg, UserWarning)

    with open(str(file), "rb") as input:
        config = pickle.load(input)  # nosec
    configure_log(config.hydra.job_logging, config.hydra.verbose)
    HydraConfig.instance().set_config(config)
    task_cfg = copy.deepcopy(config)
    with read_write(task_cfg):
        with open_dict(task_cfg):
            del task_cfg["hydra"]
    assert isinstance(task_cfg, DictConfig)
    return task_cfg
Exemple #16
0
def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any:

    if isinstance(config, ObjectConf):
        config = OmegaConf.structured(config)
    else:
        config = copy.deepcopy(config)

    params = (
        config.params
        if hasattr(config, "params") and config.params is not None
        else OmegaConf.create()
    )

    assert isinstance(
        params, MutableMapping
    ), f"Input config params are expected to be a mapping, found {type(config.params).__name__}"

    if isinstance(config, DictConfig):
        assert isinstance(params, DictConfig)
        params._set_parent(config)

    config_overrides = {}
    passthrough = {}
    for k, v in kwargs.items():
        if k in params:
            config_overrides[k] = v
        else:
            passthrough[k] = v
    final_kwargs = {}

    with read_write(params):
        params.merge_with(config_overrides)

    for k, v in params.items():
        final_kwargs[k] = v

    for k, v in passthrough.items():
        final_kwargs[k] = v
    return final_kwargs
Exemple #17
0
def run_job(
    task_function: TaskFunction,
    config: DictConfig,
    job_dir_key: str,
    job_subdir_key: Optional[str],
    configure_logging: bool = True,
    hydra_context: Optional[HydraContext] = None,
) -> "JobReturn":
    callbacks = _get_callbacks_for_run_job(hydra_context)

    old_cwd = os.getcwd()
    orig_hydra_cfg = HydraConfig.instance().cfg
    HydraConfig.instance().set_config(config)
    working_dir = str(OmegaConf.select(config, job_dir_key))
    if job_subdir_key is not None:
        # evaluate job_subdir_key lazily.
        # this is running on the client side in sweep and contains things such as job:id which
        # are only available there.
        subdir = str(OmegaConf.select(config, job_subdir_key))
        working_dir = os.path.join(working_dir, subdir)
    try:
        ret = JobReturn()
        ret.working_dir = working_dir
        task_cfg = copy.deepcopy(config)
        with read_write(task_cfg):
            with open_dict(task_cfg):
                del task_cfg["hydra"]

        ret.cfg = task_cfg
        hydra_cfg = copy.deepcopy(HydraConfig.instance().cfg)
        assert isinstance(hydra_cfg, DictConfig)
        ret.hydra_cfg = hydra_cfg
        overrides = OmegaConf.to_container(config.hydra.overrides.task)
        assert isinstance(overrides, list)
        ret.overrides = overrides
        # handle output directories here
        Path(str(working_dir)).mkdir(parents=True, exist_ok=True)
        os.chdir(working_dir)

        if configure_logging:
            configure_log(config.hydra.job_logging, config.hydra.verbose)

        if config.hydra.output_subdir is not None:
            hydra_output = Path(config.hydra.output_subdir)
            _save_config(task_cfg, "config.yaml", hydra_output)
            _save_config(hydra_cfg, "hydra.yaml", hydra_output)
            _save_config(config.hydra.overrides.task, "overrides.yaml",
                         hydra_output)

        with env_override(hydra_cfg.hydra.job.env_set):
            callbacks.on_job_start(config=config)
            try:
                ret.return_value = task_function(task_cfg)
                ret.status = JobStatus.COMPLETED
            except Exception as e:
                ret.return_value = e
                ret.status = JobStatus.FAILED

        ret.task_name = JobRuntime.instance().get("name")

        _flush_loggers()

        callbacks.on_job_end(config=config, job_return=ret)

        return ret
    finally:
        HydraConfig.instance().cfg = orig_hydra_cfg
        os.chdir(old_cwd)
Exemple #18
0
def run_job(
    task_function: TaskFunction,
    config: DictConfig,
    job_dir_key: str,
    job_subdir_key: Optional[str],
    configure_logging: bool = True,
    hydra_context: Optional[HydraContext] = None,
) -> "JobReturn":
    callbacks = _get_callbacks_for_run_job(hydra_context)

    old_cwd = os.getcwd()
    orig_hydra_cfg = HydraConfig.instance().cfg

    output_dir = str(OmegaConf.select(config, job_dir_key))
    if job_subdir_key is not None:
        # evaluate job_subdir_key lazily.
        # this is running on the client side in sweep and contains things such as job:id which
        # are only available there.
        subdir = str(OmegaConf.select(config, job_subdir_key))
        output_dir = os.path.join(output_dir, subdir)

    with read_write(config.hydra.runtime):
        with open_dict(config.hydra.runtime):
            config.hydra.runtime.output_dir = os.path.abspath(output_dir)

    HydraConfig.instance().set_config(config)
    _chdir = None
    try:
        ret = JobReturn()
        task_cfg = copy.deepcopy(config)
        with read_write(task_cfg):
            with open_dict(task_cfg):
                del task_cfg["hydra"]

        ret.cfg = task_cfg
        hydra_cfg = copy.deepcopy(HydraConfig.instance().cfg)
        assert isinstance(hydra_cfg, DictConfig)
        ret.hydra_cfg = hydra_cfg
        overrides = OmegaConf.to_container(config.hydra.overrides.task)
        assert isinstance(overrides, list)
        ret.overrides = overrides
        # handle output directories here
        Path(str(output_dir)).mkdir(parents=True, exist_ok=True)

        _chdir = hydra_cfg.hydra.job.chdir

        if _chdir is None:
            url = "https://hydra.cc/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir"
            deprecation_warning(
                message=dedent(f"""\
                    Hydra 1.3 will no longer change working directory at job runtime by default.
                    See {url} for more information."""),
                stacklevel=2,
            )
            _chdir = True

        if _chdir:
            os.chdir(output_dir)
            ret.working_dir = output_dir
        else:
            ret.working_dir = os.getcwd()

        if configure_logging:
            configure_log(config.hydra.job_logging, config.hydra.verbose)

        if config.hydra.output_subdir is not None:
            hydra_output = Path(config.hydra.runtime.output_dir) / Path(
                config.hydra.output_subdir)
            _save_config(task_cfg, "config.yaml", hydra_output)
            _save_config(hydra_cfg, "hydra.yaml", hydra_output)
            _save_config(config.hydra.overrides.task, "overrides.yaml",
                         hydra_output)

        with env_override(hydra_cfg.hydra.job.env_set):
            callbacks.on_job_start(config=config)
            try:
                ret.return_value = task_function(task_cfg)
                ret.status = JobStatus.COMPLETED
            except Exception as e:
                ret.return_value = e
                ret.status = JobStatus.FAILED

        ret.task_name = JobRuntime.instance().get("name")

        _flush_loggers()

        callbacks.on_job_end(config=config, job_return=ret)

        return ret
    finally:
        HydraConfig.instance().cfg = orig_hydra_cfg
        if _chdir:
            os.chdir(old_cwd)
Exemple #19
0
    def __init__(self, parent: Optional[Box], value: Any, metadata: Metadata):
        from omegaconf import read_write

        super().__init__(parent=parent, metadata=metadata)
        with read_write(self):
            self._set_value(value)  # lgtm [py/init-calls-subclass]
Exemple #20
0
def _main(cfg: DictConfig, output_file):
    logging.basicConfig(
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=os.environ.get("LOGLEVEL", "INFO").upper(),
        stream=output_file,
    )
    
    if 'label_dir' in cfg.task:
        manifest_dir, _ = os.path.split(cfg.dataset.gen_subset)
        with read_write(cfg):
            cfg.task.label_dir = os.path.join(cfg.task.data, manifest_dir)
        print('cfg.task.data', cfg.task.label_dir)
    logger = logging.getLogger("fairseq_cli.generate")

    utils.import_user_module(cfg.common)

    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.max_tokens = 12000
    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Load dataset splits
    task = tasks.setup_task(cfg.task)


    # Set dictionaries
    try:
        src_dict = getattr(task, "source_dictionary", None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    overrides = ast.literal_eval(cfg.common_eval.model_overrides)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, saved_cfg = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    token_type = None
    if type(models[0]) == Wav2Bart or type(models[0]) == WavTransBart or type(models[0]) == WavLinearBart or type(models[0]) == WavBart2Bart:
        token_type = 'bart'
    elif type(models[0]) == Wav2BartChr:
        token_type = 'chr'
    elif type(models[0]) == Wav2VecCtc or type(models[0]) == Wav2BertChr or type(models[0]) == Wav2BertMixChr:
        token_type = 'chrctc'
    elif type(models[0]) == Wav2Bert:
        token_type = 'bert'
    else:
        raise ValueError(f'token_type not defined for {type(models[0])}')
    print(f'token_type is {token_type}')


    # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
    task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)

    if cfg.generation.lm_path is not None:
        overrides["data"] = cfg.task.data

        try:
            lms, _ = checkpoint_utils.load_model_ensemble(
                [cfg.generation.lm_path], arg_overrides=overrides, task=None
            )
        except:
            logger.warning(
                f"Failed to load language model! Please make sure that the language model dict is the same "
                f"as target dict and is located in the data dir ({cfg.task.data})"
            )
            raise

        assert len(lms) == 1
    else:
        lms = [None]

    # Optimize ensemble for generation
    for model in chain(models, lms):
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    itr = task.get_batch_iterator(
        dataset=task.dataset(cfg.dataset.gen_subset),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[m.max_positions() for m in models]
        ),
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
        seed=cfg.common.seed,
        num_shards=cfg.distributed_training.distributed_world_size,
        shard_id=cfg.distributed_training.distributed_rank,
        num_workers=cfg.dataset.num_workers,
        data_buffer_size=cfg.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()

    extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
    print('cfg.generation', cfg.generation)

    
    # print(cfg.task._name == 'audio_pretraining')
    if cfg.task._name != 'audio_pretraining' and cfg.task._name != 'audio_pretraining_bertbpe':
        generator = task.build_generator(
            models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
        )
    else:
        print('use W2lViterbiDecoder')
        from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
        from easydict import EasyDict as edict
        args = edict({
            'criterion': 'ctc',
            'nbest': 1,
        })  
        generator = W2lViterbiDecoder(args, task.target_dictionary)
    # Handle tokenization and BPE
    tokenizer = task.build_tokenizer(cfg.tokenizer)
    bpe = task.build_bpe(cfg.bpe)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    scorer = scoring.build_scorer(cfg.scoring, tgt_dict)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for si, sample in enumerate(progress):
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if "net_input" not in sample:
            continue

        prefix_tokens = None
        if cfg.generation.prefix_size > 0:
            prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]

        constraints = None
        if "constraints" in sample:
            constraints = sample["constraints"]

        gen_timer.start()

        hypos = task.inference_step(
            generator,
            models,
            sample,
            prefix_tokens=prefix_tokens,
            constraints=constraints,
        )
        # print('hypos', hypos)
        num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
        gen_timer.stop(num_generated_tokens)
        for i, sample_id in enumerate(sample["id"].tolist()):
            has_target = sample["target"] is not None

            # Remove padding
            if "src_tokens" in sample["net_input"]:
                src_tokens = utils.strip_pad(
                    sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
                )
            else:
                src_tokens = None

            target_tokens = None
            if has_target:
                target_tokens = (
                    utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
                )

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
                    sample_id
                )
                target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
                    sample_id
                )
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
                else:
                    src_str = ""
                if has_target:
                    if token_type == 'chr':
                        target_str = tgt_dict.string(
                            target_tokens,
                            cfg.common_eval.post_process,
                            escape_unk=True,
                            extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                                generator
                            ),
                        )
                    elif token_type == 'bart':
                        target_str = task.bart.decode(target_tokens.int().cpu())
                    elif token_type == 'bert':
                        target_str = task.bert.decode(target_tokens.int().cpu())
                    elif token_type == 'chrctc':
                        target_str = tgt_dict.string(
                            target_tokens,
                            cfg.common_eval.post_process,
                            escape_unk=True,
                        )
                    else:
                        raise ValueError(f'token_type not defined for {type(models[0])}')

            src_str = decode_fn(src_str)
            
            if has_target and token_type == 'chr':
                target_str = decode_fn(target_str)
            elif has_target and token_type == 'chrctc':
                target_str = ''.join(target_str.split()).replace('|', ' ')

            if not cfg.common_eval.quiet:
                if src_dict is not None:
                    print("S-{}\t{}".format(sample_id, src_str), file=output_file)
                if has_target:
                    print("T-{}\t{}".format(sample_id, target_str), file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
                # print('align', hypo["alignment"])
                if token_type == 'bart':
                    hypo_tokens = hypo["tokens"].int().cpu()
                    hypo_str = task.bart.decode(hypo["tokens"].int().cpu())
                    alignment = hypo["alignment"]
                elif token_type == 'chr':
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo["tokens"].int().cpu(),
                        src_str=src_str,
                        alignment=hypo["alignment"],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=cfg.common_eval.post_process,
                        # extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
                    )
                elif token_type == 'chrctc':
                    hypo_tokens = hypo["tokens"].int().cpu()
                    hypo_str = task.target_dictionary.string(hypo_tokens)
                    hypo["positional_scores"] = torch.FloatTensor([0.])
                elif token_type == 'bert':
                    hypo_tokens = hypo["tokens"].int().cpu()
                    hypo_str = task.bert.decode(hypo["tokens"].int().cpu())
                    alignment = hypo["alignment"]
                else:
                    raise ValueError(f'token_type not defined for {type(models[0])}')

                detok_hypo_str = decode_fn(hypo_str)

                if token_type == 'chr' or token_type == 'chrctc':
                    print('target_str', ''.join(target_str.split()).replace('|', ' '))
                    print('typo_str', ''.join(detok_hypo_str.split()).replace('|', ' '))
                    detok_hypo_str = ''.join(detok_hypo_str.split()).replace('|', ' ')
                    # target_str = ''.join(target_str.split()).replace('|', ' ')
                elif token_type == 'bart':
                    print('target_str', target_str)
                    print('typo_str', detok_hypo_str)
                #elif token_type == 'chrctc':
                #    print('target_str', ''.join(target_str.split()).replace('|', ' '))
                #    print('typo_str', ''.join(detok_hypo_str.split()).replace('|', ' '))

                if not cfg.common_eval.quiet:
                    score = hypo["score"] / math.log(2)  # convert to base 2
                    # original hypothesis (after tokenization and BPE)
                    print(
                        "H-{}\t{}\t{}".format(sample_id, score, hypo_str),
                        file=output_file,
                    )
                    # detokenized hypothesis
                    print(
                        "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str),
                        file=output_file,
                    )
                    print(
                        "P-{}\t{}".format(
                            sample_id,
                            " ".join(
                                map(
                                    lambda x: "{:.4f}".format(x),
                                    # convert from base e to base 2
                                    hypo["positional_scores"]
                                    .div_(math.log(2))
                                    .tolist(),
                                )
                            ),
                        ),
                        file=output_file,
                    )

                    if cfg.generation.print_alignment == "hard":
                        print(
                            "A-{}\t{}".format(
                                sample_id,
                                " ".join(
                                    [
                                        "{}-{}".format(src_idx, tgt_idx)
                                        for src_idx, tgt_idx in alignment
                                    ]
                                ),
                            ),
                            file=output_file,
                        )
                    if cfg.generation.print_alignment == "soft":
                        print(
                            "A-{}\t{}".format(
                                sample_id,
                                " ".join(
                                    [
                                        ",".join(src_probs)
                                        for src_probs in alignment
                                    ]
                                ),
                            ),
                            file=output_file,
                        )

                    if cfg.generation.print_step:
                        print(
                            "I-{}\t{}".format(sample_id, hypo["steps"]),
                            file=output_file,
                        )

                    if cfg.generation.retain_iter_history:
                        for step, h in enumerate(hypo["history"]):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h["tokens"].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print(
                                "E-{}_{}\t{}".format(sample_id, step, h_str),
                                file=output_file,
                            )

                # Score only the top hypothesis
                if has_target and j == 0:
                    if align_dict is not None or cfg.common_eval.post_process is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True
                        )
                        hypo_tokens = tgt_dict.encode_line(
                            detok_hypo_str, add_if_not_exist=True
                        )
                    if hasattr(scorer, "add_string"):
                        # print('add_string 1', target_str, '2', detok_hypo_str)
                        # if si > 2:
                        #     raise
                        print('2', target_str, detok_hypo_str)
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({"wps": round(wps_meter.avg)})
        num_sentences += (
            sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
        )

    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info(
        "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
            num_sentences,
            gen_timer.n,
            gen_timer.sum,
            num_sentences / gen_timer.sum,
            1.0 / gen_timer.avg,
        )
    )
    if has_target:
        if cfg.bpe and not cfg.generation.sacrebleu:
            if cfg.common_eval.post_process:
                logger.warning(
                    "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
                )
            else:
                logger.warning(
                    "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words.  Use --sacrebleu for standard 13a BLEU tokenization"
                )
        # use print to be consistent with other main outputs: S-, H-, T-, D- and so on
        print(
            "Generate {} with beam={}: {}".format(
                cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
            ),
            file=output_file,
        )

    return scorer