예제 #1
0
def get_store_contents():
    """
    Return an ordered tuple of attributes representing the store contents.
    Useful for ensuring key properties stay the same when switching between systems.
    """
    _get_one = [
        serialized_dataframe(global_state.get_data('1')),
        global_state.get_dtypes('1'),
        global_state.get_settings('1'),
        global_state.get_metadata('1'),
        global_state.get_context_variables('1'),
        global_state.get_history('1'),
    ]
    _get_all = [
        {
            k: serialized_dataframe(v)
            for k, v in global_state.get_data().items()
        },
        global_state.get_dtypes(),
        global_state.get_settings(),
        global_state.get_metadata(),
        global_state.get_context_variables(),
        global_state.get_history(),
    ]
    _lengths = [
        len(global_state.DATA),
        len(global_state.DTYPES),
        len(global_state.SETTINGS),
        len(global_state.METADATA),
        len(global_state.CONTEXT_VARIABLES),
        len(global_state.HISTORY),
    ]
    return (_get_one, _get_all, _lengths)
예제 #2
0
def get_store_contents():
    """
    Return an ordered tuple of attributes representing the store contents.
    Useful for ensuring key properties stay the same when switching between systems.
    """
    _get_one = [
        serialized_dataframe(global_state.get_data("1")),
        global_state.get_dtypes("1"),
        global_state.get_settings("1"),
        global_state.get_metadata("1"),
        global_state.get_context_variables("1"),
        global_state.get_history("1"),
    ]
    _get_all = [
        {
            int(k): serialized_dataframe(v.data)
            for k, v in global_state.items()
        },
        {int(k): v.dtypes
         for k, v in global_state.items()},
        {int(k): v.settings
         for k, v in global_state.items()},
        {int(k): v.metadata
         for k, v in global_state.items()},
        {int(k): v.context_variables
         for k, v in global_state.items()},
        {int(k): v.history
         for k, v in global_state.items()},
    ]
    _lengths = [
        global_state.size(),
    ]
    return (_get_one, _get_all, _lengths)
예제 #3
0
 def reshape(self):
     data = run_query(
         global_state.get_data(self.data_id),
         (global_state.get_settings(self.data_id) or {}).get("query"),
         global_state.get_context_variables(self.data_id),
     )
     return self.builder.reshape(data)
예제 #4
0
 def run(self):
     data = run_query(
         global_state.get_data(self.data_id),
         (global_state.get_settings(self.data_id) or {}).get("query"),
         global_state.get_context_variables(self.data_id),
     )
     return self.report.run(data)
예제 #5
0
def test_transpose(custom_data, unittest):
    from dtale.views import build_dtypes_state

    global_state.clear_store()
    with app.test_client() as c:
        data = {c.port: custom_data}
        dtypes = {c.port: build_dtypes_state(custom_data)}
        settings = {c.port: {}}

        build_data_inst(data)
        build_dtypes(dtypes)
        build_settings(settings)
        reshape_cfg = dict(index=["security_id"], columns=["Col0"])
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="new",
                              type="transpose",
                              cfg=json.dumps(reshape_cfg)),
        )
        response_data = json.loads(resp.data)
        new_key = int(c.port) + 1
        assert "error" in response_data

        min_date = custom_data["date"].min().strftime("%Y-%m-%d")
        global_state.set_settings(c.port,
                                  dict(query="date == '{}'".format(min_date)))
        reshape_cfg = dict(index=["date", "security_id"], columns=["Col0"])
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="new",
                              type="transpose",
                              cfg=json.dumps(reshape_cfg)),
        )
        response_data = json.loads(resp.data)
        assert response_data["data_id"] == new_key
        assert len(global_state.keys()) == 2
        unittest.assertEqual(
            [d["name"] for d in global_state.get_dtypes(new_key)],
            [
                "index",
                "{} 00:00:00 100000".format(min_date),
                "{} 00:00:00 100001".format(min_date),
            ],
        )
        assert len(global_state.get_data(new_key)) == 1
        assert global_state.get_settings(new_key).get(
            "startup_code") is not None
        c.get("/dtale/cleanup-datasets", query_string=dict(dataIds=new_key))

        reshape_cfg = dict(index=["date", "security_id"])
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="override",
                              type="transpose",
                              cfg=json.dumps(reshape_cfg)),
        )
        response_data = json.loads(resp.data)
        assert response_data["data_id"] == c.port
예제 #6
0
def build_code_export(data_id, imports='import pandas as pd\n\n', query=None):
    """
    Helper function for building a string representing the code that was run to get the data you are viewing to that
    point.

    :param data_id: integer string identifier for a D-Tale process's data
    :type data_id: str
    :param imports: string representing the imports at the top of the code string
    :type imports: string, optional
    :param query: pandas dataframe query string
    :type query: str, optional
    :return: python code string
    """
    history = global_state.get_history(data_id) or []
    settings = global_state.get_settings(data_id) or {}
    ctxt_vars = global_state.get_context_variables(data_id)

    startup_code = settings.get('startup_code')
    startup_code = '# Data Re-shaping\n{}\n\n'.format(
        startup_code) if startup_code else ''
    startup_str = (
        "# DISCLAIMER: 'df' refers to the data you passed in when calling 'dtale.show'\n\n"
        '{imports}'
        '{startup}'
        'if isinstance(df, (pd.DatetimeIndex, pd.MultiIndex)):\n'
        '\tdf = df.to_frame(index=False)\n\n'
        '# remove any pre-existing indices for ease of use in the D-Tale code, but this is not required\n'
        "df = df.reset_index().drop('index', axis=1, errors='ignore')\n"
        'df.columns = [str(c) for c in df.columns]  # update columns to strings in case they are numbers\n'
    ).format(imports=imports, startup=startup_code)
    final_history = [startup_str] + history
    final_query = query
    if final_query is None:
        final_query = settings.get('query')

    if final_query is not None:
        if len(ctxt_vars or {}):
            final_history.append((
                "\n# this is injecting any context variables you may have passed into 'dtale.show'\n"
                "import dtale.global_state as dtale_global_state\n"
                "\n# DISCLAIMER: running this line in a different process than the one it originated will produce\n"
                "#             differing results\n"
                "ctxt_vars = dtale_global_state.get_context_variables('{data_id}')\n\n"
                "df = df.query('{query}', local_dict=ctxt_vars)\n").format(
                    query=final_query, data_id=data_id))
        else:
            final_history.append("df = df.query('{}')\n".format(final_query))
    elif 'query' in settings:
        final_history.append("df = df.query('{}')\n".format(settings['query']))
    if 'sort' in settings:
        cols, dirs = [], []
        for col, dir in settings['sort']:
            cols.append(col)
            dirs.append('True' if dir == 'ASC' else 'False')
        final_history.append(
            "df = df.sort_values(['{cols}'], ascending=[{dirs}])\n".format(
                cols=', '.join(cols), dirs="', '".join(dirs)))
    return final_history
예제 #7
0
 def display_page(_ts, pathname, search):
     """
     dash callback which gets called on initial load of each dash page (main & popup)
     """
     dash_app.config.suppress_callback_exceptions = False
     params = chart_url_params(search)
     data_id = get_data_id(pathname)
     df = global_state.get_data(data_id)
     settings = global_state.get_settings(data_id) or {}
     return charts_layout(df, settings, **params)
예제 #8
0
 def save_filter(self):
     curr_settings = global_state.get_settings(self.data_id)
     curr_filters = curr_settings.get('columnFilters') or {}
     fltr = self.builder.build_filter()
     if fltr is None:
         curr_filters.pop(self.column, None)
     else:
         curr_filters[self.column] = fltr
     curr_settings['columnFilters'] = curr_filters
     global_state.set_settings(self.data_id, curr_settings)
예제 #9
0
def load_filterable_data(data_id, req, query=None):
    filtered = get_bool_arg(req, "filtered")
    curr_settings = global_state.get_settings(data_id) or {}
    if filtered:
        final_query = query or build_query(data_id, curr_settings.get("query"))
        return run_query(
            handle_predefined(data_id),
            final_query,
            global_state.get_context_variables(data_id),
            ignore_empty=True,
        )
    return global_state.get_data(data_id)
예제 #10
0
 def save_filter(self):
     curr_settings = global_state.get_settings(self.data_id)
     filters_key = '{}Filters'.format('outlier' if self.cfg['type'] ==
                                      'outliers' else 'column')
     curr_filters = curr_settings.get(filters_key) or {}
     fltr = self.builder.build_filter()
     if fltr is None:
         curr_filters.pop(self.column, None)
     else:
         curr_filters[self.column] = fltr
     curr_settings[filters_key] = curr_filters
     global_state.set_settings(self.data_id, curr_settings)
     return curr_filters
예제 #11
0
 def save_filter(self):
     curr_settings = global_state.get_settings(self.data_id)
     filters_key = "{}Filters".format("outlier" if self.cfg["type"] ==
                                      "outliers" else "column")
     curr_filters = curr_settings.get(filters_key) or {}
     fltr = self.builder.build_filter()
     if fltr is None:
         curr_filters.pop(self.column, None)
     else:
         curr_filters[self.column] = fltr
     curr_settings[filters_key] = curr_filters
     global_state.set_settings(self.data_id, curr_settings)
     return curr_filters
예제 #12
0
    def execute(self):
        from dtale.views import startup

        data = global_state.get_data(self.data_id)
        try:
            df, code = self.checker.remove(data)
            instance = startup(data=df, **self.checker.startup_kwargs)
            curr_settings = global_state.get_settings(instance._data_id)
            global_state.set_settings(
                instance._data_id,
                dict_merge(curr_settings, dict(startup_code=code)),
            )
            return instance._data_id
        except NoDuplicatesException:
            return self.data_id
예제 #13
0
파일: views.py 프로젝트: standingtree/dtale
 def display_page(pathname, search):
     """
     dash callback which gets called on initial load of each dash page (main & popup)
     """
     dash_app.config.suppress_callback_exceptions = False
     if pathname is None:
         raise PreventUpdate
     params = chart_url_params(search)
     params["data_id"] = params.get("data_id") or get_data_id(pathname)
     df = global_state.get_data(params["data_id"])
     settings = global_state.get_settings(params["data_id"]) or {}
     return html.Div(
         charts_layout(df, settings, **params) + saved_charts.build_layout(),
         className="charts-body",
     )
예제 #14
0
파일: query.py 프로젝트: unachka/dtale
def handle_predefined(data_id, df=None):
    import dtale.predefined_filters as predefined_filters

    df = global_state.get_data(data_id) if df is None else df
    filters = predefined_filters.get_filters()
    if not filters:
        return df

    curr_settings = global_state.get_settings(data_id) or {}
    filter_values = curr_settings.get("predefinedFilters")
    if not filter_values:
        return df

    for f in filters:
        if f.name in filter_values:
            df = f.handler(df, filter_values[f.name])
    return df
예제 #15
0
def test_convert():
    from dtale.views import startup
    from tests.dtale.test_replacements import replacements_data
    import dtale.global_state as global_state

    global_state.clear_store()
    with app.test_client() as c:
        global_state.new_data_inst(c.port)
        startup(URL, data=replacements_data(), data_id=c.port)

        resp = c.get(
            "/dtale/to-xarray/{}".format(c.port),
            query_string=dict(index=json.dumps(["a"])),
        )
        assert resp.status_code == 200
        assert global_state.get_dataset(c.port) is not None
        assert global_state.get_settings(c.port)["locked"] == ["a"]
예제 #16
0
def test_resample(unittest):
    from dtale.views import build_dtypes_state, format_data

    start, end = "2000-10-01 23:30:00", "2000-10-03 00:30:00"
    rng = pd.date_range(start, end, freq="7min")
    ts = pd.Series(np.arange(len(rng)) * 3, index=rng)
    ts2 = pd.Series(np.arange(len(rng)) * 0.32, index=rng)
    df = pd.DataFrame(data={"col1": ts, "col2": ts2})
    df, _ = format_data(df)

    global_state.clear_store()
    with app.test_client() as c:
        data = {c.port: df}
        dtypes = {c.port: build_dtypes_state(df)}
        settings = {c.port: {}}

        build_data_inst(data)
        build_dtypes(dtypes)
        build_settings(settings)
        reshape_cfg = dict(index="index",
                           columns=["col1"],
                           freq="17min",
                           agg="mean")
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="new",
                              type="resample",
                              cfg=json.dumps(reshape_cfg)),
        )

        response_data = json.loads(resp.data)
        new_key = int(c.port) + 1
        assert response_data["data_id"] == new_key
        assert len(global_state.keys()) == 2
        unittest.assertEqual(
            [d["name"] for d in global_state.get_dtypes(new_key)],
            ["index_17min", "col1"],
        )
        assert len(global_state.get_data(new_key)) == 90
        assert global_state.get_settings(new_key).get(
            "startup_code") is not None
        c.get("/dtale/cleanup-datasets", query_string=dict(dataIds=new_key))
예제 #17
0
    def __init__(self, data_id, req):
        self.data_id = data_id
        self.analysis_type = get_str_arg(req, "type")
        curr_settings = global_state.get_settings(data_id) or {}
        self.query = build_query(data_id, curr_settings.get("query"))
        data = load_filterable_data(data_id, req, query=self.query)
        self.selected_col = find_selected_column(
            data, get_str_arg(req, "col", "values")
        )
        self.data = data[~pd.isnull(data[self.selected_col])]
        self.dtype = find_dtype(self.data[self.selected_col])
        self.classifier = classify_type(self.dtype)
        self.code = build_code_export(
            data_id,
            imports="{}\n".format(
                "\n".join(
                    [
                        "import numpy as np",
                        "import pandas as pd",
                        "import plotly.graph_objs as go",
                    ]
                )
            ),
        )

        if self.analysis_type is None:
            self.analysis_type = (
                "histogram" if self.classifier in ["F", "I", "D"] else "value_counts"
            )

        if self.analysis_type == "geolocation":
            self.analysis = GeolocationAnalysis(req)
        elif self.analysis_type == "histogram":
            self.analysis = HistogramAnalysis(req)
        elif self.analysis_type == "categories":
            self.analysis = CategoryAnalysis(req)
        elif self.analysis_type == "value_counts":
            self.analysis = ValueCountAnalysis(req)
        elif self.analysis_type == "word_value_counts":
            self.analysis = WordValueCountAnalysis(req)
        elif self.analysis_type == "qq":
            self.analysis = QQAnalysis()
예제 #18
0
def build_query(data_id, query=None):
    curr_settings = global_state.get_settings(data_id) or {}
    return inner_build_query(curr_settings, query)
예제 #19
0
def test_pivot(custom_data, unittest):
    from dtale.views import build_dtypes_state

    global_state.clear_store()
    with app.test_client() as c:
        data = {c.port: custom_data}
        dtypes = {c.port: build_dtypes_state(custom_data)}
        settings = {c.port: {}}

        build_data_inst(data)
        build_dtypes(dtypes)
        build_settings(settings)
        reshape_cfg = dict(index=["date"],
                           columns=["security_id"],
                           values=["Col0"],
                           aggfunc="mean")
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="new",
                              type="pivot",
                              cfg=json.dumps(reshape_cfg)),
        )
        response_data = json.loads(resp.data)
        new_key = int(c.port) + 1
        assert response_data["data_id"] == new_key
        assert len(global_state.keys()) == 2
        unittest.assertEqual(
            [d["name"] for d in global_state.get_dtypes(new_key)],
            ["date", "100000", "100001"],
        )
        assert len(global_state.get_data(new_key)) == 365
        assert global_state.get_settings(new_key).get(
            "startup_code") is not None

        resp = c.get("/dtale/cleanup-datasets",
                     query_string=dict(dataIds=new_key))
        assert json.loads(resp.data)["success"]
        assert len(global_state.keys()) == 1

        reshape_cfg["columnNameHeaders"] = True
        reshape_cfg["aggfunc"] = "sum"
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="new",
                              type="pivot",
                              cfg=json.dumps(reshape_cfg)),
        )
        response_data = json.loads(resp.data)
        assert response_data["data_id"] == new_key
        assert len(global_state.keys()) == 2
        unittest.assertEqual(
            [d["name"] for d in global_state.get_dtypes(new_key)],
            ["date", "security_id-100000", "security_id-100001"],
        )
        assert len(global_state.get_data(new_key)) == 365
        assert global_state.get_settings(new_key).get(
            "startup_code") is not None
        c.get("/dtale/cleanup-datasets", query_string=dict(dataIds=new_key))

        reshape_cfg["columnNameHeaders"] = False
        reshape_cfg["values"] = ["Col0", "Col1"]
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="new",
                              type="pivot",
                              cfg=json.dumps(reshape_cfg)),
        )
        response_data = json.loads(resp.data)
        assert response_data["data_id"] == new_key
        assert len(global_state.keys()) == 2
        unittest.assertEqual(
            [d["name"] for d in global_state.get_dtypes(new_key)],
            [
                "date", "Col0 100000", "Col0 100001", "Col1 100000",
                "Col1 100001"
            ],
        )
        assert len(global_state.get_data(new_key)) == 365
        assert global_state.get_settings(new_key).get(
            "startup_code") is not None
        c.get("/dtale/cleanup-datasets", query_string=dict(dataIds=new_key))
예제 #20
0
def test_aggregate(custom_data, unittest):
    from dtale.views import build_dtypes_state

    global_state.clear_store()
    with app.test_client() as c:
        data = {c.port: custom_data}
        dtypes = {c.port: build_dtypes_state(custom_data)}
        settings = {c.port: {}}

        build_data_inst(data)
        build_dtypes(dtypes)
        build_settings(settings)
        reshape_cfg = dict(
            index="date",
            agg=dict(type="col",
                     cols={
                         "Col0": ["sum", "mean"],
                         "Col1": ["count"]
                     }),
        )
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="new",
                              type="aggregate",
                              cfg=json.dumps(reshape_cfg)),
        )
        response_data = json.loads(resp.data)
        new_key = int(c.port) + 1
        assert response_data["data_id"] == new_key
        assert len(global_state.keys()) == 2
        unittest.assertEqual(
            [d["name"] for d in global_state.get_dtypes(new_key)],
            ["date", "Col0 sum", "Col0 mean", "Col1 count"],
        )
        assert len(global_state.get_data(new_key)) == 365
        assert global_state.get_settings(new_key).get(
            "startup_code") is not None
        c.get("/dtale/cleanup-datasets", query_string=dict(dataIds=new_key))

        reshape_cfg = dict(index="date",
                           agg=dict(type="func",
                                    func="mean",
                                    cols=["Col0", "Col1"]))
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="new",
                              type="aggregate",
                              cfg=json.dumps(reshape_cfg)),
        )
        response_data = json.loads(resp.data)
        assert response_data["data_id"] == new_key
        assert len(global_state.keys()) == 2
        unittest.assertEqual(
            [d["name"] for d in global_state.get_dtypes(new_key)],
            ["date", "Col0", "Col1"],
        )
        assert len(global_state.get_data(new_key)) == 365
        assert global_state.get_settings(new_key).get(
            "startup_code") is not None
        c.get("/dtale/cleanup-datasets", query_string=dict(dataIds=new_key))

        reshape_cfg = dict(index="date", agg=dict(type="func", func="mean"))
        resp = c.get(
            "/dtale/reshape/{}".format(c.port),
            query_string=dict(output="new",
                              type="aggregate",
                              cfg=json.dumps(reshape_cfg)),
        )
        response_data = json.loads(resp.data)
        assert response_data["data_id"] == new_key
        assert len(global_state.keys()) == 2
        unittest.assertEqual(
            [d["name"] for d in global_state.get_dtypes(new_key)],
            [
                "date", "security_id", "int_val", "Col0", "Col1", "Col2",
                "bool_val"
            ],
        )
        assert len(global_state.get_data(new_key)) == 365
        assert global_state.get_settings(new_key).get(
            "startup_code") is not None
        c.get("/dtale/cleanup-datasets", query_string=dict(dataIds=new_key))
예제 #21
0
    def load_drilldown_content(
        _click_data_ts,
        drilldown_type,
        drilldown_x,
        inputs,
        chart_inputs,
        yaxis_data,
        map_data,
        click_data,
        drilldowns_on,
    ):
        if not drilldowns_on:
            raise PreventUpdate
        data_id = inputs["data_id"]
        all_inputs = combine_inputs(dash_app, inputs, chart_inputs, yaxis_data,
                                    map_data)
        agg = all_inputs.get("agg") or "raw"
        chart_type = all_inputs.get("chart_type")
        frame_col = all_inputs.get("animate_by")
        all_inputs.pop("animate_by", None)
        if agg == "raw":
            raise PreventUpdate
        if drilldown_x is None and chart_type != "maps":
            raise PreventUpdate
        if click_data:
            click_point = next((p for p in click_data.get("points", [])), None)
            if click_point:
                curr_settings = global_state.get_settings(data_id) or {}
                query = build_query(
                    data_id,
                    all_inputs.get("query") or curr_settings.get("query"))
                x_col = all_inputs.get("x")
                y_col = next((y2 for y2 in make_list(all_inputs.get("y"))),
                             None)
                if chart_type in ZAXIS_CHARTS:
                    x, y, z, frame = (click_point.get(p)
                                      for p in ["x", "y", "z", "customdata"])
                    if chart_type == "heatmap":
                        click_point_vals = {}
                        for dim in click_point["text"].split("<br>"):
                            prop, val = dim.split(": ")
                            click_point_vals[prop] = val
                        x, y, frame = (click_point_vals.get(p)
                                       for p in [x_col, y_col, frame_col])
                    point_filter = {x_col: x, y_col: y}
                    if frame_col:
                        point_filter[frame_col] = frame
                    if drilldown_type == "histogram":
                        z_col = next(
                            (z2 for z2 in make_list(all_inputs.get("z"))),
                            None)
                        hist_chart = build_histogram(data_id, z_col, query,
                                                     point_filter)
                        return hist_chart, dict(display="none")
                    else:
                        xy_query, _ = build_group_inputs_filter(
                            global_state.get_data(data_id),
                            [point_filter],
                        )
                        if not query:
                            query = xy_query
                        else:
                            query = "({}) and ({})".format(query, xy_query)
                        all_inputs["query"] = query
                        all_inputs["chart_type"] = drilldown_type
                        all_inputs["agg"] = "raw"
                        all_inputs["modal"] = True
                        all_inputs["x"] = drilldown_x
                        all_inputs["y"] = [all_inputs["z"]]
                        chart, _, _ = build_chart(**all_inputs)
                        return chart, None

                elif chart_type == "maps":
                    map_type = all_inputs.get("map_type")
                    point_filter = {}
                    if frame_col:
                        point_filter[frame_col] = click_point["customdata"]
                    if map_type == "choropleth":
                        point_filter[
                            all_inputs["loc"]] = click_point["location"]
                    elif map_type == "scattergeo":
                        lat, lon = (click_point.get(p) for p in ["lat", "lon"])
                        point_filter[all_inputs["lat"]] = lat
                        point_filter[all_inputs["lon"]] = lon
                    map_val = all_inputs["map_val"]
                    if drilldown_type == "histogram":
                        hist_chart = build_histogram(data_id, map_val, query,
                                                     point_filter)
                        return hist_chart, dict(display="none")
                    else:
                        map_query, _ = build_group_inputs_filter(
                            global_state.get_data(data_id),
                            [point_filter],
                        )
                        if not query:
                            query = map_query
                        else:
                            query = "({}) and ({})".format(query, map_query)
                        all_inputs["query"] = query
                        all_inputs["chart_type"] = drilldown_type
                        all_inputs["agg"] = "raw"
                        all_inputs["modal"] = True

                        data = global_state.get_data(data_id)
                        all_inputs["x"] = drilldown_x
                        x_style = None
                        if map_type != "choropleth":
                            all_inputs["x"] = "lat_lon"
                            lat, lon = (all_inputs.get(p)
                                        for p in ["lat", "lon"])
                            data.loc[:,
                                     "lat_lon"] = (data[lat].astype(str) +
                                                   "|" + data[lon].astype(str))
                            x_style = dict(display="none")
                        all_inputs["y"] = [map_val]
                        chart, _, _ = build_chart(data=data, **all_inputs)
                        return chart, x_style
                else:
                    x_filter = click_point.get("x")
                    point_filter = {x_col: x_filter}
                    if frame_col:
                        point_filter[frame_col] = click_point.get("customdata")
                    if drilldown_type == "histogram":
                        hist_chart = build_histogram(data_id, y_col, query,
                                                     point_filter)
                        return hist_chart, dict(display="none")
                    else:
                        x_query, _ = build_group_inputs_filter(
                            global_state.get_data(data_id),
                            [point_filter],
                        )
                        if not query:
                            query = x_query
                        else:
                            query = "({}) and ({})".format(query, x_query)
                        all_inputs["query"] = query
                        all_inputs["chart_type"] = drilldown_type
                        all_inputs["agg"] = "raw"
                        all_inputs["modal"] = True
                        all_inputs["x"] = drilldown_x
                        chart, _, _ = build_chart(**all_inputs)
                        return chart, None
        return None, dict(display="none")
예제 #22
0
def build_code_export(data_id, imports="import pandas as pd\n\n", query=None):
    """
    Helper function for building a string representing the code that was run to get the data you are viewing to that
    point.

    :param data_id: integer string identifier for a D-Tale process's data
    :type data_id: str
    :param imports: string representing the imports at the top of the code string
    :type imports: string, optional
    :param query: pandas dataframe query string
    :type query: str, optional
    :return: python code string
    """
    history = global_state.get_history(data_id) or []
    settings = global_state.get_settings(data_id) or {}
    ctxt_vars = global_state.get_context_variables(data_id)

    startup_code = settings.get("startup_code") or ""
    if startup_code and not startup_code.endswith("\n"):
        startup_code += "\n"
    xarray_setup = ""
    if data_id in global_state.DATASETS:
        xarray_dims = global_state.get_dataset_dim(data_id)
        if len(xarray_dims):
            xarray_setup = (
                "df = ds.sel({selectors}).to_dataframe()\n"
                "df = df.reset_index().drop('index', axis=1, errors='ignore')\n"
                "df = df.set_index(list(ds.dims.keys()))\n"
            ).format(
                selectors=", ".join(
                    "{}='{}'".format(k, v) for k, v in xarray_dims.items()
                )
            )
        else:
            xarray_setup = (
                "df = ds.to_dataframe()\n"
                "df = df.reset_index().drop('index', axis=1, errors='ignore')\n"
                "df = df.set_index(list(ds.dims.keys()))\n"
            )
    startup_str = (
        "# DISCLAIMER: 'df' refers to the data you passed in when calling 'dtale.show'\n\n"
        "{imports}"
        "{xarray_setup}"
        "{startup}"
        "if isinstance(df, (pd.DatetimeIndex, pd.MultiIndex)):\n"
        "\tdf = df.to_frame(index=False)\n\n"
        "# remove any pre-existing indices for ease of use in the D-Tale code, but this is not required\n"
        "df = df.reset_index().drop('index', axis=1, errors='ignore')\n"
        "df.columns = [str(c) for c in df.columns]  # update columns to strings in case they are numbers\n"
    ).format(imports=imports, xarray_setup=xarray_setup, startup=startup_code)
    final_history = [startup_str] + history
    final_query = query
    if final_query is None:
        final_query = settings.get("query")

    if final_query is not None and final_query != "":
        if len(ctxt_vars or {}):
            final_history.append(
                (
                    "\n# this is injecting any context variables you may have passed into 'dtale.show'\n"
                    "import dtale.global_state as dtale_global_state\n"
                    "\n# DISCLAIMER: running this line in a different process than the one it originated will produce\n"
                    "#             differing results\n"
                    "ctxt_vars = dtale_global_state.get_context_variables('{data_id}')\n\n"
                    "df = df.query({query}, local_dict=ctxt_vars)\n"
                ).format(query=triple_quote(final_query), data_id=data_id)
            )
        else:
            final_history.append(
                "df = df.query({})\n".format(triple_quote(final_query))
            )
    elif settings.get("query"):
        final_history.append(
            "df = df.query({})\n".format(triple_quote(settings["query"]))
        )
    if "sort" in settings:
        cols, dirs = [], []
        for col, dir in settings["sort"]:
            cols.append(col)
            dirs.append("True" if dir == "ASC" else "False")
        final_history.append(
            "df = df.sort_values(['{cols}'], ascending=[{dirs}])\n".format(
                cols=", ".join(cols), dirs="', '".join(dirs)
            )
        )
    return final_history