示例#1
0
def predict(datasource, select, data_table, result_table, label_column,
            oss_model_path):
    """PAI TensorFlow prediction wrapper
    This function do some preparation for the local prediction, say,
    download the model from OSS, extract metadata and so on.

    Args:
        datasource: the datasource from which to get data
        select: data selection SQL statement
        data_table: tmp table which holds the data from select
        result_table: table to save prediction result
        label_column: prediction label column
        oss_model_path: the model path on OSS
    """

    try:
        tf.enable_eager_execution()
    except:  # noqa: E722
        pass

    (estimator, feature_column_names, feature_column_names_map, feature_metas,
     label_meta, model_params,
     feature_columns_code) = oss.load_metas(oss_model_path,
                                            "tensorflow_model_desc")

    feature_columns = eval(feature_columns_code)

    # NOTE(typhoonzero): No need to eval model_params["optimizer"] and
    # model_params["loss"] because predicting do not need these parameters.

    is_estimator = is_tf_estimator(import_model(estimator))

    # Keras single node is using h5 format to save the model, no need to deal
    # with export model format. Keras distributed mode will use estimator, so
    # this is also needed.
    if is_estimator:
        oss.load_file(oss_model_path, "exported_path")
        # NOTE(typhoonzero): directory "model_save" is hardcoded in
        # codegen/tensorflow/codegen.go
        oss.load_dir("%s/model_save" % oss_model_path)
    else:
        oss.load_file(oss_model_path, "model_save")

    _predict(datasource=datasource,
             estimator_string=estimator,
             select=select,
             result_table=result_table,
             feature_columns=feature_columns,
             feature_column_names=feature_column_names,
             feature_column_names_map=feature_column_names_map,
             train_label_name=label_meta["feature_name"],
             result_col_name=label_column,
             feature_metas=feature_metas,
             model_params=model_params,
             save="model_save",
             batch_size=1,
             pai_table=data_table)
示例#2
0
def evaluate(datasource, select, data_table, result_table, oss_model_path,
             metrics):
    """PAI Tensorflow evaluate wrapper
    This function do some preparation for the local evaluation, say,
    download the model from OSS, extract metadata and so on.

    Args:
        datasource: the datasource from which to get data
        select: data selection SQL statement
        data_table: tmp table which holds the data from select
        result_table: table to save prediction result
        oss_model_path: the model path on OSS
        metrics: metrics to evaluate
    """

    (estimator, feature_column_names, feature_column_names_map, feature_metas,
     label_meta, model_params,
     feature_columns_code) = oss.load_metas(oss_model_path,
                                            "tensorflow_model_desc")

    feature_columns = eval(feature_columns_code)
    # NOTE(typhoonzero): No need to eval model_params["optimizer"] and
    # model_params["loss"] because predicting do not need these parameters.

    is_estimator = is_tf_estimator(import_model(estimator))

    # Keras single node is using h5 format to save the model, no need to deal
    # with export model format. Keras distributed mode will use estimator, so
    # this is also needed.
    if is_estimator:
        oss.load_file(oss_model_path, "exported_path")
        # NOTE(typhoonzero): directory "model_save" is hardcoded in
        # codegen/tensorflow/codegen.go
        oss.load_dir("%s/model_save" % oss_model_path)
    else:
        oss.load_file(oss_model_path, "model_save")

    _evaluate(datasource=datasource,
              estimator_string=estimator,
              select=select,
              result_table=result_table,
              feature_columns=feature_columns,
              feature_column_names=feature_column_names,
              feature_metas=feature_metas,
              label_meta=label_meta,
              model_params=model_params,
              validation_metrics=metrics,
              save="model_save",
              batch_size=1,
              validation_steps=None,
              verbose=0,
              is_pai=True,
              pai_table=data_table)
示例#3
0
def explain(datasource, select, data_table, result_table, label_column,
            oss_model_path):
    try:
        tf.enable_eager_execution()
    except Exception as e:
        sys.stderr.write("warning: failed to enable_eager_execution: %s" % e)
        pass

    (estimator, feature_column_names, feature_column_names_map, feature_metas,
     label_meta, model_params,
     feature_columns_code) = oss.load_metas(oss_model_path,
                                            "tensorflow_model_desc")

    feature_columns = eval(feature_columns_code)
    # NOTE(typhoonzero): No need to eval model_params["optimizer"] and
    # model_params["loss"] because predicting do not need these parameters.

    is_estimator = is_tf_estimator(import_model(estimator))

    # Keras single node is using h5 format to save the model, no need to deal
    # with export model format. Keras distributed mode will use estimator, so
    # this is also needed.
    if is_estimator:
        oss.load_file(oss_model_path, "exported_path")
        # NOTE(typhoonzero): directory "model_save" is hardcoded in
        # codegen/tensorflow/codegen.go
        oss.load_dir("%s/model_save" % oss_model_path)
    else:
        oss.load_file(oss_model_path, "model_save")

    # (TODO: lhw) use oss to store result image
    _explain(datasource=datasource,
             estimator_string=estimator,
             select=select,
             feature_columns=feature_columns,
             feature_column_names=feature_column_names,
             feature_metas=feature_metas,
             label_meta=label_meta,
             model_params=model_params,
             save="model_save",
             result_table=result_table,
             pai_table=data_table,
             oss_dest=None,
             oss_ak=None,
             oss_sk=None,
             oss_endpoint=None,
             oss_bucket_name=None)