Example #1
0
def create_tf_model(model_path, flags):
    """Returns the appropriate Model implementation based on env vars."""
    engine = os.environ.get("prediction_engine", MODEL_SERVER_ENGINE_NAME)
    if engine == MODEL_SERVER_ENGINE_NAME:
        logging.debug("Starting model server from %s", model_path)
        try:
            _, stub = _start_model_server(model_path, flags)
        except Exception as e:  # pylint: disable=broad-except
            logging.critical("Could not load ModelServer.\n%s", str(e))
            raise mlprediction.PredictionError(
                mlprediction.PredictionError.FAILED_TO_LOAD_MODEL, str(e))
        signature_map = _get_model_signature_map(stub)
        if not signature_map:
            raise mlprediction.PredictionError(
                mlprediction.PredictionError.FAILED_TO_LOAD_MODEL,
                "Could not get signature map from the model. ")
        client = ModelServerClient(stub, signature_map)
    elif engine == mlprediction.SESSION_RUN_ENGINE_NAME:
        session, signature_map = _get_session_and_signature_map(
            model_path, flags)
        client = mlprediction.SessionClient(session, signature_map)
    else:
        logging.critical("Illegal prediction engine %s", engine)
        raise mlprediction.PredictionError(
            mlprediction.PredictionError.FAILED_TO_LOAD_MODEL,
            "Illegal prediction engine %s" % engine)

    return mlprediction.create_model(client, model_path)
Example #2
0
def create_model_from_server(model_server, model_path):
    session, signature_map = mlprediction.load_model(
        model_path, tags=[tf.saved_model.tag_constants.SERVING])

    # Session is not used; close it.
    session.close()
    client = tf_prediction_server_lib.ModelServerClient(
        model_server, signature_map)
    return mlprediction.create_model(client, model_path)
Example #3
0
def create_model_using_session(model_path):
    session, signature_map = mlprediction.load_model(
        model_path, tags=[tf.saved_model.tag_constants.SERVING])
    client = mlprediction.SessionClient(session, signature_map)
    return mlprediction.create_model(client, model_path)