예제 #1
0
def shap_explain(booster,
                 datasource,
                 dataset,
                 summary_params,
                 result_table="",
                 is_pai=False,
                 oss_dest=None,
                 oss_ak=None,
                 oss_sk=None,
                 oss_endpoint=None,
                 oss_bucket_name=None):
    tree_explainer = shap.TreeExplainer(booster)
    shap_values = tree_explainer.shap_values(dataset)
    if result_table:
        if is_pai:
            conn = PaiIOConnection.from_table(result_table)
        else:
            conn = db.connect_with_data_source(datasource)
        # TODO(typhoonzero): the shap_values is may be a
        # list of shape [3, num_samples, num_features],
        # use the first dimension here, should find out
        # when to use the other two. When shap_values is
        # not a list it can be directly used.
        if isinstance(shap_values, list):
            to_write = shap_values[0]
        else:
            to_write = shap_values

        columns = list(dataset.columns)
        with db.buffered_db_writer(conn, result_table, columns) as w:
            for row in to_write:
                w.write(list(row))
        conn.close()

    if summary_params.get("plot_type") == "decision":
        shap_interaction_values = tree_explainer.shap_interaction_values(
            dataset)
        expected_value = tree_explainer.expected_value
        if isinstance(shap_interaction_values, list):
            shap_interaction_values = shap_interaction_values[0]
        if isinstance(expected_value, list):
            expected_value = expected_value[0]
        plot_func = lambda: shap.decision_plot(  # noqa: E731
            expected_value,
            shap_interaction_values,
            dataset,
            show=False,
            feature_display_range=slice(None, -40, -1),
            alpha=1)
    else:
        plot_func = lambda: shap.summary_plot(  # noqa: E731
            shap_values, dataset, show=False, **summary_params)

    explainer.plot_and_save(plot_func,
                            oss_dest=oss_dest,
                            oss_ak=oss_ak,
                            oss_sk=oss_sk,
                            oss_endpoint=oss_endpoint,
                            oss_bucket_name=oss_bucket_name,
                            filename='summary')
예제 #2
0
def shap_explain(datasource,
                 select,
                 feature_field_meta,
                 feature_column_names,
                 label_meta,
                 summary_params,
                 result_table="",
                 is_pai=False,
                 pai_explain_table="",
                 oss_dest=None,
                 oss_ak=None,
                 oss_sk=None,
                 oss_endpoint=None,
                 oss_bucket_name=None,
                 transform_fn=None,
                 feature_column_code=""):
    x = xgb_shap_dataset(datasource,
                         select,
                         feature_column_names,
                         label_meta,
                         feature_field_meta,
                         is_pai,
                         pai_explain_table,
                         transform_fn=transform_fn,
                         feature_column_code=feature_column_code)
    shap_values, shap_interaction_values, expected_value = xgb_shap_values(x)
    if result_table != "":
        if is_pai:
            from runtime.dbapi.paiio import PaiIOConnection
            conn = PaiIOConnection.from_table(result_table)
        else:
            conn = db.connect_with_data_source(datasource)
        # TODO(typhoonzero): the shap_values is may be a
        # list of shape [3, num_samples, num_features],
        # use the first dimension here, should find out
        # when to use the other two. When shap_values is
        # not a list it can be directly used.
        if isinstance(shap_values, list):
            to_write = shap_values[0]
        else:
            to_write = shap_values
        write_shap_values(to_write, conn, result_table, feature_column_names)

    if summary_params.get("plot_type") == "decision":
        explainer.plot_and_save(
            lambda: shap.decision_plot(expected_value,
                                       shap_interaction_values,
                                       x,
                                       show=False,
                                       feature_display_range=slice(
                                           None, -40, -1),
                                       alpha=1), oss_dest, oss_ak, oss_sk,
            oss_endpoint, oss_bucket_name)
    else:
        explainer.plot_and_save(
            lambda: shap.summary_plot(
                shap_values, x, show=False, **summary_params), oss_dest,
            oss_ak, oss_sk, oss_endpoint, oss_bucket_name)
예제 #3
0
def shap_explain(booster, datasource, dataset, summary_params, result_table):

    tree_explainer = shap.TreeExplainer(booster)
    shap_values = tree_explainer.shap_values(dataset)
    if result_table:
        conn = db.connect_with_data_source(datasource)
        # TODO(typhoonzero): the shap_values is may be a
        # list of shape [3, num_samples, num_features],
        # use the first dimension here, should find out
        # when to use the other two. When shap_values is
        # not a list it can be directly used.
        if isinstance(shap_values, list):
            to_write = shap_values[0]
        else:
            to_write = shap_values

        columns = list(dataset.columns)
        dtypes = [DataType.to_db_field_type(conn.driver, DataType.FLOAT32)
                  ] * len(columns)
        _create_table(conn, result_table, columns, dtypes)
        with db.buffered_db_writer(conn, result_table, columns) as w:
            for row in to_write:
                w.write(list(row))

        conn.close()

    if summary_params.get("plot_type") == "decision":
        shap_interaction_values = tree_explainer.shap_interaction_values(
            dataset)
        expected_value = tree_explainer.expected_value
        if isinstance(shap_interaction_values, list):
            shap_interaction_values = shap_interaction_values[0]
        if isinstance(expected_value, list):
            expected_value = expected_value[0]

        plot_func = lambda: shap.decision_plot(  # noqa: E731
            expected_value,
            shap_interaction_values,
            dataset,
            show=False,
            feature_display_range=slice(None, -40, -1),
            alpha=1)
    else:
        plot_func = lambda: shap.summary_plot(  # noqa: E731
            shap_values, dataset, show=False, **summary_params)

    filename = 'summary.png'
    with temp_file.TemporaryDirectory(as_cwd=True):
        explainer.plot_and_save(plot_func, filename=filename)
        with open(filename, 'rb') as f:
            img = f.read()

    img = base64.b64encode(img)
    if six.PY3:
        img = img.decode('utf-8')
    img = "<div align='center'><img src='data:image/png;base64,%s' /></div>" \
          % img
    print(img)
예제 #4
0
파일: explain.py 프로젝트: hsjung6/sqlflow
def explain_boosted_trees(datasource, estimator, input_fn, plot_type,
                          result_table, feature_column_names, conn, oss_dest,
                          oss_ak, oss_sk, oss_endpoint, oss_bucket_name):
    result = estimator.experimental_predict_with_explanations(input_fn)
    pred_dicts = list(result)
    df_dfc = pd.DataFrame([pred['dfc'] for pred in pred_dicts])
    dfc_mean = df_dfc.abs().mean()
    gain = estimator.experimental_feature_importances(normalize=True)
    if result_table != "":
        write_dfc_result(dfc_mean, gain, result_table, conn,
                         feature_column_names)
    explainer.plot_and_save(lambda: eval(plot_type)(df_dfc), oss_dest, oss_ak,
                            oss_sk, oss_endpoint, oss_bucket_name)
예제 #5
0
def explain_dnns(datasource, estimator, shap_dataset, plot_type, result_table,
                 feature_column_names, is_pai, pai_table, hdfs_namenode_addr,
                 hive_location, hdfs_user, hdfs_pass, oss_dest, oss_ak, oss_sk,
                 oss_endpoint, oss_bucket_name):
    def predict(d):
        if len(d) == 1:
            # This is to make sure the progress bar of SHAP display properly:
            # 1. The newline makes the progress bar string captured in pipe
            # 2. The ASCII control code moves cursor up twice for alignment
            print("\033[A" * 2)

        def input_fn():
            return tf.data.Dataset.from_tensor_slices(
                dict(pd.DataFrame(d,
                                  columns=shap_dataset.columns))).batch(1000)

        if plot_type == 'bar':
            predictions = [
                p['logits'] if 'logits' in p else p['predictions']
                for p in estimator.predict(input_fn)
            ]
        else:
            predictions = [
                p['logits'][-1] if 'logits' in p else p['predictions'][-1]
                for p in estimator.predict(input_fn)
            ]
        return np.array(predictions)

    if len(shap_dataset) > 100:
        # Reduce to 16 weighted samples to speed up
        shap_dataset_summary = shap.kmeans(shap_dataset, 16)
    else:
        shap_dataset_summary = shap_dataset
    shap_values = shap.KernelExplainer(
        predict, shap_dataset_summary).shap_values(shap_dataset, l1_reg="aic")
    if result_table != "":
        if is_pai:
            write_shap_values(shap_values, "pai_maxcompute", None,
                              result_table, feature_column_names,
                              hdfs_namenode_addr, hive_location, hdfs_user,
                              hdfs_pass)
        else:
            conn = connect_with_data_source(datasource)
            write_shap_values(shap_values, conn.driver, conn, result_table,
                              feature_column_names, hdfs_namenode_addr,
                              hive_location, hdfs_user, hdfs_pass)
    explainer.plot_and_save(
        lambda: shap.summary_plot(
            shap_values, shap_dataset, show=False, plot_type=plot_type),
        is_pai, oss_dest, oss_ak, oss_sk, oss_endpoint, oss_bucket_name)
예제 #6
0
def explain_boosted_trees(datasource, estimator, input_fn, plot_type,
                          result_table, feature_column_names, is_pai,
                          pai_table, hdfs_namenode_addr, hive_location,
                          hdfs_user, hdfs_pass, oss_dest, oss_ak, oss_sk,
                          oss_endpoint, oss_bucket_name):
    result = estimator.experimental_predict_with_explanations(input_fn)
    pred_dicts = list(result)
    df_dfc = pd.DataFrame([pred['dfc'] for pred in pred_dicts])
    dfc_mean = df_dfc.abs().mean()
    gain = estimator.experimental_feature_importances(normalize=True)
    if result_table != "":
        if is_pai:
            write_dfc_result(dfc_mean, gain, result_table, "pai_maxcompute",
                             None, feature_column_names, hdfs_namenode_addr,
                             hive_location, hdfs_user, hdfs_pass)
        else:
            conn = connect_with_data_source(datasource)
            write_dfc_result(dfc_mean, gain, result_table, conn.driver, conn,
                             feature_column_names, hdfs_namenode_addr,
                             hive_location, hdfs_user, hdfs_pass)
    explainer.plot_and_save(lambda: eval(plot_type)(df_dfc), is_pai, oss_dest,
                            oss_ak, oss_sk, oss_endpoint, oss_bucket_name)
예제 #7
0
파일: explain.py 프로젝트: vmnet04/sqlflow
def explain(datasource,
            select,
            feature_field_meta,
            feature_column_names,
            label_meta,
            summary_params,
            result_table="",
            is_pai=False,
            pai_explain_table="",
            hdfs_namenode_addr="",
            hive_location="",
            hdfs_user="",
            hdfs_pass="",
            oss_dest=None,
            oss_ak=None,
            oss_sk=None,
            oss_endpoint=None,
            oss_bucket_name=None,
            transform_fn=None,
            feature_column_code=""):
    x = xgb_shap_dataset(datasource,
                         select,
                         feature_column_names,
                         label_meta,
                         feature_field_meta,
                         is_pai,
                         pai_explain_table,
                         transform_fn=transform_fn,
                         feature_column_code=feature_column_code)

    shap_values, shap_interaction_values, expected_value = xgb_shap_values(x)

    if result_table != "":
        if is_pai:
            # TODO(typhoonzero): the shape of shap_values is
            # (3, num_samples, num_features), use the first
            # dimension here, should find out how to use
            # the other two.
            write_shap_values(shap_values[0], "pai_maxcompute", None,
                              result_table, feature_column_names,
                              hdfs_namenode_addr, hive_location, hdfs_user,
                              hdfs_pass)
        else:
            conn = db.connect_with_data_source(datasource)
            write_shap_values(shap_values[0], conn.driver, conn, result_table,
                              feature_column_names, hdfs_namenode_addr,
                              hive_location, hdfs_user, hdfs_pass)
        return

    if summary_params.get("plot_type") == "decision":
        explainer.plot_and_save(
            lambda: shap.decision_plot(expected_value,
                                       shap_interaction_values,
                                       x,
                                       show=False,
                                       feature_display_range=slice(
                                           None, -40, -1),
                                       alpha=1), oss_dest, oss_ak, oss_sk,
            oss_endpoint, oss_bucket_name)
    else:
        explainer.plot_and_save(
            lambda: shap.summary_plot(
                shap_values, x, show=False, **summary_params), oss_dest,
            oss_ak, oss_sk, oss_endpoint, oss_bucket_name)
예제 #8
0
def shap_explain(booster, datasource, select, summary_params, result_table,
                 model):
    train_fc_map = model.get_meta("features")
    label_meta = model.get_meta("label").get_field_desc()[0].to_dict(
        dtype_to_string=True)

    field_descs = get_ordered_field_descs(train_fc_map)
    feature_column_names = [fd.name for fd in field_descs]
    feature_metas = dict([(fd.name, fd.to_dict(dtype_to_string=True))
                          for fd in field_descs])

    # NOTE: in the current implementation, we are generating a transform_fn
    # from the COLUMN clause. The transform_fn is executed during the process
    # of dumping the original data into DMatrix SVM file.
    compiled_fc = compile_ir_feature_columns(train_fc_map, model.get_type())
    transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer(
        feature_column_names, *compiled_fc["feature_columns"])

    dataset = xgb_shap_dataset(datasource, select, feature_column_names,
                               label_meta, feature_metas, transform_fn)

    tree_explainer = shap.TreeExplainer(booster)
    shap_values = tree_explainer.shap_values(dataset)
    if result_table:
        conn = db.connect_with_data_source(datasource)
        # TODO(typhoonzero): the shap_values is may be a
        # list of shape [3, num_samples, num_features],
        # use the first dimension here, should find out
        # when to use the other two. When shap_values is
        # not a list it can be directly used.
        if isinstance(shap_values, list):
            to_write = shap_values[0]
        else:
            to_write = shap_values

        columns = list(dataset.columns)
        dtypes = [DataType.to_db_field_type(conn.driver, DataType.FLOAT32)
                  ] * len(columns)
        _create_table(conn, result_table, columns, dtypes)
        with db.buffered_db_writer(conn, result_table, columns) as w:
            for row in to_write:
                w.write(list(row))

        conn.close()

    if summary_params.get("plot_type") == "decision":
        shap_interaction_values = tree_explainer.shap_interaction_values(
            dataset)
        expected_value = tree_explainer.expected_value
        if isinstance(shap_interaction_values, list):
            shap_interaction_values = shap_interaction_values[0]
        if isinstance(expected_value, list):
            expected_value = expected_value[0]

        plot_func = lambda: shap.decision_plot(  # noqa: E731
            expected_value,
            shap_interaction_values,
            dataset,
            show=False,
            feature_display_range=slice(None, -40, -1),
            alpha=1)
    else:
        plot_func = lambda: shap.summary_plot(  # noqa: E731
            shap_values, dataset, show=False, **summary_params)

    filename = 'summary.png'
    with temp_file.TemporaryDirectory(as_cwd=True):
        explainer.plot_and_save(plot_func, filename=filename)
        with open(filename, 'rb') as f:
            img = f.read()

    img = base64.b64encode(img)
    if six.PY3:
        img = img.decode('utf-8')
    img = "<div align='center'><img src='data:image/png;base64,%s' /></div>" \
          % img
    print(img)