def execute_scatter(view: Vis, tbl: LuxSQLTable): """ Given a scatterplot vis and a Lux Dataframe, fetch the data required to render the vis. 1) Generate WHERE clause for the SQL query 2) Check number of datapoints to be included in the query 3) If the number of datapoints exceeds 10000, perform a random sample from the original data 4) Query datapoints needed for the scatterplot visualization 5) return a DataFrame with relevant results Parameters ---------- vislist: list[lux.Vis] vis list that contains lux.Vis objects for visualization. tbl : lux.core.frame LuxSQLTable with specified intent. Returns ------- None """ attributes = set([]) for clause in view._inferred_intent: if clause.attribute: if clause.attribute != "Record": attributes.add(clause.attribute) where_clause, filterVars = SQLExecutor.execute_filter(view) length_query = pandas.read_sql(lux.config.query_templates['length_query'].format(table_name = tbl.table_name, where_clause = where_clause),lux.config.SQLconnection,) def add_quotes(var_name): return '"' + var_name + '"' required_variables = attributes | set(filterVars) if lux.config.handle_quotes: required_variables = map(add_quotes, required_variables) required_variables_str = ",".join(required_variables) row_count = list(pandas.read_sql(lux.config.query_templates['length_query'].format(table_name = tbl.table_name, where_clause = where_clause),lux.config.SQLconnection,)["length"])[0] if row_count > lux.config.sampling_cap: query = lux.config.query_templates['sample_query'].format(columns = required_variables_str, table_name = tbl.table_name, where_clause = where_clause, num_rows = 10000) #query = f"SELECT {required_variables} FROM {tbl.table_name} {where_clause} ORDER BY random() LIMIT 10000" else: query = lux.config.query_templates['scatter_query'].format(columns = required_variables_str, table_name = tbl.table_name, where_clause = where_clause) data = pandas.read_sql(query, lux.config.SQLconnection) if len(attributes | set(filterVars)) == 2: assert(len(data.columns) == 2) else: assert(len(data.columns) == 3) view._vis_data = utils.pandas_to_lux(data) view._query = query # view._vis_data.length = list(length_query["length"])[0] tbl._message.add_unique(f"Large scatterplots detected: Lux is automatically binning scatterplots to heatmaps.",priority=98,)
def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True): """ Aggregate data points on an axis for bar or line charts Parameters ---------- vis: lux.Vis lux.Vis object that represents a visualization tbl : lux.core.frame LuxSQLTable with specified intent. isFiltered: boolean boolean that represents whether a vis has had a filter applied to its data Returns ------- None """ x_attr = view.get_attr_by_channel("x")[0] y_attr = view.get_attr_by_channel("y")[0] has_color = False groupby_attr = "" measure_attr = "" if x_attr.aggregation is None or y_attr.aggregation is None: return if y_attr.aggregation != "": groupby_attr = x_attr measure_attr = y_attr agg_func = y_attr.aggregation if x_attr.aggregation != "": groupby_attr = y_attr measure_attr = x_attr agg_func = x_attr.aggregation if groupby_attr.attribute in tbl.unique_values.keys(): attr_unique_vals = tbl.unique_values[groupby_attr.attribute] # checks if color is specified in the Vis if len(view.get_attr_by_channel("color")) == 1: color_attr = view.get_attr_by_channel("color")[0] color_attr_vals = tbl.unique_values[color_attr.attribute] color_cardinality = len(color_attr_vals) # NOTE: might want to have a check somewhere to not use categorical variables with greater than some number of categories as a Color variable---------------- has_color = True else: color_cardinality = 1 if measure_attr != "": # barchart case, need count data for each group if measure_attr.attribute == "Record": where_clause, filterVars = SQLExecutor.execute_filter(view) length_query = pandas.read_sql(lux.config.query_templates['length_query'].format(table_name = tbl.table_name, where_clause = where_clause),lux.config.SQLconnection,) # generates query for colored barchart case if has_color: count_query = lux.config.query_templates['colored_barchart_counts'].format(groupby_attr = groupby_attr.attribute, color_attr = color_attr.attribute, table_name = tbl.table_name, where_clause = where_clause,) view._vis_data = pandas.read_sql(count_query, lux.config.SQLconnection) assert((len(view._vis_data.columns) == 3) & ("count" in view._vis_data.columns)) view._vis_data = view._vis_data.rename(columns={"count": "Record"}) view._vis_data = utils.pandas_to_lux(view._vis_data) # generates query for normal barchart case else: count_query = lux.config.query_templates['barchart_counts'].format(groupby_attr = groupby_attr.attribute, table_name = tbl.table_name, where_clause = where_clause,) view._vis_data = pandas.read_sql(count_query, lux.config.SQLconnection) assert((len(view._vis_data.columns) == 2) & ("count" in view._vis_data.columns)) view._vis_data = view._vis_data.rename(columns={"count": "Record"}) view._vis_data = utils.pandas_to_lux(view._vis_data) view._query = count_query # view._vis_data.length = list(length_query["length"])[0] # aggregate barchart case, need aggregate data (mean, sum, max) for each group else: where_clause, filterVars = SQLExecutor.execute_filter(view) length_query = pandas.read_sql(lux.config.query_templates['length_query'].format(table_name = tbl.table_name, where_clause = where_clause),lux.config.SQLconnection,) # generates query for colored barchart case if has_color: if agg_func == "mean": agg_query = (lux.config.query_templates['colored_barchart_average'].format(groupby_attr = groupby_attr.attribute,color_attr = color_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,)) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) assert((len(view._vis_data.columns) == 3) & (measure_attr.attribute in view._vis_data.columns)) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "sum": agg_query = (lux.config.query_templates['colored_barchart_sum'].format(groupby_attr = groupby_attr.attribute,color_attr = color_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,)) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) assert((len(view._vis_data.columns) == 3) & (measure_attr.attribute in view._vis_data.columns)) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "max": agg_query = (lux.config.query_templates['colored_barchart_max'].format(groupby_attr = groupby_attr.attribute,color_attr = color_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,)) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) assert((len(view._vis_data.columns) == 3) & (measure_attr.attribute in view._vis_data.columns)) view._vis_data = utils.pandas_to_lux(view._vis_data) # generates query for normal barchart case else: if agg_func == "mean": agg_query = lux.config.query_templates['barchart_average'].format(groupby_attr = groupby_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) assert((len(view._vis_data.columns) == 2) & (measure_attr.attribute in view._vis_data.columns)) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "sum": agg_query = lux.config.query_templates['barchart_sum'].format(groupby_attr = groupby_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) assert((len(view._vis_data.columns) == 2) & (measure_attr.attribute in view._vis_data.columns)) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "max": agg_query = lux.config.query_templates['barchart_max'].format(groupby_attr = groupby_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) assert((len(view._vis_data.columns) == 2) & (measure_attr.attribute in view._vis_data.columns)) view._vis_data = utils.pandas_to_lux(view._vis_data) view._query = agg_query result_vals = list(view._vis_data[groupby_attr.attribute]) # create existing group by attribute combinations if color is specified # this is needed to check what combinations of group_by_attr and color_attr values have a non-zero number of elements in them if has_color: res_color_combi_vals = [] result_color_vals = list(view._vis_data[color_attr.attribute]) for i in range(0, len(result_vals)): res_color_combi_vals.append([result_vals[i], result_color_vals[i]]) # For filtered aggregation that have missing groupby-attribute values, set these aggregated value as 0, since no datapoints if isFiltered or has_color and attr_unique_vals: N_unique_vals = len(attr_unique_vals) if len(result_vals) != N_unique_vals * color_cardinality: columns = view._vis_data.columns if has_color: df = pandas.DataFrame({columns[0]: attr_unique_vals * color_cardinality,columns[1]: pandas.Series(color_attr_vals).repeat(N_unique_vals),}) view._vis_data = view._vis_data.merge(df,on=[columns[0], columns[1]],how="right",suffixes=["", "_right"],) for col in columns[2:]: # Triggers __setitem__ view._vis_data[col] = view._vis_data[col].fillna(0) assert len(list(view._vis_data[groupby_attr.attribute])) == N_unique_vals * len(color_attr_vals), f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute, color_attr.attribute}`." # Keep only the three relevant columns not the *_right columns resulting from merge view._vis_data = view._vis_data.iloc[:, :3] else: df = pandas.DataFrame({columns[0]: attr_unique_vals}) view._vis_data = view._vis_data.merge(df, on=columns[0], how="right", suffixes=["", "_right"]) for col in columns[1:]: view._vis_data[col] = view._vis_data[col].fillna(0) assert (len(list(view._vis_data[groupby_attr.attribute])) == N_unique_vals), f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute}`." view._vis_data = view._vis_data.sort_values(by=groupby_attr.attribute, ascending=True) view._vis_data = view._vis_data.reset_index() view._vis_data = view._vis_data.drop(columns="index")