Beispiel #1
0
    def _handle_artifacts(self, model, nemo_file_folder):
        tarfile_artifacts = []
        app_state = AppState()
        for conf_path, artiitem in model.artifacts.items():
            if artiitem.path_type == model_utils.ArtifactPathType.LOCAL_PATH:
                if not os.path.exists(artiitem.path):
                    raise FileNotFoundError(f"Artifact {conf_path} not found at location: {artiitem.path}")

                # Generate new uniq artifact name and copy it to nemo_file_folder
                # Note uuid.uuid4().hex is guaranteed to be 32 character long
                artifact_base_name = os.path.basename(artiitem.path)
                artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}"
                shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name))

                # Update artifacts registry
                artiitem.hashed_path = "nemo:" + artifact_uniq_name
                model.artifacts[conf_path] = artiitem

            elif artiitem.path_type == model_utils.ArtifactPathType.TAR_PATH:
                # process all tarfile artifacts in one go, so preserve key-value pair
                tarfile_artifacts.append((conf_path, artiitem))

            else:
                raise ValueError(f"Directly referencing artifacts from other nemo files isn't supported yet")

        # Process current tarfile artifacts by unpacking the previous tarfile and extract the artifacts
        # that are currently required.
        model_metadata = app_state.get_model_metadata_from_guid(model.model_guid)
        if len(tarfile_artifacts) > 0 and model_metadata.restoration_path is not None:
            # Need to step into nemo archive to extract file
            # Get path where the command is executed - the artifacts will be "retrieved" there
            # (original .nemo behavior)
            cwd = os.getcwd()
            try:
                # Step into the nemo archive to try and find the file
                with tempfile.TemporaryDirectory() as archive_dir:
                    self._unpack_nemo_file(path2file=model_metadata.restoration_path, out_folder=archive_dir)
                    os.chdir(archive_dir)
                    for conf_path, artiitem in tarfile_artifacts:
                        # Get basename and copy it to nemo_file_folder
                        if 'nemo:' in artiitem.path:
                            artifact_base_name = artiitem.path.split('nemo:')[1]
                        else:
                            artifact_base_name = os.path.basename(artiitem.path)
                        # no need to hash here as we are in tarfile_artifacts which are already hashed
                        artifact_uniq_name = artifact_base_name
                        shutil.copy2(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name))

                        # Update artifacts registry
                        new_artiitem = model_utils.ArtifactItem()
                        new_artiitem.path = "nemo:" + artifact_uniq_name
                        new_artiitem.path_type = model_utils.ArtifactPathType.TAR_PATH
                        model.artifacts[conf_path] = new_artiitem
            finally:
                # change back working directory
                os.chdir(cwd)
    def register_artifact(self,
                          model,
                          config_path: str,
                          src: str,
                          verify_src_exists: bool = True):
        """ Register model artifacts with this function. These artifacts (files) will be included inside .nemo file
            when model.save_to("mymodel.nemo") is called.        

            How it works:
            1. It always returns existing absolute path which can be used during Model constructor call
                EXCEPTION: src is None or "" in which case nothing will be done and src will be returned
            2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts

            If "src" is local existing path, then it will be returned in absolute path form.
            elif "src" starts with "nemo_file:unique_artifact_name":
                .nemo will be untarred to a temporary folder location and an actual existing path will be returned
            else an error will be raised.

            WARNING: use .register_artifact calls in your models' constructors.
            The returned path is not guaranteed to exist after you have exited your model's constuctor.

            Args:
                model: ModelPT object to register artifact for.
                config_path (str): Artifact key. Usually corresponds to the model config.
                src (str): Path to artifact.
                verify_src_exists (bool): If set to False, then the artifact is optional and register_artifact will return None even if 
                                          src is not found. Defaults to True.

            Returns:
                str: If src is not None or empty it always returns absolute path which is guaranteed to exists during model instnce life
        """
        app_state = AppState()

        artifact_item = model_utils.ArtifactItem()

        # This is for backward compatibility, if the src objects exists simply inside of the tarfile
        # without its key having been overriden, this pathway will be used.
        src_obj_name = os.path.basename(src)
        if app_state.nemo_file_folder is not None:
            src_obj_path = os.path.abspath(
                os.path.join(app_state.nemo_file_folder, src_obj_name))
        else:
            src_obj_path = src_obj_name

        # src is a local existing path - register artifact and return exact same path for usage by the model
        if os.path.exists(os.path.abspath(src)):
            return_path = os.path.abspath(src)
            artifact_item.path_type = model_utils.ArtifactPathType.LOCAL_PATH

        # this is the case when artifact must be retried from the nemo file
        # we are assuming that the location of the right nemo file is available from _MODEL_RESTORE_PATH
        elif src.startswith("nemo:"):
            return_path = os.path.abspath(
                os.path.join(app_state.nemo_file_folder, src[5:]))
            artifact_item.path_type = model_utils.ArtifactPathType.TAR_PATH

        # backward compatibility implementation
        elif os.path.exists(src_obj_path):
            return_path = src_obj_path
            artifact_item.path_type = model_utils.ArtifactPathType.TAR_PATH
        else:
            if verify_src_exists:
                raise FileNotFoundError(
                    f"src path does not exist or it is not a path in nemo file. src value I got was: {src}. Absolute: {os.path.abspath(src)}"
                )
            else:
                # artifact is optional and we simply return None
                return None

        assert os.path.exists(return_path)

        artifact_item.path = os.path.abspath(src)
        model.artifacts[config_path] = artifact_item
        # we were called by ModelPT
        if hasattr(model, "cfg"):
            with open_dict(model._cfg):
                OmegaConf.update(model.cfg, config_path, return_path)
        return return_path