Exemple #1
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("onnx") is None:
            raise CortexException(api["name"], "onnx key not configured")

        _, prefix = ctx.storage.deconstruct_s3_path(api["onnx"]["model"])
        model_path = os.path.join(args.model_dir, os.path.basename(prefix))
        if api["onnx"].get("request_handler") is not None:
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"], args.project_dir)
        request_handler = local_cache.get("request_handler")

        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            cx_logger().info(
                "using pre_inference request handler provided in {}".format(
                    api["onnx"]["request_handler"]))
        else:
            cx_logger().info("pre_inference request handler not found")

        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            cx_logger().info(
                "using post_inference request handler provided in {}".format(
                    api["onnx"]["request_handler"]))
        else:
            cx_logger().info("post_inference request handler not found")

        sess = rt.InferenceSession(model_path)
        local_cache["sess"] = sess
        local_cache["input_metadata"] = sess.get_inputs()
        cx_logger().info("input_metadata: {}".format(
            truncate(extract_signature(local_cache["input_metadata"]))))
        local_cache["output_metadata"] = sess.get_outputs()
        cx_logger().info("output_metadata: {}".format(
            truncate(extract_signature(local_cache["output_metadata"]))))

    except Exception as e:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    cx_logger().info("API is ready")
    serve(app, listen="*:{}".format(args.port))
Exemple #2
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]

        local_cache["api"] = api
        local_cache["ctx"] = ctx
        if api.get("request_handler_impl_key") is not None:
            package.install_packages(ctx.python_packages, ctx.storage)
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"])

        model_cache_path = os.path.join(args.model_dir, args.api)
        if not os.path.exists(model_cache_path):
            ctx.storage.download_file_external(api["model"], model_cache_path)

        sess = rt.InferenceSession(model_cache_path)
        local_cache["sess"] = sess
        local_cache["input_metadata"] = sess.get_inputs()
        local_cache["output_metadata"] = sess.get_outputs()
    except CortexException as e:
        e.wrap("error")
        logger.error(str(e))
        if api is not None:
            logger.exception(
                "An error occured starting the api, see `cx logs -v api {}` for more details"
                .format(api["name"]))
        sys.exit(1)

    serve(app, listen="*:{}".format(args.port))
Exemple #3
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("request_handler") is not None:
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"], args.project_dir)
    except Exception as e:
        logger.exception("failed to start api")
        sys.exit(1)

    try:
        validate_model_dir(args.model_dir)
    except Exception as e:
        logger.exception("failed to validate model")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            logger.warn("an error occurred while attempting to load classes",
                        exc_info=True)

    channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
    local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(
        channel)

    # wait a bit for tf serving to start before querying metadata
    limit = 60
    for i in range(limit):
        try:
            local_cache["metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i > 6:
                logger.warn(
                    "unable to read model metadata - model is still loading. Retrying..."
                )
            if i == limit - 1:
                logger.exception("retry limit exceeded")
                sys.exit(1)

        time.sleep(5)
    logger.info("model_signature: {}".format(
        extract_signature(
            local_cache["metadata"]["signatureDef"],
            local_cache["api"]["tf_serving"]["signature_key"],
        )))
    serve(app, listen="*:{}".format(args.port))
Exemple #4
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api["predictor"]["type"] != "onnx":
            raise CortexException(api["name"], "predictor type is not onnx")

        cx_logger().info("loading the predictor from {}".format(api["predictor"]["path"]))

        _, prefix = ctx.storage.deconstruct_s3_path(api["predictor"]["model"])
        model_path = os.path.join(args.model_dir, os.path.basename(prefix))
        local_cache["client"] = ONNXClient(model_path)

        predictor_class = ctx.get_predictor_class(api["name"], args.project_dir)

        try:
            local_cache["predictor"] = predictor_class(
                local_cache["client"], api["predictor"]["config"]
            )
        except Exception as e:
            raise UserRuntimeException(api["predictor"]["path"], "__init__", str(e)) from e
        finally:
            refresh_logger()
    except Exception as e:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get("model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn("an error occurred while attempting to load classes", exc_info=True)

    cx_logger().info("ONNX model signature: {}".format(local_cache["client"].input_signature))

    waitress_kwargs = {}
    if api["predictor"].get("config") is not None:
        for key, value in api["predictor"]["config"].items():
            if key.startswith("waitress_"):
                waitress_kwargs[key[len("waitress_") :]] = value

    if len(waitress_kwargs) > 0:
        cx_logger().info("waitress parameters: {}".format(waitress_kwargs))

    waitress_kwargs["listen"] = "*:{}".format(args.port)

    cx_logger().info("{} api is live".format(api["name"]))
    open("/health_check.txt", "a").close()
    serve(app, **waitress_kwargs)
Exemple #5
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("predictor") is None:
            raise CortexException(api["name"], "predictor key not configured")

        cx_logger().info("loading the predictor from {}".format(
            api["predictor"]["path"]))
        local_cache["predictor"] = ctx.get_predictor_impl(
            api["name"], args.project_dir)

        if util.has_function(local_cache["predictor"], "init"):
            try:
                model_path = None
                if api["predictor"].get("model") is not None:
                    _, prefix = ctx.storage.deconstruct_s3_path(
                        api["predictor"]["model"])
                    model_path = os.path.join(
                        args.model_dir,
                        os.path.basename(os.path.normpath(prefix)))

                cx_logger().info("calling the predictor's init() function")
                local_cache["predictor"].init(model_path,
                                              api["predictor"]["metadata"])
            except Exception as e:
                raise UserRuntimeException(api["predictor"]["path"], "init",
                                           str(e)) from e
            finally:
                refresh_logger()
    except:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    cx_logger().info("{} api is live".format(api["name"]))
    serve(app, listen="*:{}".format(args.port))
Exemple #6
0
def main():
    parser = argparse.ArgumentParser()
    na = parser.add_argument_group("required named arguments")
    na.add_argument("--workload-id", required=True, help="Workload ID")
    na.add_argument(
        "--context",
        required=True,
        help="S3 path to context (e.g. s3://bucket/path/to/context.json")
    na.add_argument("--cache-dir",
                    required=True,
                    help="Local path for the context cache")
    na.add_argument("--python-packages",
                    help="Resource ids of packages to build")
    na.add_argument("--build",
                    action="store_true",
                    help="Flag to determine mode (build vs install)")

    args, _ = parser.parse_known_args()
    if args.build:
        build(args)
    else:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        install_packages(ctx.python_packages, ctx.storage)
Exemple #7
0
def run_job(args):
    should_ingest, cols_to_validate, cols_to_aggregate, cols_to_transform, training_datasets = parse_args(
        args)

    resource_id_list = cols_to_validate + cols_to_aggregate + cols_to_transform + training_datasets

    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
    except Exception as e:
        logger.exception("An error occurred, see the logs for more details.")
        sys.exit(1)

    try:
        spark = None  # For the finally clause
        spark = get_spark_session(ctx.workload_id)
        spark.sparkContext.parallelize(
            [1, 2, 3, 4, 5]).count()  # test that executors are allocated
        raw_df = ingest_raw_dataset(spark, ctx, cols_to_validate,
                                    should_ingest)

        if len(cols_to_aggregate) > 0:
            run_aggregators(spark, ctx, cols_to_aggregate, raw_df)

        if len(cols_to_transform) > 0:
            validate_transformers(spark, ctx, cols_to_transform, raw_df)

        create_training_datasets(spark, ctx, training_datasets, raw_df)

        util.log_job_finished(ctx.workload_id)
    except CortexException as e:
        e.wrap("error")
        logger.error(str(e))
        logger.exception(
            "An error occurred, see `cortex logs -v {} {}` for more details.".
            format(
                ctx.id_map[resource_id_list[0]]["resource_type"],
                ctx.id_map[resource_id_list[0]]["name"],
            ))
        sys.exit(1)
    except Exception as e:
        logger.exception(
            "An error occurred, see `cortex logs -v {} {}` for more details.".
            format(
                ctx.id_map[resource_id_list[0]]["resource_type"],
                ctx.id_map[resource_id_list[0]]["name"],
            ))
        sys.exit(1)
    finally:
        if spark is not None:
            spark.stop()
Exemple #8
0
def start(args):

    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        _, prefix = ctx.storage.deconstruct_s3_path(api["model"])
        model_path = os.path.join(args.model_dir, os.path.basename(prefix))
        if api.get("request_handler") is not None:
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"], args.project_dir)

        sess = rt.InferenceSession(model_path)
        local_cache["sess"] = sess
        local_cache["input_metadata"] = sess.get_inputs()
        logger.info("input_metadata: {}".format(
            truncate(extract_signature(local_cache["input_metadata"]))))
        local_cache["output_metadata"] = sess.get_outputs()
        logger.info("output_metadata: {}".format(
            truncate(extract_signature(local_cache["output_metadata"]))))

    except Exception as e:
        logger.exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            logger.warn("an error occurred while attempting to load classes",
                        exc_info=True)

    serve(app, listen="*:{}".format(args.port))
Exemple #9
0
def build(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)
    python_packages_list = [
        ctx.pp_id_map[id] for id in args.python_packages.split(",")
    ]
    python_packages = {
        python_package["name"]: python_package
        for python_package in python_packages_list
    }
    ctx.upload_resource_status_start(*python_packages_list)
    try:
        build_packages(python_packages, ctx.storage)
        util.log_job_finished(ctx.workload_id)
    except CortexException as e:
        e.wrap("error")
        logger.exception(e)
        ctx.upload_resource_status_failed(*python_packages_list)
    except Exception as e:
        logger.exception(e)
        ctx.upload_resource_status_failed(*python_packages_list)
    else:
        ctx.upload_resource_status_success(*python_packages_list)
Exemple #10
0
def start(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)
    api = ctx.apis_id_map[args.api]

    local_cache["api"] = api
    local_cache["ctx"] = ctx
    if api.get("request_handler_impl_key") is not None:
        package.install_packages(ctx.python_packages, ctx.storage)
        local_cache["request_handler"] = ctx.get_request_handler_impl(
            api["name"])

    model_cache_path = os.path.join(args.model_dir, args.api)
    if not os.path.exists(model_cache_path):
        ctx.storage.download_file_external(api["model"], model_cache_path)

    sess = rt.InferenceSession(model_cache_path)
    local_cache["sess"] = sess
    local_cache["input_metadata"] = sess.get_inputs()
    local_cache["output_metadata"] = sess.get_outputs()
    logger.info("Serving model: {}".format(
        util.remove_resource_ref(api["model"])))
    serve(app, listen="*:{}".format(args.port))
Exemple #11
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("tensorflow") is None:
            raise CortexException(api["name"], "tensorflow key not configured")

        if api["tensorflow"].get("request_handler") is not None:
            cx_logger().info("loading the request handler from {}".format(
                api["tensorflow"]["request_handler"]))
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"], args.project_dir)
        request_handler = local_cache.get("request_handler")

        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            cx_logger().info(
                "using pre_inference request handler defined in {}".format(
                    api["tensorflow"]["request_handler"]))
        else:
            cx_logger().info("pre_inference request handler not defined")

        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            cx_logger().info(
                "using post_inference request handler defined in {}".format(
                    api["tensorflow"]["request_handler"]))
        else:
            cx_logger().info("post_inference request handler not defined")

    except Exception as e:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    try:
        validate_model_dir(args.model_dir)
    except Exception as e:
        cx_logger().exception("failed to validate model")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
    local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(
        channel)

    # wait a bit for tf serving to start before querying metadata
    limit = 60
    for i in range(limit):
        try:
            local_cache["model_metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i > 6:
                cx_logger().warn(
                    "unable to read model metadata - model is still loading. Retrying..."
                )
            if i == limit - 1:
                cx_logger().exception("retry limit exceeded")
                sys.exit(1)

        time.sleep(5)

    signature_key, parsed_signature = extract_signature(
        local_cache["model_metadata"]["signatureDef"],
        api["tensorflow"]["signature_key"])

    local_cache["signature_key"] = signature_key
    local_cache["parsed_signature"] = parsed_signature
    cx_logger().info("model_signature: {}".format(
        local_cache["parsed_signature"]))

    cx_logger().info("{} API is live".format(api["name"]))
    serve(app, listen="*:{}".format(args.port))
Exemple #12
0
def train(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)

    package.install_packages(ctx.python_packages, ctx.storage)

    model = ctx.models_id_map[args.model]

    logger.info("Training")

    with util.Tempdir(ctx.cache_dir) as temp_dir:
        model_dir = os.path.join(temp_dir, "model_dir")
        ctx.upload_resource_status_start(model)

        try:
            estimator_impl, _ = ctx.get_estimator_impl(model["name"])
            train_util.train(model["name"], estimator_impl, ctx, model_dir)
            ctx.upload_resource_status_success(model)

            logger.info("Caching")
            logger.info("Caching model " + model["name"])
            model_export_dir = os.path.join(model_dir, "export", "estimator")
            model_zip_path = os.path.join(temp_dir, "model.zip")
            util.zip_dir(model_export_dir, model_zip_path)

            ctx.storage.upload_file(model_zip_path, model["key"])
            util.log_job_finished(ctx.workload_id)

        except CortexException as e:
            ctx.upload_resource_status_failed(model)
            e.wrap("error")
            logger.error(str(e))
            logger.exception(
                "An error occurred, see `cortex logs -v model {}` for more details."
                .format(model["name"]))
            sys.exit(1)
        except Exception as e:
            ctx.upload_resource_status_failed(model)
            logger.exception(
                "An error occurred, see `cortex logs -v model {}` for more details."
                .format(model["name"]))
            sys.exit(1)
Exemple #13
0
def start(args):
    ctx = Context(s3_path=args.context,
                  cache_dir=args.cache_dir,
                  workload_id=args.workload_id)

    api = ctx.apis_id_map[args.api]
    local_cache["api"] = api
    local_cache["ctx"] = ctx

    try:
        if api.get("request_handler_impl_key") is not None:
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"])

        if not util.is_resource_ref(api["model"]):
            if api.get("request_handler") is not None:
                package.install_packages(ctx.python_packages, ctx.storage)
            if not os.path.isdir(args.model_dir):
                ctx.storage.download_and_unzip_external(
                    api["model"], args.model_dir)
        else:
            package.install_packages(ctx.python_packages, ctx.storage)
            model_name = util.get_resource_ref(api["model"])
            model = ctx.models[model_name]
            estimator = ctx.estimators[model["estimator"]]

            local_cache["model"] = model
            local_cache["estimator"] = estimator
            local_cache["target_col"] = ctx.columns[util.get_resource_ref(
                model["target_column"])]
            local_cache["target_col_type"] = ctx.get_inferred_column_type(
                util.get_resource_ref(model["target_column"]))

            log_level = "DEBUG"
            if ctx.environment is not None and ctx.environment.get(
                    "log_level") is not None:
                log_level = ctx.environment["log_level"].get(
                    "tensorflow", "DEBUG")
            tf_lib.set_logging_verbosity(log_level)

            if not os.path.isdir(args.model_dir):
                ctx.storage.download_and_unzip(model["key"], args.model_dir)

            for column_name in ctx.extract_column_names(
                [model["input"], model["target_column"]]):
                if ctx.is_transformed_column(column_name):
                    trans_impl, _ = ctx.get_transformer_impl(column_name)
                    local_cache["trans_impls"][column_name] = trans_impl
                    transformed_column = ctx.transformed_columns[column_name]

                    # cache aggregate values
                    for resource_name in util.extract_resource_refs(
                            transformed_column["input"]):
                        if resource_name in ctx.aggregates:
                            ctx.get_obj(ctx.aggregates[resource_name]["key"])

            local_cache["required_inputs"] = tf_lib.get_base_input_columns(
                model["name"], ctx)

            if util.is_dict(model["input"]) and model["input"].get(
                    "target_vocab") is not None:
                local_cache["target_vocab_populated"] = ctx.populate_values(
                    model["input"]["target_vocab"], None, False)
    except CortexException as e:
        e.wrap("error")
        logger.error(str(e))
        logger.exception(
            "An error occurred, see `cortex logs -v api {}` for more details.".
            format(api["name"]))
        sys.exit(1)
    except Exception as e:
        logger.exception(
            "An error occurred, see `cortex logs -v api {}` for more details.".
            format(api["name"]))
        sys.exit(1)

    try:
        validate_model_dir(args.model_dir)
    except Exception as e:
        logger.exception(e)
        sys.exit(1)

    channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
    local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(
        channel)

    # wait a bit for tf serving to start before querying metadata
    limit = 300
    for i in range(limit):
        try:
            local_cache["metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i == limit - 1:
                logger.exception(
                    "An error occurred, see `cortex logs -v api {}` for more details."
                    .format(api["name"]))
                sys.exit(1)

        time.sleep(1)

    serve(app, listen="*:{}".format(args.port))
Exemple #14
0
 def _get_context(d):
     return Context(obj=d, cache_dir=".")
Exemple #15
0
def test_simple_end_to_end(spark):
    local_storage_path = Path("/workspace/local_storage")
    local_storage_path.mkdir(parents=True, exist_ok=True)
    should_ingest = True
    input_data_path = os.path.join(str(local_storage_path), "insurance.csv")

    raw_ctx = insurance_context.get(input_data_path)

    workload_id = raw_ctx["raw_columns"]["raw_string_columns"]["smoker"][
        "workload_id"]

    cols_to_validate = []

    for column_type in raw_ctx["raw_columns"].values():
        for raw_column in column_type.values():
            cols_to_validate.append(raw_column["id"])

    insurance_data_string = "\n".join(",".join(str(val) for val in line)
                                      for line in insurance_data)
    Path(os.path.join(str(local_storage_path),
                      "insurance.csv")).write_text(insurance_data_string)

    ctx = Context(raw_obj=raw_ctx,
                  cache_dir="/workspace/cache",
                  local_storage_path=str(local_storage_path))
    storage = ctx.storage

    raw_df = spark_job.ingest_raw_dataset(spark, ctx, cols_to_validate,
                                          should_ingest)

    assert raw_df.count() == 15
    assert ctx.get_metadata(ctx.raw_dataset["key"])["dataset_size"] == 15
    for raw_column_id in cols_to_validate:
        path = os.path.join(raw_ctx["status_prefix"], raw_column_id,
                            workload_id)
        status = storage.get_json(str(path))
        status["resource_id"] = raw_column_id
        status["exist_code"] = "succeeded"

    cols_to_aggregate = [r["id"] for r in raw_ctx["aggregates"].values()]

    spark_job.run_aggregators(spark, ctx, cols_to_aggregate, raw_df)

    for aggregate_id in cols_to_aggregate:
        for aggregate_resource in raw_ctx["aggregates"].values():
            if aggregate_resource["id"] == aggregate_id:
                assert local_storage_path.joinpath(
                    aggregate_resource["key"]).exists()
        path = os.path.join(raw_ctx["status_prefix"], aggregate_id,
                            workload_id)
        status = storage.get_json(str(path))
        status["resource_id"] = aggregate_id
        status["exist_code"] = "succeeded"

    cols_to_transform = [
        r["id"] for r in raw_ctx["transformed_columns"].values()
    ]
    spark_job.validate_transformers(spark, ctx, cols_to_transform, raw_df)

    for transformed_id in cols_to_transform:
        path = os.path.join(raw_ctx["status_prefix"], transformed_id,
                            workload_id)
        status = storage.get_json(str(path))
        status["resource_id"] = transformed_id
        status["exist_code"] = "succeeded"

    training_datasets = [raw_ctx["models"]["dnn"]["dataset"]["id"]]

    spark_job.create_training_datasets(spark, ctx, training_datasets, raw_df)

    for dataset_id in training_datasets:
        path = os.path.join(raw_ctx["status_prefix"], transformed_id,
                            workload_id)
        status = storage.get_json(str(path))
        status["resource_id"] = transformed_id
        status["exist_code"] = "succeeded"

        dataset = raw_ctx["models"]["dnn"]["dataset"]
        metadata = ctx.get_metadata(dataset["id"])
        assert metadata["training_size"] + metadata["eval_size"] == 15
        assert local_storage_path.joinpath(dataset["train_key"],
                                           "_SUCCESS").exists()
        assert local_storage_path.joinpath(dataset["eval_key"],
                                           "_SUCCESS").exists()