def _explain(datasource, estimator_string, select, feature_columns, feature_column_names, feature_metas={}, label_meta={}, model_params={}, save="", pai_table="", plot_type='bar', result_table="", oss_dest=None, oss_ak=None, oss_sk=None, oss_endpoint=None, oss_bucket_name=None): estimator_cls = import_model(estimator_string) FLAGS = tf.app.flags.FLAGS model_params["model_dir"] = FLAGS.checkpointDir model_params.update(feature_columns) def _input_fn(): dataset = input_fn("", datasource, feature_column_names, feature_metas, label_meta, is_pai=True, pai_table=pai_table) return dataset.batch(1).cache() estimator = init_model_with_feature_column(estimator_cls, model_params) driver = "pai_maxcompute" conn = None if estimator_cls in (tf.estimator.BoostedTreesClassifier, tf.estimator.BoostedTreesRegressor): explain_boosted_trees(datasource, estimator, _input_fn, plot_type, result_table, feature_column_names, driver, conn, "", "", "", "", oss_dest, oss_ak, oss_sk, oss_endpoint, oss_bucket_name) else: shap_dataset = pd.DataFrame(columns=feature_column_names) for i, (features, label) in enumerate(_input_fn()): shap_dataset.loc[i] = [ item.numpy()[0][0] for item in features.values() ] explain_dnns(datasource, estimator, shap_dataset, plot_type, result_table, feature_column_names, driver, conn, "", "", "", "", oss_dest, oss_ak, oss_sk, oss_endpoint, oss_bucket_name)
def _explain(datasource, estimator_string, select, feature_columns, feature_column_names, feature_metas={}, label_meta={}, model_params={}, save="", pai_table="", plot_type='bar', result_table="", oss_dest=None, oss_ak=None, oss_sk=None, oss_endpoint=None, oss_bucket_name=None): estimator_cls = import_model(estimator_string) if is_tf_estimator(estimator_cls): model_params['model_dir'] = save model_params.update(feature_columns) pop_optimizer_and_loss(model_params) is_pai = True if pai_table else False def _input_fn(): dataset = input_fn(select, datasource, feature_column_names, feature_metas, label_meta, is_pai=is_pai, pai_table=pai_table) return dataset.batch(1).cache() estimator = init_model_with_feature_column(estimator_cls, model_params) if not is_tf_estimator(estimator_cls): load_keras_model_weights(estimator, save) if result_table: if is_pai: conn = PaiIOConnection.from_table(result_table) else: conn = db.connect_with_data_source(datasource) else: conn = None if estimator_cls in (tf.estimator.BoostedTreesClassifier, tf.estimator.BoostedTreesRegressor): explain_boosted_trees(datasource, estimator, _input_fn, plot_type, result_table, feature_column_names, conn, oss_dest, oss_ak, oss_sk, oss_endpoint, oss_bucket_name) else: shap_dataset = pd.DataFrame(columns=feature_column_names) for i, (features, label) in enumerate(_input_fn()): shap_dataset.loc[i] = [ item.numpy()[0][0] for item in features.values() ] explain_dnns(datasource, estimator, shap_dataset, plot_type, result_table, feature_column_names, conn, oss_dest, oss_ak, oss_sk, oss_endpoint, oss_bucket_name) if conn is not None: conn.close()