예제 #1
0
    def execute_scatter(view: Vis, tbl: LuxSQLTable):
        """
        Given a scatterplot vis and a Lux Dataframe, fetch the data required to render the vis.
        1) Generate WHERE clause for the SQL query
        2) Check number of datapoints to be included in the query
        3) If the number of datapoints exceeds 10000, perform a random sample from the original data
        4) Query datapoints needed for the scatterplot visualization
        5) return a DataFrame with relevant results

        Parameters
        ----------
        vislist: list[lux.Vis]
            vis list that contains lux.Vis objects for visualization.
        tbl : lux.core.frame
            LuxSQLTable with specified intent.

        Returns
        -------
        None
        """

        attributes = set([])
        for clause in view._inferred_intent:
            if clause.attribute:
                if clause.attribute != "Record":
                    attributes.add(clause.attribute)
        where_clause, filterVars = SQLExecutor.execute_filter(view)

        length_query = pandas.read_sql(lux.config.query_templates['length_query'].format(table_name = tbl.table_name, where_clause = where_clause),lux.config.SQLconnection,)

        def add_quotes(var_name):
            return '"' + var_name + '"'

        required_variables = attributes | set(filterVars)
        if lux.config.handle_quotes:
            required_variables = map(add_quotes, required_variables)
        required_variables_str = ",".join(required_variables)
        row_count = list(pandas.read_sql(lux.config.query_templates['length_query'].format(table_name = tbl.table_name, where_clause = where_clause),lux.config.SQLconnection,)["length"])[0]
        if row_count > lux.config.sampling_cap:
            query = lux.config.query_templates['sample_query'].format(columns = required_variables_str, table_name = tbl.table_name, where_clause = where_clause, num_rows = 10000)
            #query = f"SELECT {required_variables} FROM {tbl.table_name} {where_clause} ORDER BY random() LIMIT 10000"
        else:
            query = lux.config.query_templates['scatter_query'].format(columns = required_variables_str, table_name = tbl.table_name, where_clause = where_clause)
        data = pandas.read_sql(query, lux.config.SQLconnection)

        if len(attributes | set(filterVars)) == 2:
            assert(len(data.columns) == 2)
        else:
            assert(len(data.columns) == 3)
        view._vis_data = utils.pandas_to_lux(data)
        view._query = query
        # view._vis_data.length = list(length_query["length"])[0]

        tbl._message.add_unique(f"Large scatterplots detected: Lux is automatically binning scatterplots to heatmaps.",priority=98,)
예제 #2
0
    def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True):
        """
        Aggregate data points on an axis for bar or line charts
        Parameters
        ----------
        vis: lux.Vis
            lux.Vis object that represents a visualization
        tbl : lux.core.frame
            LuxSQLTable with specified intent.
        isFiltered: boolean
            boolean that represents whether a vis has had a filter applied to its data
        Returns
        -------
        None
        """
        x_attr = view.get_attr_by_channel("x")[0]
        y_attr = view.get_attr_by_channel("y")[0]
        has_color = False
        groupby_attr = ""
        measure_attr = ""
        if x_attr.aggregation is None or y_attr.aggregation is None:
            return
        if y_attr.aggregation != "":
            groupby_attr = x_attr
            measure_attr = y_attr
            agg_func = y_attr.aggregation
        if x_attr.aggregation != "":
            groupby_attr = y_attr
            measure_attr = x_attr
            agg_func = x_attr.aggregation
        if groupby_attr.attribute in tbl.unique_values.keys():
            attr_unique_vals = tbl.unique_values[groupby_attr.attribute]
        # checks if color is specified in the Vis
        if len(view.get_attr_by_channel("color")) == 1:
            color_attr = view.get_attr_by_channel("color")[0]
            color_attr_vals = tbl.unique_values[color_attr.attribute]
            color_cardinality = len(color_attr_vals)
            # NOTE: might want to have a check somewhere to not use categorical variables with greater than some number of categories as a Color variable----------------
            has_color = True
        else:
            color_cardinality = 1
        if measure_attr != "":
            # barchart case, need count data for each group
            if measure_attr.attribute == "Record":
                where_clause, filterVars = SQLExecutor.execute_filter(view)

                length_query = pandas.read_sql(lux.config.query_templates['length_query'].format(table_name = tbl.table_name, where_clause = where_clause),lux.config.SQLconnection,)
                # generates query for colored barchart case
                if has_color:
                    count_query = lux.config.query_templates['colored_barchart_counts'].format(groupby_attr = groupby_attr.attribute, color_attr = color_attr.attribute, table_name = tbl.table_name, where_clause = where_clause,)
                    view._vis_data = pandas.read_sql(count_query, lux.config.SQLconnection)
                    assert((len(view._vis_data.columns) == 3) & ("count" in view._vis_data.columns))
                    view._vis_data = view._vis_data.rename(columns={"count": "Record"})
                    view._vis_data = utils.pandas_to_lux(view._vis_data)
                # generates query for normal barchart case
                else:
                    count_query = lux.config.query_templates['barchart_counts'].format(groupby_attr = groupby_attr.attribute, table_name = tbl.table_name, where_clause = where_clause,)
                    view._vis_data = pandas.read_sql(count_query, lux.config.SQLconnection)
                    assert((len(view._vis_data.columns) == 2) & ("count" in view._vis_data.columns))
                    view._vis_data = view._vis_data.rename(columns={"count": "Record"})
                    view._vis_data = utils.pandas_to_lux(view._vis_data)
                view._query = count_query
                # view._vis_data.length = list(length_query["length"])[0]
            # aggregate barchart case, need aggregate data (mean, sum, max) for each group
            else:
                where_clause, filterVars = SQLExecutor.execute_filter(view)
                length_query = pandas.read_sql(lux.config.query_templates['length_query'].format(table_name = tbl.table_name, where_clause = where_clause),lux.config.SQLconnection,)
                # generates query for colored barchart case
                if has_color:
                    if agg_func == "mean":
                        agg_query = (lux.config.query_templates['colored_barchart_average'].format(groupby_attr = groupby_attr.attribute,color_attr = color_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,))
                        view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
                        assert((len(view._vis_data.columns) == 3) & (measure_attr.attribute in view._vis_data.columns))
                        view._vis_data = utils.pandas_to_lux(view._vis_data)
                    if agg_func == "sum":
                        agg_query = (lux.config.query_templates['colored_barchart_sum'].format(groupby_attr = groupby_attr.attribute,color_attr = color_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,))
                        view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
                        assert((len(view._vis_data.columns) == 3) & (measure_attr.attribute in view._vis_data.columns))
                        view._vis_data = utils.pandas_to_lux(view._vis_data)
                    if agg_func == "max":
                        agg_query = (lux.config.query_templates['colored_barchart_max'].format(groupby_attr = groupby_attr.attribute,color_attr = color_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,))
                        view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
                        assert((len(view._vis_data.columns) == 3) & (measure_attr.attribute in view._vis_data.columns))
                        view._vis_data = utils.pandas_to_lux(view._vis_data)
                # generates query for normal barchart case
                else:
                    if agg_func == "mean":
                        agg_query = lux.config.query_templates['barchart_average'].format(groupby_attr = groupby_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,)
                        view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
                        assert((len(view._vis_data.columns) == 2) & (measure_attr.attribute in view._vis_data.columns))
                        view._vis_data = utils.pandas_to_lux(view._vis_data)
                    if agg_func == "sum":
                        agg_query = lux.config.query_templates['barchart_sum'].format(groupby_attr = groupby_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,)
                        view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
                        assert((len(view._vis_data.columns) == 2) & (measure_attr.attribute in view._vis_data.columns))
                        view._vis_data = utils.pandas_to_lux(view._vis_data)
                    if agg_func == "max":
                        agg_query = lux.config.query_templates['barchart_max'].format(groupby_attr = groupby_attr.attribute,measure_attr = measure_attr.attribute,table_name = tbl.table_name,where_clause = where_clause,)
                        view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
                        assert((len(view._vis_data.columns) == 2) & (measure_attr.attribute in view._vis_data.columns))
                        view._vis_data = utils.pandas_to_lux(view._vis_data)
                view._query = agg_query
            result_vals = list(view._vis_data[groupby_attr.attribute])
            # create existing group by attribute combinations if color is specified
            # this is needed to check what combinations of group_by_attr and color_attr values have a non-zero number of elements in them
            if has_color:
                res_color_combi_vals = []
                result_color_vals = list(view._vis_data[color_attr.attribute])
                for i in range(0, len(result_vals)):
                    res_color_combi_vals.append([result_vals[i], result_color_vals[i]])
            # For filtered aggregation that have missing groupby-attribute values, set these aggregated value as 0, since no datapoints
            if isFiltered or has_color and attr_unique_vals:
                N_unique_vals = len(attr_unique_vals)
                if len(result_vals) != N_unique_vals * color_cardinality:
                    columns = view._vis_data.columns
                    if has_color:
                        df = pandas.DataFrame({columns[0]: attr_unique_vals * color_cardinality,columns[1]: pandas.Series(color_attr_vals).repeat(N_unique_vals),})
                        view._vis_data = view._vis_data.merge(df,on=[columns[0], columns[1]],how="right",suffixes=["", "_right"],)
                        for col in columns[2:]:
                            # Triggers __setitem__
                            view._vis_data[col] = view._vis_data[col].fillna(0)
                        assert len(list(view._vis_data[groupby_attr.attribute])) == N_unique_vals * len(color_attr_vals), f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute, color_attr.attribute}`."
                        # Keep only the three relevant columns not the *_right columns resulting from merge
                        view._vis_data = view._vis_data.iloc[:, :3]
                    else:
                        df = pandas.DataFrame({columns[0]: attr_unique_vals})

                        view._vis_data = view._vis_data.merge(df, on=columns[0], how="right", suffixes=["", "_right"])

                        for col in columns[1:]:
                            view._vis_data[col] = view._vis_data[col].fillna(0)
                        assert (len(list(view._vis_data[groupby_attr.attribute])) == N_unique_vals), f"Aggregated data missing values compared to original range of values of `{groupby_attr.attribute}`."
            view._vis_data = view._vis_data.sort_values(by=groupby_attr.attribute, ascending=True)
            view._vis_data = view._vis_data.reset_index()
            view._vis_data = view._vis_data.drop(columns="index")