예제 #1
0
파일: __main__.py 프로젝트: youjp/gluon-ts
def train_command(data_path: str, forecaster: Optional[str]) -> None:
    from gluonts.shell import train

    logger.info("Run 'train' command")

    try:
        env = TrainEnv(Path(data_path))
        if forecaster is None:
            try:
                forecaster = env.hyperparameters["forecaster_name"]
            except KeyError:
                msg = (
                    "Forecaster shell parameter is `None`, but "
                    "the `forecaster_name` key is not defined in the "
                    "hyperparameters.json dictionary."
                )
                raise GluonTSForecasterNotFoundError(msg)
        train.run_train_and_test(env, forecaster_type_by_name(forecaster))
    except Exception as error:
        with open(
            TrainPaths(Path(data_path)).output / "failure", "w"
        ) as out_file:
            out_file.write(str(error))
            out_file.write("\n\n")
            out_file.write(traceback.format_exc())
        raise
예제 #2
0
def test_train_shell(train_env: TrainEnv, caplog) -> None:
    run_train_and_test(env=train_env, forecaster_type=MeanPredictor)

    for _, _, line in caplog.record_tuples:
        if "#test_score (local, QuantileLoss" in line:
            assert line.endswith("0.0")
        if "local, wQuantileLoss" in line:
            assert line.endswith("0.0")
        if "local, Coverage" in line:
            assert line.endswith("0.0")
        if "MASE" in line or "MSIS" in line:
            assert line.endswith("0.0")
        if "abs_target_sum" in line:
            assert line.endswith("270.0")
예제 #3
0
def train_command(data_path: str, forecaster: Optional[str]) -> None:
    from gluonts.shell import train

    env = TrainEnv(Path(data_path))

    if forecaster is None:
        try:
            forecaster = env.hyperparameters['forecaster_name']
        except KeyError:
            msg = ("Forecaster shell parameter is `None`, but "
                   "the `forecaster_name` key is not defined in the "
                   "hyperparameters.json dictionary.")
            raise GluonTSForecasterNotFoundError(msg)

    assert forecaster is not None
    train.run_train_and_test(env, forecaster_type_by_name(forecaster))
예제 #4
0
def train_command(data_path: str, forecaster: str) -> None:
    from gluonts.shell import train

    env = SageMakerEnv(Path(data_path))

    if forecaster == '%from_hyperparameters%':
        try:
            forecaster = env.hyperparameters['forecaster_name']
        except KeyError:
            msg = (
                "Forecaster shell parameter is '%from_hyperparameters%', but "
                "the `forecaster_name` key is not defined in the "
                "hyperparameters.json dictionary."
            )
            raise GluonTSForecasterNotFoundError(msg)

    forecaster_type = forecaster_type_by_name(forecaster)

    train.run_train_and_test(env, forecaster_type)