Exemplo n.º 1
0
def train_command(data_path: str, forecaster: Optional[str]) -> None:
    from gluonts.shell import train

    logger.info("Run 'train' command")

    try:
        env = TrainEnv(Path(data_path))
        if forecaster is None:
            try:
                forecaster = env.hyperparameters["forecaster_name"]
            except KeyError:
                msg = (
                    "Forecaster shell parameter is `None`, but "
                    "the `forecaster_name` key is not defined in the "
                    "hyperparameters.json dictionary."
                )
                raise GluonTSForecasterNotFoundError(msg)
        train.run_train_and_test(env, forecaster_type_by_name(forecaster))
    except Exception as error:
        with open(
            TrainPaths(Path(data_path)).output / "failure", "w"
        ) as out_file:
            out_file.write(str(error))
            out_file.write("\n\n")
            out_file.write(traceback.format_exc())
        raise
Exemplo n.º 2
0
def forecaster_type_by_name(name: str) -> Forecaster:
    """
    Loads a forecaster from the `gluonts_forecasters` entry_points namespace
    by name.

    If a forecater wasn't register under that name, it tries to locate the
    class.

    Third-party libraries can register their forecasters as follows by defining
    a corresponding section in the `entry_points` section of their `setup.py`::

        entry_points={
            'gluonts_forecasters': [
                'model_a = my_models.model_a:MyEstimator',
                'model_b = my_models.model_b:MyPredictor',
            ]
        }
    """
    forecaster = None

    for entry_point in pkg_resources.iter_entry_points("gluonts_forecasters"):
        if entry_point.name == name:
            forecaster = entry_point.load()
            break
    else:
        forecaster = pydoc.locate(name)

    if forecaster is None:
        raise GluonTSForecasterNotFoundError(
            f'Cannot locate estimator with classname "{name}".')

    return cast(Forecaster, forecaster)
Exemplo n.º 3
0
def forecaster_type_by_name(name: str) -> Type[Union[Estimator, Predictor]]:
    """
    Loads a forecaster from the `gluonts_forecasters` entry_points namespace
    by name.

    Third-party libraries can register their forecasters as follows by defining
    a corresponding section in the `entry_points` section of their `setup.py`::

        entry_points={
            'blogtool.parsers': [
                'model_a = my_models.model_a:MyEstimator',
                'model_b = my_models.model_b:MyPredictor',
            ]
        }
    """
    forecaster = None

    for entry_point in pkg_resources.iter_entry_points('gluonts_forecasters'):
        if entry_point.name == name:
            forecaster = entry_point.load()
            break

    if forecaster is None:
        msg = f'Cannot locate estimator with classname "{name}".'
        raise GluonTSForecasterNotFoundError(msg)

    return cast(Type[Union[Estimator, Predictor]], forecaster)
Exemplo n.º 4
0
def train_command(data_path: str, forecaster: Optional[str]) -> None:
    from gluonts.shell import train

    env = TrainEnv(Path(data_path))

    if forecaster is None:
        try:
            forecaster = env.hyperparameters['forecaster_name']
        except KeyError:
            msg = ("Forecaster shell parameter is `None`, but "
                   "the `forecaster_name` key is not defined in the "
                   "hyperparameters.json dictionary.")
            raise GluonTSForecasterNotFoundError(msg)

    assert forecaster is not None
    train.run_train_and_test(env, forecaster_type_by_name(forecaster))
Exemplo n.º 5
0
def train_command(data_path: str, forecaster: str) -> None:
    from gluonts.shell import train

    env = SageMakerEnv(Path(data_path))

    if forecaster == '%from_hyperparameters%':
        try:
            forecaster = env.hyperparameters['forecaster_name']
        except KeyError:
            msg = (
                "Forecaster shell parameter is '%from_hyperparameters%', but "
                "the `forecaster_name` key is not defined in the "
                "hyperparameters.json dictionary."
            )
            raise GluonTSForecasterNotFoundError(msg)

    forecaster_type = forecaster_type_by_name(forecaster)

    train.run_train_and_test(env, forecaster_type)