示例#1
0
def test_blob_to_dir(tmpdir):
    test_dir = Path(tmpdir) / "test"
    test_dir.mkdir()
    test_file_name = "test.txt"
    test_file = test_dir / test_file_name
    file_contents = "test"
    test_file.write_text(file_contents)

    blob = dir_to_blob(test_dir)
    extract_path = test_dir / "test2"
    blob_to_dir(blob, extract_path)

    extracted_file = extract_path / test_file_name
    assert extracted_file.exists()
    assert extracted_file.read_text() == file_contents
示例#2
0
    def get_checkpoint(self, base_path: Optional[Path] = None) -> Path:
        """
        Return a filesystem path to our checkpoint, which can be used to initialize
        future models from the same state. If a base_path is provided, copy/extract
        the checkpoint under that path.

        NOTE: If no base_path is provided and the checkpoint comes from a remote
        worker, the checkpoint will be extracted to a temporary directory, and a
        warning will be emitted.  gobbli will make no effort to ensure the temporary
        directory is cleaned up after creation.

        Args:
          base_path: Optional directory to extract/copy the checkpoint to. If not provided,
            the original path will be returned if the checkpoint already existed on the
            current machine's filesystem.  If the checkpoint is a bytes object, a temporary
            directory will be created.  The directory must not already exist.

        Returns:
          The path to the extracted checkpoint.
        """
        if isinstance(self.best_model_checkpoint, bytes):
            if base_path is None:
                warnings.warn(
                    "No base_path provided; checkpoint extracting to temporary "
                    "directory.")
                base_path = Path(tempfile.mkdtemp())

            blob_to_dir(self.best_model_checkpoint, base_path)
            return base_path / self.best_model_checkpoint_name

        elif isinstance(self.best_model_checkpoint, Path):
            if base_path is None:
                base_path = self.best_model_checkpoint
            else:
                # Copy the checkpoint to the user-provided base path
                shutil.copytree(self.best_model_checkpoint, base_path)
            return base_path / self.best_model_checkpoint_name
        else:
            raise TypeError(
                f"unsupported checkpoint type: '{type(self.best_model_checkpoint)}'"
            )
示例#3
0
        def predict(
            X_test: List[str],
            test_batch_size: int,
            model_cls: Any,
            model_params: Dict[str, Any],
            labels: List[str],
            checkpoint: Union[bytes, Path],
            checkpoint_name: Optional[str],
            master_ip: str,
            gobbli_dir: Optional[Path] = None,
            log_level: Union[int, str] = logging.WARNING,
            distributed: bool = False,
        ) -> pd.DataFrame:

            logger = init_worker_env(gobbli_dir=gobbli_dir,
                                     log_level=log_level)
            use_gpu, nvidia_visible_devices = init_gpu_config()

            worker_ip = get_worker_ip()
            if not distributed and worker_ip != master_ip:
                raise RuntimeError(
                    "Experiments must be started with distributed = True to run "
                    "tasks on remote workers.")

            clf = model_cls(
                **model_params,
                use_gpu=use_gpu,
                nvidia_visible_devices=nvidia_visible_devices,
                logger=logger,
            )

            # This step isn't necessary in all cases if the build step just downloads
            # pretrained weights we weren't going to use anyway, but sometimes it's needed
            # Ex. for BERT to download vocabulary files and config
            clf.build()

            # Use the current working directory (CWD) as the base for the tempdir, under the
            # assumption that the CWD is included in any bind mounts/volumes the user may have
            # created if they're running this in a Docker container
            # If it's not part of a host mount, the files won't be mounted properly in the container
            with tempfile.TemporaryDirectory(dir=".") as tempdir:
                tempdir_path = Path(tempdir)

                checkpoint_path = None  # type: Optional[Path]
                if isinstance(checkpoint, bytes):
                    if checkpoint_name is not None:
                        blob_to_dir(checkpoint, tempdir_path)
                        checkpoint_path = tempdir_path / checkpoint_name
                elif isinstance(checkpoint, Path):
                    checkpoint_path = checkpoint
                elif checkpoint is None:
                    pass
                else:
                    raise TypeError(
                        f"invalid checkpoint type: '{type(checkpoint)}'")

                predict_input = gobbli.io.PredictInput(
                    X=X_test,
                    labels=labels,
                    checkpoint=checkpoint_path,
                    predict_batch_size=test_batch_size,
                )
                predict_output = clf.predict(predict_input)

            return predict_output.y_pred_proba