예제 #1
0
def temporary_train_env(
    hyperparameters: Dict[str, Any],
    dataset_name: str,
) -> ContextManager[TrainEnv]:
    """
    A context manager that instantiates a training environment from a given
    combination of `hyperparameters` and `dataset_name` in a temporary
    directory and removes the directory on exit.

    Parameters
    ----------
    hyperparameters
        The hyperparameters to use when instantiating the
        training environment.
    dataset_name
        The name of the repository dataset to use when instantiating the
        training environment.

    Returns
    -------
    ContextManager[gluonts.shell.env.TrainEnv]
        A context manager that yields the `TrainEnv` instance.
    """

    with tempfile.TemporaryDirectory(prefix="gluonts-train-env") as base:
        paths = TrainPaths(base=Path(base))

        # write hyperparameters
        with paths.hyperparameters.open(mode="w") as fp:
            hps_encoded = encode_sagemaker_parameters(hyperparameters)
            json.dump(hps_encoded, fp, indent=2, sort_keys=True)

        # save dataset
        ds_path = materialize_dataset(dataset_name)

        path_metadata = paths.data / "metadata" / "metadata.json"
        path_train = paths.data / "train"
        path_test = paths.data / "test"

        path_metadata.parent.mkdir(exist_ok=True)

        path_metadata.symlink_to(ds_path / "metadata.json")
        path_train.symlink_to(ds_path / "train", target_is_directory=True)
        path_test.symlink_to(ds_path / "test", target_is_directory=True)

        yield TrainEnv(path=paths.base)
예제 #2
0
 def _materialize(self, directory: Path, regenerate: bool = False) -> None:
     materialize_dataset(self._gluonts_name,
                         directory,
                         regenerate=regenerate)