def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api.get("onnx") is None: raise CortexException(api["name"], "onnx key not configured") _, prefix = ctx.storage.deconstruct_s3_path(api["onnx"]["model"]) model_path = os.path.join(args.model_dir, os.path.basename(prefix)) if api["onnx"].get("request_handler") is not None: local_cache["request_handler"] = ctx.get_request_handler_impl( api["name"], args.project_dir) request_handler = local_cache.get("request_handler") if request_handler is not None and util.has_function( request_handler, "pre_inference"): cx_logger().info( "using pre_inference request handler provided in {}".format( api["onnx"]["request_handler"])) else: cx_logger().info("pre_inference request handler not found") if request_handler is not None and util.has_function( request_handler, "post_inference"): cx_logger().info( "using post_inference request handler provided in {}".format( api["onnx"]["request_handler"])) else: cx_logger().info("post_inference request handler not found") sess = rt.InferenceSession(model_path) local_cache["sess"] = sess local_cache["input_metadata"] = sess.get_inputs() cx_logger().info("input_metadata: {}".format( truncate(extract_signature(local_cache["input_metadata"])))) local_cache["output_metadata"] = sess.get_outputs() cx_logger().info("output_metadata: {}".format( truncate(extract_signature(local_cache["output_metadata"])))) except Exception as e: cx_logger().exception("failed to start api") sys.exit(1) if api.get("tracker") is not None and api["tracker"].get( "model_type") == "classification": try: local_cache["class_set"] = api_utils.get_classes(ctx, api["name"]) except Exception as e: cx_logger().warn( "an error occurred while attempting to load classes", exc_info=True) cx_logger().info("API is ready") serve(app, listen="*:{}".format(args.port))
def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api.get("request_handler_impl_key") is not None: package.install_packages(ctx.python_packages, ctx.storage) local_cache["request_handler"] = ctx.get_request_handler_impl( api["name"]) model_cache_path = os.path.join(args.model_dir, args.api) if not os.path.exists(model_cache_path): ctx.storage.download_file_external(api["model"], model_cache_path) sess = rt.InferenceSession(model_cache_path) local_cache["sess"] = sess local_cache["input_metadata"] = sess.get_inputs() local_cache["output_metadata"] = sess.get_outputs() except CortexException as e: e.wrap("error") logger.error(str(e)) if api is not None: logger.exception( "An error occured starting the api, see `cx logs -v api {}` for more details" .format(api["name"])) sys.exit(1) serve(app, listen="*:{}".format(args.port))
def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api.get("request_handler") is not None: local_cache["request_handler"] = ctx.get_request_handler_impl( api["name"], args.project_dir) except Exception as e: logger.exception("failed to start api") sys.exit(1) try: validate_model_dir(args.model_dir) except Exception as e: logger.exception("failed to validate model") sys.exit(1) if api.get("tracker") is not None and api["tracker"].get( "model_type") == "classification": try: local_cache["class_set"] = api_utils.get_classes(ctx, api["name"]) except Exception as e: logger.warn("an error occurred while attempting to load classes", exc_info=True) channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port)) local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub( channel) # wait a bit for tf serving to start before querying metadata limit = 60 for i in range(limit): try: local_cache["metadata"] = run_get_model_metadata() break except Exception as e: if i > 6: logger.warn( "unable to read model metadata - model is still loading. Retrying..." ) if i == limit - 1: logger.exception("retry limit exceeded") sys.exit(1) time.sleep(5) logger.info("model_signature: {}".format( extract_signature( local_cache["metadata"]["signatureDef"], local_cache["api"]["tf_serving"]["signature_key"], ))) serve(app, listen="*:{}".format(args.port))
def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api["predictor"]["type"] != "onnx": raise CortexException(api["name"], "predictor type is not onnx") cx_logger().info("loading the predictor from {}".format(api["predictor"]["path"])) _, prefix = ctx.storage.deconstruct_s3_path(api["predictor"]["model"]) model_path = os.path.join(args.model_dir, os.path.basename(prefix)) local_cache["client"] = ONNXClient(model_path) predictor_class = ctx.get_predictor_class(api["name"], args.project_dir) try: local_cache["predictor"] = predictor_class( local_cache["client"], api["predictor"]["config"] ) except Exception as e: raise UserRuntimeException(api["predictor"]["path"], "__init__", str(e)) from e finally: refresh_logger() except Exception as e: cx_logger().exception("failed to start api") sys.exit(1) if api.get("tracker") is not None and api["tracker"].get("model_type") == "classification": try: local_cache["class_set"] = api_utils.get_classes(ctx, api["name"]) except Exception as e: cx_logger().warn("an error occurred while attempting to load classes", exc_info=True) cx_logger().info("ONNX model signature: {}".format(local_cache["client"].input_signature)) waitress_kwargs = {} if api["predictor"].get("config") is not None: for key, value in api["predictor"]["config"].items(): if key.startswith("waitress_"): waitress_kwargs[key[len("waitress_") :]] = value if len(waitress_kwargs) > 0: cx_logger().info("waitress parameters: {}".format(waitress_kwargs)) waitress_kwargs["listen"] = "*:{}".format(args.port) cx_logger().info("{} api is live".format(api["name"])) open("/health_check.txt", "a").close() serve(app, **waitress_kwargs)
def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api.get("predictor") is None: raise CortexException(api["name"], "predictor key not configured") cx_logger().info("loading the predictor from {}".format( api["predictor"]["path"])) local_cache["predictor"] = ctx.get_predictor_impl( api["name"], args.project_dir) if util.has_function(local_cache["predictor"], "init"): try: model_path = None if api["predictor"].get("model") is not None: _, prefix = ctx.storage.deconstruct_s3_path( api["predictor"]["model"]) model_path = os.path.join( args.model_dir, os.path.basename(os.path.normpath(prefix))) cx_logger().info("calling the predictor's init() function") local_cache["predictor"].init(model_path, api["predictor"]["metadata"]) except Exception as e: raise UserRuntimeException(api["predictor"]["path"], "init", str(e)) from e finally: refresh_logger() except: cx_logger().exception("failed to start api") sys.exit(1) if api.get("tracker") is not None and api["tracker"].get( "model_type") == "classification": try: local_cache["class_set"] = api_utils.get_classes(ctx, api["name"]) except Exception as e: cx_logger().warn( "an error occurred while attempting to load classes", exc_info=True) cx_logger().info("{} api is live".format(api["name"])) serve(app, listen="*:{}".format(args.port))
def main(): parser = argparse.ArgumentParser() na = parser.add_argument_group("required named arguments") na.add_argument("--workload-id", required=True, help="Workload ID") na.add_argument( "--context", required=True, help="S3 path to context (e.g. s3://bucket/path/to/context.json") na.add_argument("--cache-dir", required=True, help="Local path for the context cache") na.add_argument("--python-packages", help="Resource ids of packages to build") na.add_argument("--build", action="store_true", help="Flag to determine mode (build vs install)") args, _ = parser.parse_known_args() if args.build: build(args) else: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) install_packages(ctx.python_packages, ctx.storage)
def run_job(args): should_ingest, cols_to_validate, cols_to_aggregate, cols_to_transform, training_datasets = parse_args( args) resource_id_list = cols_to_validate + cols_to_aggregate + cols_to_transform + training_datasets try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) except Exception as e: logger.exception("An error occurred, see the logs for more details.") sys.exit(1) try: spark = None # For the finally clause spark = get_spark_session(ctx.workload_id) spark.sparkContext.parallelize( [1, 2, 3, 4, 5]).count() # test that executors are allocated raw_df = ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest) if len(cols_to_aggregate) > 0: run_aggregators(spark, ctx, cols_to_aggregate, raw_df) if len(cols_to_transform) > 0: validate_transformers(spark, ctx, cols_to_transform, raw_df) create_training_datasets(spark, ctx, training_datasets, raw_df) util.log_job_finished(ctx.workload_id) except CortexException as e: e.wrap("error") logger.error(str(e)) logger.exception( "An error occurred, see `cortex logs -v {} {}` for more details.". format( ctx.id_map[resource_id_list[0]]["resource_type"], ctx.id_map[resource_id_list[0]]["name"], )) sys.exit(1) except Exception as e: logger.exception( "An error occurred, see `cortex logs -v {} {}` for more details.". format( ctx.id_map[resource_id_list[0]]["resource_type"], ctx.id_map[resource_id_list[0]]["name"], )) sys.exit(1) finally: if spark is not None: spark.stop()
def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx _, prefix = ctx.storage.deconstruct_s3_path(api["model"]) model_path = os.path.join(args.model_dir, os.path.basename(prefix)) if api.get("request_handler") is not None: local_cache["request_handler"] = ctx.get_request_handler_impl( api["name"], args.project_dir) sess = rt.InferenceSession(model_path) local_cache["sess"] = sess local_cache["input_metadata"] = sess.get_inputs() logger.info("input_metadata: {}".format( truncate(extract_signature(local_cache["input_metadata"])))) local_cache["output_metadata"] = sess.get_outputs() logger.info("output_metadata: {}".format( truncate(extract_signature(local_cache["output_metadata"])))) except Exception as e: logger.exception("failed to start api") sys.exit(1) if api.get("tracker") is not None and api["tracker"].get( "model_type") == "classification": try: local_cache["class_set"] = api_utils.get_classes(ctx, api["name"]) except Exception as e: logger.warn("an error occurred while attempting to load classes", exc_info=True) serve(app, listen="*:{}".format(args.port))
def build(args): ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) python_packages_list = [ ctx.pp_id_map[id] for id in args.python_packages.split(",") ] python_packages = { python_package["name"]: python_package for python_package in python_packages_list } ctx.upload_resource_status_start(*python_packages_list) try: build_packages(python_packages, ctx.storage) util.log_job_finished(ctx.workload_id) except CortexException as e: e.wrap("error") logger.exception(e) ctx.upload_resource_status_failed(*python_packages_list) except Exception as e: logger.exception(e) ctx.upload_resource_status_failed(*python_packages_list) else: ctx.upload_resource_status_success(*python_packages_list)
def start(args): ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api.get("request_handler_impl_key") is not None: package.install_packages(ctx.python_packages, ctx.storage) local_cache["request_handler"] = ctx.get_request_handler_impl( api["name"]) model_cache_path = os.path.join(args.model_dir, args.api) if not os.path.exists(model_cache_path): ctx.storage.download_file_external(api["model"], model_cache_path) sess = rt.InferenceSession(model_cache_path) local_cache["sess"] = sess local_cache["input_metadata"] = sess.get_inputs() local_cache["output_metadata"] = sess.get_outputs() logger.info("Serving model: {}".format( util.remove_resource_ref(api["model"]))) serve(app, listen="*:{}".format(args.port))
def start(args): api = None try: ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx if api.get("tensorflow") is None: raise CortexException(api["name"], "tensorflow key not configured") if api["tensorflow"].get("request_handler") is not None: cx_logger().info("loading the request handler from {}".format( api["tensorflow"]["request_handler"])) local_cache["request_handler"] = ctx.get_request_handler_impl( api["name"], args.project_dir) request_handler = local_cache.get("request_handler") if request_handler is not None and util.has_function( request_handler, "pre_inference"): cx_logger().info( "using pre_inference request handler defined in {}".format( api["tensorflow"]["request_handler"])) else: cx_logger().info("pre_inference request handler not defined") if request_handler is not None and util.has_function( request_handler, "post_inference"): cx_logger().info( "using post_inference request handler defined in {}".format( api["tensorflow"]["request_handler"])) else: cx_logger().info("post_inference request handler not defined") except Exception as e: cx_logger().exception("failed to start api") sys.exit(1) try: validate_model_dir(args.model_dir) except Exception as e: cx_logger().exception("failed to validate model") sys.exit(1) if api.get("tracker") is not None and api["tracker"].get( "model_type") == "classification": try: local_cache["class_set"] = api_utils.get_classes(ctx, api["name"]) except Exception as e: cx_logger().warn( "an error occurred while attempting to load classes", exc_info=True) channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port)) local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub( channel) # wait a bit for tf serving to start before querying metadata limit = 60 for i in range(limit): try: local_cache["model_metadata"] = run_get_model_metadata() break except Exception as e: if i > 6: cx_logger().warn( "unable to read model metadata - model is still loading. Retrying..." ) if i == limit - 1: cx_logger().exception("retry limit exceeded") sys.exit(1) time.sleep(5) signature_key, parsed_signature = extract_signature( local_cache["model_metadata"]["signatureDef"], api["tensorflow"]["signature_key"]) local_cache["signature_key"] = signature_key local_cache["parsed_signature"] = parsed_signature cx_logger().info("model_signature: {}".format( local_cache["parsed_signature"])) cx_logger().info("{} API is live".format(api["name"])) serve(app, listen="*:{}".format(args.port))
def train(args): ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) package.install_packages(ctx.python_packages, ctx.storage) model = ctx.models_id_map[args.model] logger.info("Training") with util.Tempdir(ctx.cache_dir) as temp_dir: model_dir = os.path.join(temp_dir, "model_dir") ctx.upload_resource_status_start(model) try: estimator_impl, _ = ctx.get_estimator_impl(model["name"]) train_util.train(model["name"], estimator_impl, ctx, model_dir) ctx.upload_resource_status_success(model) logger.info("Caching") logger.info("Caching model " + model["name"]) model_export_dir = os.path.join(model_dir, "export", "estimator") model_zip_path = os.path.join(temp_dir, "model.zip") util.zip_dir(model_export_dir, model_zip_path) ctx.storage.upload_file(model_zip_path, model["key"]) util.log_job_finished(ctx.workload_id) except CortexException as e: ctx.upload_resource_status_failed(model) e.wrap("error") logger.error(str(e)) logger.exception( "An error occurred, see `cortex logs -v model {}` for more details." .format(model["name"])) sys.exit(1) except Exception as e: ctx.upload_resource_status_failed(model) logger.exception( "An error occurred, see `cortex logs -v model {}` for more details." .format(model["name"])) sys.exit(1)
def start(args): ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) api = ctx.apis_id_map[args.api] local_cache["api"] = api local_cache["ctx"] = ctx try: if api.get("request_handler_impl_key") is not None: local_cache["request_handler"] = ctx.get_request_handler_impl( api["name"]) if not util.is_resource_ref(api["model"]): if api.get("request_handler") is not None: package.install_packages(ctx.python_packages, ctx.storage) if not os.path.isdir(args.model_dir): ctx.storage.download_and_unzip_external( api["model"], args.model_dir) else: package.install_packages(ctx.python_packages, ctx.storage) model_name = util.get_resource_ref(api["model"]) model = ctx.models[model_name] estimator = ctx.estimators[model["estimator"]] local_cache["model"] = model local_cache["estimator"] = estimator local_cache["target_col"] = ctx.columns[util.get_resource_ref( model["target_column"])] local_cache["target_col_type"] = ctx.get_inferred_column_type( util.get_resource_ref(model["target_column"])) log_level = "DEBUG" if ctx.environment is not None and ctx.environment.get( "log_level") is not None: log_level = ctx.environment["log_level"].get( "tensorflow", "DEBUG") tf_lib.set_logging_verbosity(log_level) if not os.path.isdir(args.model_dir): ctx.storage.download_and_unzip(model["key"], args.model_dir) for column_name in ctx.extract_column_names( [model["input"], model["target_column"]]): if ctx.is_transformed_column(column_name): trans_impl, _ = ctx.get_transformer_impl(column_name) local_cache["trans_impls"][column_name] = trans_impl transformed_column = ctx.transformed_columns[column_name] # cache aggregate values for resource_name in util.extract_resource_refs( transformed_column["input"]): if resource_name in ctx.aggregates: ctx.get_obj(ctx.aggregates[resource_name]["key"]) local_cache["required_inputs"] = tf_lib.get_base_input_columns( model["name"], ctx) if util.is_dict(model["input"]) and model["input"].get( "target_vocab") is not None: local_cache["target_vocab_populated"] = ctx.populate_values( model["input"]["target_vocab"], None, False) except CortexException as e: e.wrap("error") logger.error(str(e)) logger.exception( "An error occurred, see `cortex logs -v api {}` for more details.". format(api["name"])) sys.exit(1) except Exception as e: logger.exception( "An error occurred, see `cortex logs -v api {}` for more details.". format(api["name"])) sys.exit(1) try: validate_model_dir(args.model_dir) except Exception as e: logger.exception(e) sys.exit(1) channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port)) local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub( channel) # wait a bit for tf serving to start before querying metadata limit = 300 for i in range(limit): try: local_cache["metadata"] = run_get_model_metadata() break except Exception as e: if i == limit - 1: logger.exception( "An error occurred, see `cortex logs -v api {}` for more details." .format(api["name"])) sys.exit(1) time.sleep(1) serve(app, listen="*:{}".format(args.port))
def _get_context(d): return Context(obj=d, cache_dir=".")
def test_simple_end_to_end(spark): local_storage_path = Path("/workspace/local_storage") local_storage_path.mkdir(parents=True, exist_ok=True) should_ingest = True input_data_path = os.path.join(str(local_storage_path), "insurance.csv") raw_ctx = insurance_context.get(input_data_path) workload_id = raw_ctx["raw_columns"]["raw_string_columns"]["smoker"][ "workload_id"] cols_to_validate = [] for column_type in raw_ctx["raw_columns"].values(): for raw_column in column_type.values(): cols_to_validate.append(raw_column["id"]) insurance_data_string = "\n".join(",".join(str(val) for val in line) for line in insurance_data) Path(os.path.join(str(local_storage_path), "insurance.csv")).write_text(insurance_data_string) ctx = Context(raw_obj=raw_ctx, cache_dir="/workspace/cache", local_storage_path=str(local_storage_path)) storage = ctx.storage raw_df = spark_job.ingest_raw_dataset(spark, ctx, cols_to_validate, should_ingest) assert raw_df.count() == 15 assert ctx.get_metadata(ctx.raw_dataset["key"])["dataset_size"] == 15 for raw_column_id in cols_to_validate: path = os.path.join(raw_ctx["status_prefix"], raw_column_id, workload_id) status = storage.get_json(str(path)) status["resource_id"] = raw_column_id status["exist_code"] = "succeeded" cols_to_aggregate = [r["id"] for r in raw_ctx["aggregates"].values()] spark_job.run_aggregators(spark, ctx, cols_to_aggregate, raw_df) for aggregate_id in cols_to_aggregate: for aggregate_resource in raw_ctx["aggregates"].values(): if aggregate_resource["id"] == aggregate_id: assert local_storage_path.joinpath( aggregate_resource["key"]).exists() path = os.path.join(raw_ctx["status_prefix"], aggregate_id, workload_id) status = storage.get_json(str(path)) status["resource_id"] = aggregate_id status["exist_code"] = "succeeded" cols_to_transform = [ r["id"] for r in raw_ctx["transformed_columns"].values() ] spark_job.validate_transformers(spark, ctx, cols_to_transform, raw_df) for transformed_id in cols_to_transform: path = os.path.join(raw_ctx["status_prefix"], transformed_id, workload_id) status = storage.get_json(str(path)) status["resource_id"] = transformed_id status["exist_code"] = "succeeded" training_datasets = [raw_ctx["models"]["dnn"]["dataset"]["id"]] spark_job.create_training_datasets(spark, ctx, training_datasets, raw_df) for dataset_id in training_datasets: path = os.path.join(raw_ctx["status_prefix"], transformed_id, workload_id) status = storage.get_json(str(path)) status["resource_id"] = transformed_id status["exist_code"] = "succeeded" dataset = raw_ctx["models"]["dnn"]["dataset"] metadata = ctx.get_metadata(dataset["id"]) assert metadata["training_size"] + metadata["eval_size"] == 15 assert local_storage_path.joinpath(dataset["train_key"], "_SUCCESS").exists() assert local_storage_path.joinpath(dataset["eval_key"], "_SUCCESS").exists()