Пример #1
0
def xgb_native_explain(booster, datasource, result_table):
    if not result_table:
        raise ValueError(
            "XGBoostExplainer must use with INTO to output result to a table.")

    gain_map = booster.get_score(importance_type="gain")
    fscore_map = booster.get_fscore()
    conn = db.connect_with_data_source(datasource)

    all_feature_keys = list(gain_map.keys())
    all_feature_keys.sort()
    columns = ["feature", "fscore", "gain"]
    dtypes = [
        DataType.to_db_field_type(conn.driver, DataType.STRING),
        DataType.to_db_field_type(conn.driver, DataType.FLOAT32),
        DataType.to_db_field_type(conn.driver, DataType.FLOAT32),
    ]
    _create_table(conn, result_table, columns, dtypes)

    with db.buffered_db_writer(conn, result_table, columns) as w:
        for fkey in all_feature_keys:
            row = [fkey, fscore_map[fkey], gain_map[fkey]]
            w.write(list(row))

    conn.close()
def create_predict_table(conn, select, result_table, train_label_desc,
                         pred_label_name):
    """
    Create the result prediction table.

    Args:
        conn: the database connection object.
        select (str): the input data to predict.
        result_table (str): the output data table.
        train_label_desc (FieldDesc): the FieldDesc of the trained label.
        pred_label_name (str): the output label name to predict.

    Returns:
        A tuple of (result_column_names, train_label_index).
    """
    name_and_types = db.selected_columns_and_types(conn, select)
    train_label_index = -1
    if train_label_desc:
        for i, (name, _) in enumerate(name_and_types):
            if name == train_label_desc.name:
                train_label_index = i
                break

    if train_label_index >= 0:
        del name_and_types[train_label_index]

    column_strs = []
    for name, typ in name_and_types:
        column_strs.append("%s %s" %
                           (name, db.to_db_field_type(conn.driver, typ)))

    if train_label_desc and train_label_desc.format == DataFormat.PLAIN:
        train_label_field_type = DataType.to_db_field_type(
            conn.driver, train_label_desc.dtype)
    else:
        # if no train lable description is provided (clustering),
        # we treat the column type as string
        train_label_field_type = DataType.to_db_field_type(
            conn.driver, DataType.STRING)

    column_strs.append("%s %s" % (pred_label_name, train_label_field_type))

    drop_sql = "DROP TABLE IF EXISTS %s;" % result_table
    create_sql = "CREATE TABLE %s (%s);" % (result_table,
                                            ",".join(column_strs))
    conn.execute(drop_sql)
    conn.execute(create_sql)
    result_column_names = [item[0] for item in name_and_types]
    result_column_names.append(pred_label_name)
    return result_column_names, train_label_index
def create_evaluate_table(conn, result_table, validation_metrics):
    """
    Create the result table to store the evaluation result.

    Args:
        conn: the database connection object.
        result_table (str): the output data table.
        validation_metrics (list[str]): the evaluation metric names.

    Returns:
        The column names of the created table.
    """
    result_columns = ['loss'] + validation_metrics
    float_field_type = DataType.to_db_field_type(conn.driver, DataType.FLOAT32)
    column_strs = [
        "%s %s" % (name, float_field_type) for name in result_columns
    ]

    drop_sql = "DROP TABLE IF EXISTS %s;" % result_table
    create_sql = "CREATE TABLE %s (%s);" % (result_table,
                                            ",".join(column_strs))
    conn.execute(drop_sql)
    conn.execute(create_sql)

    return result_columns
Пример #4
0
def _create_result_table(datasource, select, variables, result_value_name,
                         variable_type, result_table):
    if variable_type.endswith('Integers') or variable_type == "Binary":
        result_type = DataType.INT64
    elif variable_type.endswith('Reals'):
        result_type = DataType.FLOAT32
    else:
        raise ValueError("unsupported variable type %s" % variable_type)

    conn = db.connect_with_data_source(datasource)
    name_and_types = dict(db.selected_columns_and_types(conn, select))
    columns = []
    for var in variables:
        field_type = db.to_db_field_type(conn.driver, name_and_types.get(var))
        columns.append("%s %s" % (var, field_type))

    if len(variables) == 1 and variables[0].lower() == result_value_name.lower(
    ):
        result_value_name += "_value"

    columns.append("%s %s" %
                   (result_value_name,
                    DataType.to_db_field_type(conn.driver, result_type)))
    column_str = ",".join(columns)

    conn.execute("DROP TABLE IF EXISTS %s" % result_table)
    create_sql = "CREATE TABLE %s (%s)" % (result_table, column_str)
    conn.execute(create_sql)
    conn.close()
Пример #5
0
def shap_explain(booster, datasource, dataset, summary_params, result_table):

    tree_explainer = shap.TreeExplainer(booster)
    shap_values = tree_explainer.shap_values(dataset)
    if result_table:
        conn = db.connect_with_data_source(datasource)
        # TODO(typhoonzero): the shap_values is may be a
        # list of shape [3, num_samples, num_features],
        # use the first dimension here, should find out
        # when to use the other two. When shap_values is
        # not a list it can be directly used.
        if isinstance(shap_values, list):
            to_write = shap_values[0]
        else:
            to_write = shap_values

        columns = list(dataset.columns)
        dtypes = [DataType.to_db_field_type(conn.driver, DataType.FLOAT32)
                  ] * len(columns)
        _create_table(conn, result_table, columns, dtypes)
        with db.buffered_db_writer(conn, result_table, columns) as w:
            for row in to_write:
                w.write(list(row))

        conn.close()

    if summary_params.get("plot_type") == "decision":
        shap_interaction_values = tree_explainer.shap_interaction_values(
            dataset)
        expected_value = tree_explainer.expected_value
        if isinstance(shap_interaction_values, list):
            shap_interaction_values = shap_interaction_values[0]
        if isinstance(expected_value, list):
            expected_value = expected_value[0]

        plot_func = lambda: shap.decision_plot(  # noqa: E731
            expected_value,
            shap_interaction_values,
            dataset,
            show=False,
            feature_display_range=slice(None, -40, -1),
            alpha=1)
    else:
        plot_func = lambda: shap.summary_plot(  # noqa: E731
            shap_values, dataset, show=False, **summary_params)

    filename = 'summary.png'
    with temp_file.TemporaryDirectory(as_cwd=True):
        explainer.plot_and_save(plot_func, filename=filename)
        with open(filename, 'rb') as f:
            img = f.read()

    img = base64.b64encode(img)
    if six.PY3:
        img = img.decode('utf-8')
    img = "<div align='center'><img src='data:image/png;base64,%s' /></div>" \
          % img
    print(img)
Пример #6
0
def create_explain_table(conn, model_type, explainer, estimator_string,
                         result_table, feature_column_names):
    drop_sql = "DROP TABLE IF EXISTS %s;" % result_table
    conn.execute(drop_sql)

    if model_type == EstimatorType.PAIML:
        return
    elif model_type == EstimatorType.TENSORFLOW and \
            estimator_string in ("BoostedTreesClassifier",
                                 "BoostedTreesRegressor"):
        # Tensorflow boosted trees model explain:
        columns = ["feature", "dfc", "gain"]
        dtypes = [
            DataType.to_db_field_type(conn.driver, DataType.STRING),
            DataType.to_db_field_type(conn.driver, DataType.FLOAT32),
            DataType.to_db_field_type(conn.driver, DataType.FLOAT32),
        ]
    elif model_type == EstimatorType.XGBOOST and \
            explainer == "XGBoostExplainer":
        columns = ["feature", "fscore", "gain"]
        dtypes = [
            DataType.to_db_field_type(conn.driver, DataType.STRING),
            DataType.to_db_field_type(conn.driver, DataType.FLOAT32),
            DataType.to_db_field_type(conn.driver, DataType.FLOAT32),
        ]
    else:
        # shap explain result
        columns = feature_column_names
        dtypes = [DataType.to_db_field_type(conn.driver, DataType.FLOAT32)
                  ] * len(columns)

    column_strs = [
        "%s %s" % (name, dtype) for name, dtype in zip(columns, dtypes)
    ]
    create_sql = "CREATE TABLE %s (%s);" % (result_table,
                                            ",".join(column_strs))
    conn.execute(create_sql)
Пример #7
0
def shap_explain(booster, datasource, select, summary_params, result_table,
                 model):
    train_fc_map = model.get_meta("features")
    label_meta = model.get_meta("label").get_field_desc()[0].to_dict(
        dtype_to_string=True)

    field_descs = get_ordered_field_descs(train_fc_map)
    feature_column_names = [fd.name for fd in field_descs]
    feature_metas = dict([(fd.name, fd.to_dict(dtype_to_string=True))
                          for fd in field_descs])

    # NOTE: in the current implementation, we are generating a transform_fn
    # from the COLUMN clause. The transform_fn is executed during the process
    # of dumping the original data into DMatrix SVM file.
    compiled_fc = compile_ir_feature_columns(train_fc_map, model.get_type())
    transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer(
        feature_column_names, *compiled_fc["feature_columns"])

    dataset = xgb_shap_dataset(datasource, select, feature_column_names,
                               label_meta, feature_metas, transform_fn)

    tree_explainer = shap.TreeExplainer(booster)
    shap_values = tree_explainer.shap_values(dataset)
    if result_table:
        conn = db.connect_with_data_source(datasource)
        # TODO(typhoonzero): the shap_values is may be a
        # list of shape [3, num_samples, num_features],
        # use the first dimension here, should find out
        # when to use the other two. When shap_values is
        # not a list it can be directly used.
        if isinstance(shap_values, list):
            to_write = shap_values[0]
        else:
            to_write = shap_values

        columns = list(dataset.columns)
        dtypes = [DataType.to_db_field_type(conn.driver, DataType.FLOAT32)
                  ] * len(columns)
        _create_table(conn, result_table, columns, dtypes)
        with db.buffered_db_writer(conn, result_table, columns) as w:
            for row in to_write:
                w.write(list(row))

        conn.close()

    if summary_params.get("plot_type") == "decision":
        shap_interaction_values = tree_explainer.shap_interaction_values(
            dataset)
        expected_value = tree_explainer.expected_value
        if isinstance(shap_interaction_values, list):
            shap_interaction_values = shap_interaction_values[0]
        if isinstance(expected_value, list):
            expected_value = expected_value[0]

        plot_func = lambda: shap.decision_plot(  # noqa: E731
            expected_value,
            shap_interaction_values,
            dataset,
            show=False,
            feature_display_range=slice(None, -40, -1),
            alpha=1)
    else:
        plot_func = lambda: shap.summary_plot(  # noqa: E731
            shap_values, dataset, show=False, **summary_params)

    filename = 'summary.png'
    with temp_file.TemporaryDirectory(as_cwd=True):
        explainer.plot_and_save(plot_func, filename=filename)
        with open(filename, 'rb') as f:
            img = f.read()

    img = base64.b64encode(img)
    if six.PY3:
        img = img.decode('utf-8')
    img = "<div align='center'><img src='data:image/png;base64,%s' /></div>" \
          % img
    print(img)
Пример #8
0
def explain(datasource, select, explainer, model_params, result_table, model):
    """
    Do explanation to a trained TensorFlow model.

    Args:
        datasource (str): the database connection string.
        select (str): the input data to predict.
        explainer (str): the explainer to explain the model.
                         Not used in TensorFlow models.
        model_params (dict): the parameters for evaluation.
        result_table (str): the output data table.
        model (Model|str): the model object or where to load the model.

    Returns:
        None.
    """
    if isinstance(model, six.string_types):
        model = Model.load_from_db(datasource, model)
    else:
        assert isinstance(model,
                          Model), "not supported model type %s" % type(model)

    plot_type = model_params.get("summary.plot_type", "bar")

    train_attributes = model.get_meta("attributes")
    train_fc_map = model.get_meta("features")
    train_label_desc = model.get_meta("label").get_field_desc()[0]
    estimator_string = model.get_meta("class_name")
    save = "model_save"

    field_descs = get_ordered_field_descs(train_fc_map)
    feature_column_names = [fd.name for fd in field_descs]
    feature_metas = dict([(fd.name, fd.to_dict(dtype_to_string=True))
                          for fd in field_descs])
    feature_columns = compile_ir_feature_columns(train_fc_map,
                                                 model.get_type())

    label_name = model_params.get("label_col", train_label_desc.name)
    train_label_desc.name = label_name
    label_meta = train_label_desc.to_dict(dtype_to_string=True)

    if result_table:
        conn = db.connect_with_data_source(datasource)
        if estimator_string.startswith("BoostedTrees"):
            column_defs = [
                "feature %s" %
                DataType.to_db_field_type(conn.driver, DataType.STRING),
                "dfc %s" %
                DataType.to_db_field_type(conn.driver, DataType.FLOAT32),
                "gain %s" %
                DataType.to_db_field_type(conn.driver, DataType.FLOAT32),
            ]
        else:
            selected_cols = db.selected_cols(conn, select)
            if label_name in selected_cols:
                selected_cols.remove(label_name)

            name_to_shape = dict([(fd.name, fd.shape) for fd in field_descs])
            column_defs = []
            float_field_type = DataType.to_db_field_type(
                conn.driver, DataType.FLOAT32)
            for name in selected_cols:
                shape = name_to_shape.get(name, None)
                if shape is None:
                    raise ValueError("cannot find column %s" % name)

                size = int(np.prod(shape))
                if size == 1:
                    column_def = "%s %s" % (name, float_field_type)
                    column_defs.append(column_def)
                else:
                    for i in six.moves.range(size):
                        column_def = "%s_%d %s" % (name, i, float_field_type)
                        column_defs.append(column_def)

        drop_sql = "DROP TABLE IF EXISTS %s;" % result_table
        create_sql = "CREATE TABLE %s (%s);" % (result_table,
                                                ",".join(column_defs))
        conn.execute(drop_sql)
        conn.execute(create_sql)
        conn.close()

    _explain(datasource=datasource,
             estimator_string=estimator_string,
             select=select,
             feature_columns=feature_columns,
             feature_column_names=feature_column_names,
             feature_metas=feature_metas,
             label_meta=label_meta,
             model_params=train_attributes,
             save=save,
             plot_type=plot_type,
             result_table=result_table)

    with open('summary.png', 'rb') as f:
        img = f.read()

    img = base64.b64encode(img)
    if six.PY3:
        img = img.decode('utf-8')
    img = "<div align='center'><img src='data:image/png;base64,%s' /></div>" \
          % img
    print(img)