def xgb_native_explain(booster, datasource, result_table): if not result_table: raise ValueError( "XGBoostExplainer must use with INTO to output result to a table.") gain_map = booster.get_score(importance_type="gain") fscore_map = booster.get_fscore() conn = db.connect_with_data_source(datasource) all_feature_keys = list(gain_map.keys()) all_feature_keys.sort() columns = ["feature", "fscore", "gain"] dtypes = [ DataType.to_db_field_type(conn.driver, DataType.STRING), DataType.to_db_field_type(conn.driver, DataType.FLOAT32), DataType.to_db_field_type(conn.driver, DataType.FLOAT32), ] _create_table(conn, result_table, columns, dtypes) with db.buffered_db_writer(conn, result_table, columns) as w: for fkey in all_feature_keys: row = [fkey, fscore_map[fkey], gain_map[fkey]] w.write(list(row)) conn.close()
def create_predict_table(conn, select, result_table, train_label_desc, pred_label_name): """ Create the result prediction table. Args: conn: the database connection object. select (str): the input data to predict. result_table (str): the output data table. train_label_desc (FieldDesc): the FieldDesc of the trained label. pred_label_name (str): the output label name to predict. Returns: A tuple of (result_column_names, train_label_index). """ name_and_types = db.selected_columns_and_types(conn, select) train_label_index = -1 if train_label_desc: for i, (name, _) in enumerate(name_and_types): if name == train_label_desc.name: train_label_index = i break if train_label_index >= 0: del name_and_types[train_label_index] column_strs = [] for name, typ in name_and_types: column_strs.append("%s %s" % (name, db.to_db_field_type(conn.driver, typ))) if train_label_desc and train_label_desc.format == DataFormat.PLAIN: train_label_field_type = DataType.to_db_field_type( conn.driver, train_label_desc.dtype) else: # if no train lable description is provided (clustering), # we treat the column type as string train_label_field_type = DataType.to_db_field_type( conn.driver, DataType.STRING) column_strs.append("%s %s" % (pred_label_name, train_label_field_type)) drop_sql = "DROP TABLE IF EXISTS %s;" % result_table create_sql = "CREATE TABLE %s (%s);" % (result_table, ",".join(column_strs)) conn.execute(drop_sql) conn.execute(create_sql) result_column_names = [item[0] for item in name_and_types] result_column_names.append(pred_label_name) return result_column_names, train_label_index
def create_evaluate_table(conn, result_table, validation_metrics): """ Create the result table to store the evaluation result. Args: conn: the database connection object. result_table (str): the output data table. validation_metrics (list[str]): the evaluation metric names. Returns: The column names of the created table. """ result_columns = ['loss'] + validation_metrics float_field_type = DataType.to_db_field_type(conn.driver, DataType.FLOAT32) column_strs = [ "%s %s" % (name, float_field_type) for name in result_columns ] drop_sql = "DROP TABLE IF EXISTS %s;" % result_table create_sql = "CREATE TABLE %s (%s);" % (result_table, ",".join(column_strs)) conn.execute(drop_sql) conn.execute(create_sql) return result_columns
def _create_result_table(datasource, select, variables, result_value_name, variable_type, result_table): if variable_type.endswith('Integers') or variable_type == "Binary": result_type = DataType.INT64 elif variable_type.endswith('Reals'): result_type = DataType.FLOAT32 else: raise ValueError("unsupported variable type %s" % variable_type) conn = db.connect_with_data_source(datasource) name_and_types = dict(db.selected_columns_and_types(conn, select)) columns = [] for var in variables: field_type = db.to_db_field_type(conn.driver, name_and_types.get(var)) columns.append("%s %s" % (var, field_type)) if len(variables) == 1 and variables[0].lower() == result_value_name.lower( ): result_value_name += "_value" columns.append("%s %s" % (result_value_name, DataType.to_db_field_type(conn.driver, result_type))) column_str = ",".join(columns) conn.execute("DROP TABLE IF EXISTS %s" % result_table) create_sql = "CREATE TABLE %s (%s)" % (result_table, column_str) conn.execute(create_sql) conn.close()
def shap_explain(booster, datasource, dataset, summary_params, result_table): tree_explainer = shap.TreeExplainer(booster) shap_values = tree_explainer.shap_values(dataset) if result_table: conn = db.connect_with_data_source(datasource) # TODO(typhoonzero): the shap_values is may be a # list of shape [3, num_samples, num_features], # use the first dimension here, should find out # when to use the other two. When shap_values is # not a list it can be directly used. if isinstance(shap_values, list): to_write = shap_values[0] else: to_write = shap_values columns = list(dataset.columns) dtypes = [DataType.to_db_field_type(conn.driver, DataType.FLOAT32) ] * len(columns) _create_table(conn, result_table, columns, dtypes) with db.buffered_db_writer(conn, result_table, columns) as w: for row in to_write: w.write(list(row)) conn.close() if summary_params.get("plot_type") == "decision": shap_interaction_values = tree_explainer.shap_interaction_values( dataset) expected_value = tree_explainer.expected_value if isinstance(shap_interaction_values, list): shap_interaction_values = shap_interaction_values[0] if isinstance(expected_value, list): expected_value = expected_value[0] plot_func = lambda: shap.decision_plot( # noqa: E731 expected_value, shap_interaction_values, dataset, show=False, feature_display_range=slice(None, -40, -1), alpha=1) else: plot_func = lambda: shap.summary_plot( # noqa: E731 shap_values, dataset, show=False, **summary_params) filename = 'summary.png' with temp_file.TemporaryDirectory(as_cwd=True): explainer.plot_and_save(plot_func, filename=filename) with open(filename, 'rb') as f: img = f.read() img = base64.b64encode(img) if six.PY3: img = img.decode('utf-8') img = "<div align='center'><img src='data:image/png;base64,%s' /></div>" \ % img print(img)
def create_explain_table(conn, model_type, explainer, estimator_string, result_table, feature_column_names): drop_sql = "DROP TABLE IF EXISTS %s;" % result_table conn.execute(drop_sql) if model_type == EstimatorType.PAIML: return elif model_type == EstimatorType.TENSORFLOW and \ estimator_string in ("BoostedTreesClassifier", "BoostedTreesRegressor"): # Tensorflow boosted trees model explain: columns = ["feature", "dfc", "gain"] dtypes = [ DataType.to_db_field_type(conn.driver, DataType.STRING), DataType.to_db_field_type(conn.driver, DataType.FLOAT32), DataType.to_db_field_type(conn.driver, DataType.FLOAT32), ] elif model_type == EstimatorType.XGBOOST and \ explainer == "XGBoostExplainer": columns = ["feature", "fscore", "gain"] dtypes = [ DataType.to_db_field_type(conn.driver, DataType.STRING), DataType.to_db_field_type(conn.driver, DataType.FLOAT32), DataType.to_db_field_type(conn.driver, DataType.FLOAT32), ] else: # shap explain result columns = feature_column_names dtypes = [DataType.to_db_field_type(conn.driver, DataType.FLOAT32) ] * len(columns) column_strs = [ "%s %s" % (name, dtype) for name, dtype in zip(columns, dtypes) ] create_sql = "CREATE TABLE %s (%s);" % (result_table, ",".join(column_strs)) conn.execute(create_sql)
def shap_explain(booster, datasource, select, summary_params, result_table, model): train_fc_map = model.get_meta("features") label_meta = model.get_meta("label").get_field_desc()[0].to_dict( dtype_to_string=True) field_descs = get_ordered_field_descs(train_fc_map) feature_column_names = [fd.name for fd in field_descs] feature_metas = dict([(fd.name, fd.to_dict(dtype_to_string=True)) for fd in field_descs]) # NOTE: in the current implementation, we are generating a transform_fn # from the COLUMN clause. The transform_fn is executed during the process # of dumping the original data into DMatrix SVM file. compiled_fc = compile_ir_feature_columns(train_fc_map, model.get_type()) transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer( feature_column_names, *compiled_fc["feature_columns"]) dataset = xgb_shap_dataset(datasource, select, feature_column_names, label_meta, feature_metas, transform_fn) tree_explainer = shap.TreeExplainer(booster) shap_values = tree_explainer.shap_values(dataset) if result_table: conn = db.connect_with_data_source(datasource) # TODO(typhoonzero): the shap_values is may be a # list of shape [3, num_samples, num_features], # use the first dimension here, should find out # when to use the other two. When shap_values is # not a list it can be directly used. if isinstance(shap_values, list): to_write = shap_values[0] else: to_write = shap_values columns = list(dataset.columns) dtypes = [DataType.to_db_field_type(conn.driver, DataType.FLOAT32) ] * len(columns) _create_table(conn, result_table, columns, dtypes) with db.buffered_db_writer(conn, result_table, columns) as w: for row in to_write: w.write(list(row)) conn.close() if summary_params.get("plot_type") == "decision": shap_interaction_values = tree_explainer.shap_interaction_values( dataset) expected_value = tree_explainer.expected_value if isinstance(shap_interaction_values, list): shap_interaction_values = shap_interaction_values[0] if isinstance(expected_value, list): expected_value = expected_value[0] plot_func = lambda: shap.decision_plot( # noqa: E731 expected_value, shap_interaction_values, dataset, show=False, feature_display_range=slice(None, -40, -1), alpha=1) else: plot_func = lambda: shap.summary_plot( # noqa: E731 shap_values, dataset, show=False, **summary_params) filename = 'summary.png' with temp_file.TemporaryDirectory(as_cwd=True): explainer.plot_and_save(plot_func, filename=filename) with open(filename, 'rb') as f: img = f.read() img = base64.b64encode(img) if six.PY3: img = img.decode('utf-8') img = "<div align='center'><img src='data:image/png;base64,%s' /></div>" \ % img print(img)
def explain(datasource, select, explainer, model_params, result_table, model): """ Do explanation to a trained TensorFlow model. Args: datasource (str): the database connection string. select (str): the input data to predict. explainer (str): the explainer to explain the model. Not used in TensorFlow models. model_params (dict): the parameters for evaluation. result_table (str): the output data table. model (Model|str): the model object or where to load the model. Returns: None. """ if isinstance(model, six.string_types): model = Model.load_from_db(datasource, model) else: assert isinstance(model, Model), "not supported model type %s" % type(model) plot_type = model_params.get("summary.plot_type", "bar") train_attributes = model.get_meta("attributes") train_fc_map = model.get_meta("features") train_label_desc = model.get_meta("label").get_field_desc()[0] estimator_string = model.get_meta("class_name") save = "model_save" field_descs = get_ordered_field_descs(train_fc_map) feature_column_names = [fd.name for fd in field_descs] feature_metas = dict([(fd.name, fd.to_dict(dtype_to_string=True)) for fd in field_descs]) feature_columns = compile_ir_feature_columns(train_fc_map, model.get_type()) label_name = model_params.get("label_col", train_label_desc.name) train_label_desc.name = label_name label_meta = train_label_desc.to_dict(dtype_to_string=True) if result_table: conn = db.connect_with_data_source(datasource) if estimator_string.startswith("BoostedTrees"): column_defs = [ "feature %s" % DataType.to_db_field_type(conn.driver, DataType.STRING), "dfc %s" % DataType.to_db_field_type(conn.driver, DataType.FLOAT32), "gain %s" % DataType.to_db_field_type(conn.driver, DataType.FLOAT32), ] else: selected_cols = db.selected_cols(conn, select) if label_name in selected_cols: selected_cols.remove(label_name) name_to_shape = dict([(fd.name, fd.shape) for fd in field_descs]) column_defs = [] float_field_type = DataType.to_db_field_type( conn.driver, DataType.FLOAT32) for name in selected_cols: shape = name_to_shape.get(name, None) if shape is None: raise ValueError("cannot find column %s" % name) size = int(np.prod(shape)) if size == 1: column_def = "%s %s" % (name, float_field_type) column_defs.append(column_def) else: for i in six.moves.range(size): column_def = "%s_%d %s" % (name, i, float_field_type) column_defs.append(column_def) drop_sql = "DROP TABLE IF EXISTS %s;" % result_table create_sql = "CREATE TABLE %s (%s);" % (result_table, ",".join(column_defs)) conn.execute(drop_sql) conn.execute(create_sql) conn.close() _explain(datasource=datasource, estimator_string=estimator_string, select=select, feature_columns=feature_columns, feature_column_names=feature_column_names, feature_metas=feature_metas, label_meta=label_meta, model_params=train_attributes, save=save, plot_type=plot_type, result_table=result_table) with open('summary.png', 'rb') as f: img = f.read() img = base64.b64encode(img) if six.PY3: img = img.decode('utf-8') img = "<div align='center'><img src='data:image/png;base64,%s' /></div>" \ % img print(img)