def load_model(model_uri,
               meta_graph_tags,
               signature_def_map_key,
               tf_session=None):
    """
    Load a tensorflow model from a specific path.

    *With TensorFlow version <2.0.0, this method must be called within a TensorFlow graph context.*

    :param model_uri: The location, in URI format, of the tensorflow model. For example:

                      - ``/Users/me/path/to/local/model``
                      - ``relative/path/to/local/model``

    :param meta_graph_tags: A list of tags identifying the model's metagraph within the
                               serialized ``SavedModel`` object. For more information, see the
                               ``tags`` parameter of the `tf.saved_model.builder.SavedModelBuilder
                               method <https://www.tensorflow.org/api_docs/python/tf/saved_model/
                               builder/SavedModelBuilder#add_meta_graph>`_.

    :param signature_def_map_key: A string identifying the input/output signature associated with the
                                 model. This is a key within the serialized ``SavedModel``'s
                                 signature definition mapping. For more information, see the
                                 ``signature_def_map`` parameter of the
                                 ``tf.saved_model.builder.SavedModelBuilder`` method.

    :param tf_session: The TensorFlow session in which to load the model. If using TensorFlow
                    version >= 2.0.0, this argument is ignored. If using TensorFlow <2.0.0, if no
                    session is passed to this function, aiflow will attempt to load the model using
                    the default TensorFlow session.  If no default session is available, then the
                    function raises an exception.
    :return: For TensorFlow < 2.0.0, a TensorFlow signature definition of type:
             ``tensorflow.core.protobuf.meta_graph_pb2.SignatureDef``. This defines the input and
             output tensors for model inference.
             For TensorFlow >= 2.0.0, A callable graph (tf.function) that takes inputs and
             returns inferences.
    """
    if LooseVersion(tensorflow.__version__) < LooseVersion('2.0.0'):
        if not tf_session:
            tf_session = tensorflow.get_default_session()
            if not tf_session:
                raise AIFlowException(
                    "No TensorFlow session found while calling load_model()." +
                    "You can set the default Tensorflow session before calling"
                    +
                    " load_model via `session.as_default()`, or directly pass "
                    + "a session in which to load the model via the tf_sess " +
                    "argument.")

    else:
        if tf_session:
            warnings.warn(
                "A TensorFlow session was passed into load_model, but the " +
                "currently used version is TF 2.0 where sessions are deprecated. "
                + "The tf_sess argument will be ignored.", FutureWarning)
    return _load_tensorflow_saved_model(
        model_uri=model_uri,
        meta_graph_tags=meta_graph_tags,
        signature_def_map_key=signature_def_map_key,
        tf_session=tf_session)
示例#2
0
def _unwrap_job_response(response):
    if response.return_code == str(SUCCESS):
        return ProtoToMeta.proto_to_job_meta(Parse(response.data, JobProto()))
    elif response.return_code == str(RESOURCE_DOES_NOT_EXIST):
        return None
    else:
        raise AIFlowException(response.return_msg)
示例#3
0
def _unwrap_update_response(response):
    if response.return_code == str(SUCCESS):
        return int(response.data)
    elif response.return_code == str(INTERNAL_ERROR):
        return Status.ERROR
    else:
        raise AIFlowException(response.return_msg)
示例#4
0
def _unwrap_job_list_response(response):
    if response.return_code == str(SUCCESS):
        job_proto_list = Parse(response.data, JobListProto())
        return ProtoToMeta.proto_to_job_meta_list(job_proto_list.jobs)
    elif response.return_code == str(RESOURCE_DOES_NOT_EXIST):
        return None
    else:
        raise AIFlowException(response.return_msg)
示例#5
0
 def example_meta_list_to_table(example_meta_list: List[ExampleMeta],
                                store_type='SqlAlchemyStore'):
     list_example_table = []
     for example_meta in example_meta_list:
         if example_meta.schema is not None:
             name_list = example_meta.schema.name_list
             type_list = example_meta.schema.type_list
             if name_list is not None and type_list is not None:
                 if len(name_list) != len(type_list):
                     raise AIFlowException(
                         "the length of name list and type list should be the same"
                     )
             if name_list is not None and type_list is None:
                 raise AIFlowException(
                     "the length of name list and type list should be the same"
                 )
             if name_list is None and type_list is not None:
                 raise AIFlowException(
                     "the length of name list and type list should be the same"
                 )
         else:
             name_list = None
             type_list = None
         list_example_table.append(
             MetaToTable.example_meta_to_table(
                 name=example_meta.name,
                 support_type=example_meta.support_type,
                 data_type=example_meta.data_type,
                 data_format=example_meta.data_format,
                 description=example_meta.description,
                 batch_uri=example_meta.batch_uri,
                 stream_uri=example_meta.stream_uri,
                 create_time=example_meta.create_time,
                 update_time=example_meta.update_time,
                 properties=example_meta.properties,
                 name_list=name_list,
                 type_list=type_list,
                 catalog_name=example_meta.catalog_name,
                 catalog_type=example_meta.catalog_type,
                 catalog_database=example_meta.catalog_database,
                 catalog_connection_uri=example_meta.catalog_connection_uri,
                 catalog_table=example_meta.catalog_table,
                 catalog_version=example_meta.catalog_version,
                 store_type=store_type))
     return list_example_table
示例#6
0
def _unwrap_example_list_response(response):
    if response.return_code == str(SUCCESS):
        example_proto_list = Parse(response.data, ExampleListProto())
        return ProtoToMeta.proto_to_example_meta_list(
            example_proto_list.examples)
    elif response.return_code == str(RESOURCE_DOES_NOT_EXIST):
        return None
    else:
        raise AIFlowException(response.return_msg)
示例#7
0
def _parse_response(response, message):
    if response.return_code == str(SUCCESS):
        if response.data == '':
            return None
        else:
            return Parse(response.data, message, ignore_unknown_fields=False)
    else:
        raise AIFlowException(error_code=response.return_code,
                              error_msg=response.return_msg)
示例#8
0
def _unwrap_artifact_list_response(response):
    if response.return_code == str(SUCCESS):
        artifact_proto_list = Parse(response.data, ArtifactListProto())
        return ProtoToMeta.proto_to_artifact_meta_list(
            artifact_proto_list.artifacts)
    elif response.return_code == str(RESOURCE_DOES_NOT_EXIST):
        return None
    else:
        raise AIFlowException(response.return_msg)
示例#9
0
def _unwrap_model_relation_list_response(response):
    if response.return_code == str(SUCCESS):
        model_proto_list = Parse(response.data, ModelRelationListProto())
        return ProtoToMeta.proto_to_model_relation_meta_list(
            model_proto_list.model_relations)
    elif response.return_code == str(RESOURCE_DOES_NOT_EXIST):
        return None
    else:
        raise AIFlowException(response.return_msg)
示例#10
0
def _unwrap_workflow_execution_list_response(response):
    if response.return_code == str(SUCCESS):
        workflow_execution_proto_list = Parse(response.data,
                                              WorkFlowExecutionListProto())
        return ProtoToMeta.proto_to_execution_meta_list(
            workflow_execution_proto_list.workflow_executions)
    elif response.return_code == str(RESOURCE_DOES_NOT_EXIST):
        return None
    else:
        raise AIFlowException(response.return_msg)
def _get_store(db_uri=''):
    try:
        username, password, host, port, db = parse_mongo_uri(db_uri)
        return MongoStore(host=host,
                          port=int(port),
                          username=username,
                          password=password,
                          db=db)
    except Exception as e:
        raise AIFlowException(str(e))
示例#12
0
def extract_db_engine_from_uri(db_uri):
    """
    Parse specified database URI to extract database type. Confirm extracted database engine is
    supported. If database driver is specified, confirm driver passes a plausible regex.
    """
    scheme = urllib.parse.urlparse(db_uri).scheme
    scheme_plus_count = scheme.count('+')
    """validates scheme parsed from DB URI is supported"""
    if scheme_plus_count == 0:
        db_engine = scheme
    elif scheme_plus_count == 1:
        db_engine, _ = scheme.split('+')
    else:
        error_msg = "Invalid database URI: '%s'." % db_uri
        raise AIFlowException(error_msg)
    """validates db_engine parsed from DB URI is supported"""
    if db_engine not in DATABASE_ENGINES:
        error_msg = "Invalid database engine: '%s'." % db_engine
        raise AIFlowException(error_msg)

    return db_engine
示例#13
0
 def make_managed_session():
     """Provide transactional scope around series of session operations."""
     session = SessionMaker()
     try:
         yield session
         session.commit()
     except AIFlowException:
         session.rollback()
         raise
     except Exception as e:
         session.rollback()
         raise AIFlowException(error_msg=e, error_code=INTERNAL_ERROR)
     finally:
         session.close()
示例#14
0
def _load_tensorflow_saved_model(model_uri, meta_graph_tags,
                                 signature_def_map_key, tf_session):
    """
    Load a specified TensorFlow model consisting of a TensorFlow metagraph and signature definition
    from a serialized TensorFlow ``SavedModel`` collection.

    :param model_uri: The local filesystem path or run-relative artifact path to the model.
    :param meta_graph_tags: A list of tags identifying the model's metagraph within the
                               serialized ``SavedModel`` object. For more information, see the
                               ``tags`` parameter of the `tf.saved_model.builder.SavedModelBuilder
                               method <https://www.tensorflow.org/api_docs/python/tf/saved_model/
                               builder/SavedModelBuilder#add_meta_graph>`_.
    :param signature_def_map_key: A string identifying the input/output signature associated with the
                                 model. This is a key within the serialized ``SavedModel``'s
                                 signature definition mapping. For more information, see the
                                 ``signature_def_map`` parameter of the
                                 ``tf.saved_model.builder.SavedModelBuilder`` method.
    :param tf_session: The TensorFlow session in which to load the metagraph.
                    Required in TensorFlow versions < 2.0.0. Unused in TensorFlow versions >= 2.0.0
    :return: For TensorFlow versions < 2.0.0:
             A TensorFlow signature definition of type:
             ``tensorflow.core.protobuf.meta_graph_pb2.SignatureDef``. This defines input and
             output tensors within the specified metagraph for inference.
             For TensorFlow versions >= 2.0.0:
             A callable graph (tensorflow.function) that takes inputs and returns inferences.
    """
    if LooseVersion(tensorflow.__version__) < LooseVersion('2.0.0'):
        loaded = tensorflow.saved_model.loader.load(sess=tf_session,
                                                    tags=meta_graph_tags,
                                                    export_dir=model_uri)
        loaded_sig = loaded.signature_def
    else:
        loaded = tensorflow.saved_model.load(  # pylint: disable=no-value-for-parameter
            tags=meta_graph_tags,
            export_dir=model_uri)
        loaded_sig = loaded.signatures
    if signature_def_map_key not in loaded_sig:
        raise AIFlowException(
            "Could not find signature def key %s. Available keys are: %s" %
            (signature_def_map_key, list(loaded_sig.keys())))
    return loaded_sig[signature_def_map_key]