def train_command(data_path: str, forecaster: Optional[str]) -> None: from gluonts.shell import train logger.info("Run 'train' command") try: env = TrainEnv(Path(data_path)) if forecaster is None: try: forecaster = env.hyperparameters["forecaster_name"] except KeyError: msg = ( "Forecaster shell parameter is `None`, but " "the `forecaster_name` key is not defined in the " "hyperparameters.json dictionary." ) raise GluonTSForecasterNotFoundError(msg) train.run_train_and_test(env, forecaster_type_by_name(forecaster)) except Exception as error: with open( TrainPaths(Path(data_path)).output / "failure", "w" ) as out_file: out_file.write(str(error)) out_file.write("\n\n") out_file.write(traceback.format_exc()) raise
def test_train_shell(train_env: TrainEnv, caplog) -> None: run_train_and_test(env=train_env, forecaster_type=MeanPredictor) for _, _, line in caplog.record_tuples: if "#test_score (local, QuantileLoss" in line: assert line.endswith("0.0") if "local, wQuantileLoss" in line: assert line.endswith("0.0") if "local, Coverage" in line: assert line.endswith("0.0") if "MASE" in line or "MSIS" in line: assert line.endswith("0.0") if "abs_target_sum" in line: assert line.endswith("270.0")
def train_command(data_path: str, forecaster: Optional[str]) -> None: from gluonts.shell import train env = TrainEnv(Path(data_path)) if forecaster is None: try: forecaster = env.hyperparameters['forecaster_name'] except KeyError: msg = ("Forecaster shell parameter is `None`, but " "the `forecaster_name` key is not defined in the " "hyperparameters.json dictionary.") raise GluonTSForecasterNotFoundError(msg) assert forecaster is not None train.run_train_and_test(env, forecaster_type_by_name(forecaster))
def train_command(data_path: str, forecaster: str) -> None: from gluonts.shell import train env = SageMakerEnv(Path(data_path)) if forecaster == '%from_hyperparameters%': try: forecaster = env.hyperparameters['forecaster_name'] except KeyError: msg = ( "Forecaster shell parameter is '%from_hyperparameters%', but " "the `forecaster_name` key is not defined in the " "hyperparameters.json dictionary." ) raise GluonTSForecasterNotFoundError(msg) forecaster_type = forecaster_type_by_name(forecaster) train.run_train_and_test(env, forecaster_type)