Exemplo n.º 1
0
def _explain(datasource,
             estimator_string,
             select,
             feature_columns,
             feature_column_names,
             feature_metas={},
             label_meta={},
             model_params={},
             save="",
             pai_table="",
             plot_type='bar',
             result_table="",
             oss_dest=None,
             oss_ak=None,
             oss_sk=None,
             oss_endpoint=None,
             oss_bucket_name=None):
    estimator_cls = import_model(estimator_string)
    FLAGS = tf.app.flags.FLAGS
    model_params["model_dir"] = FLAGS.checkpointDir
    model_params.update(feature_columns)

    def _input_fn():
        dataset = input_fn("",
                           datasource,
                           feature_column_names,
                           feature_metas,
                           label_meta,
                           is_pai=True,
                           pai_table=pai_table)
        return dataset.batch(1).cache()

    estimator = init_model_with_feature_column(estimator_cls, model_params)
    driver = "pai_maxcompute"
    conn = None
    if estimator_cls in (tf.estimator.BoostedTreesClassifier,
                         tf.estimator.BoostedTreesRegressor):
        explain_boosted_trees(datasource, estimator, _input_fn, plot_type,
                              result_table, feature_column_names, driver, conn,
                              "", "", "", "", oss_dest, oss_ak, oss_sk,
                              oss_endpoint, oss_bucket_name)
    else:
        shap_dataset = pd.DataFrame(columns=feature_column_names)
        for i, (features, label) in enumerate(_input_fn()):
            shap_dataset.loc[i] = [
                item.numpy()[0][0] for item in features.values()
            ]
        explain_dnns(datasource, estimator, shap_dataset, plot_type,
                     result_table, feature_column_names, driver, conn, "", "",
                     "", "", oss_dest, oss_ak, oss_sk, oss_endpoint,
                     oss_bucket_name)
Exemplo n.º 2
0
def _explain(datasource,
             estimator_string,
             select,
             feature_columns,
             feature_column_names,
             feature_metas={},
             label_meta={},
             model_params={},
             save="",
             pai_table="",
             plot_type='bar',
             result_table="",
             oss_dest=None,
             oss_ak=None,
             oss_sk=None,
             oss_endpoint=None,
             oss_bucket_name=None):
    estimator_cls = import_model(estimator_string)
    if is_tf_estimator(estimator_cls):
        model_params['model_dir'] = save
    model_params.update(feature_columns)
    pop_optimizer_and_loss(model_params)

    is_pai = True if pai_table else False

    def _input_fn():
        dataset = input_fn(select,
                           datasource,
                           feature_column_names,
                           feature_metas,
                           label_meta,
                           is_pai=is_pai,
                           pai_table=pai_table)
        return dataset.batch(1).cache()

    estimator = init_model_with_feature_column(estimator_cls, model_params)
    if not is_tf_estimator(estimator_cls):
        load_keras_model_weights(estimator, save)

    if result_table:
        if is_pai:
            conn = PaiIOConnection.from_table(result_table)
        else:
            conn = db.connect_with_data_source(datasource)
    else:
        conn = None

    if estimator_cls in (tf.estimator.BoostedTreesClassifier,
                         tf.estimator.BoostedTreesRegressor):
        explain_boosted_trees(datasource, estimator, _input_fn, plot_type,
                              result_table, feature_column_names, conn,
                              oss_dest, oss_ak, oss_sk, oss_endpoint,
                              oss_bucket_name)
    else:
        shap_dataset = pd.DataFrame(columns=feature_column_names)
        for i, (features, label) in enumerate(_input_fn()):
            shap_dataset.loc[i] = [
                item.numpy()[0][0] for item in features.values()
            ]
        explain_dnns(datasource, estimator, shap_dataset, plot_type,
                     result_table, feature_column_names, conn, oss_dest,
                     oss_ak, oss_sk, oss_endpoint, oss_bucket_name)

    if conn is not None:
        conn.close()