def __init__(self, tree: "DataTree"): self.tree = tree JupyterDash.infer_jupyter_proxy_config() self.app = JupyterDash(__name__, external_stylesheets=STYLESHEET) self.server = self.app.server self.model = CellModel(tree) self.scatter = Scatter(self.app, data={"model": self.model}) self.dim_dropdowns = DropDownColumn( self.app, data=dict(title_list=self.model.DIMENSIONS, options=self.model.col_options)) self.sliders = SliderColumn( data={"slider_params": self.model.MARKER_ARGS_PARAMS}) self.col_adder = ColumnAdder() self.annotator = ScatterAnnotator(data={ "model": self.model, "options": self.model.col_options }) self.colname_store = DropDownOptionsStore() self.gene_data = GeneData(self.app, data={"model": self.model}) self.colname_store.callbacks(master=self.annotator, subscribers=[ *self.dim_dropdowns.contents, self.annotator.dropdown ]) self.annotator.callbacks(self.scatter, self.col_adder) self.scatter.callbacks(self.dim_dropdowns, self.sliders) self.gene_data.callbacks(self.scatter)
# Copyright (C) 2021, RTE (http://www.rte-france.com/) # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 """ This file handles the html entry point of the application through dash components. It will generate the layout of a given page and handle the routing """ import dash_bootstrap_components as dbc # from dash import Dash from jupyter_dash import JupyterDash # We need to create app before importing the rest of the project as it uses @app decorators JupyterDash.infer_jupyter_proxy_config( ) # for binder or jupyterHub for instance app = JupyterDash(__name__, external_stylesheets=[ dbc.themes.BOOTSTRAP ]) # ,server_url="http://127.0.0.1:8050/") """ Get Imports to create layout and callbacks """ from grid2viz.main_callbacks import register_callbacks_main from grid2viz.layout import make_layout as layout from grid2viz.src.episodes.episodes_clbk import register_callbacks_episodes from grid2viz.src.overview.overview_clbk import ( register_callbacks_overview, ) # as overview_clbk from grid2viz.src.macro.macro_clbk import register_callbacks_macro # as macro_clbk from grid2viz.src.micro.micro_clbk import register_callbacks_micro # as micro_clbk """
# Import required libraries import pandas as pd import dash import dash_html_components as html import dash_core_components as dcc from dash.dependencies import Input, Output, State from jupyter_dash import JupyterDash import plotly.graph_objects as go import plotly.express as px from dash import no_update # Create a dash application app = JupyterDash(__name__) JupyterDash.infer_jupyter_proxy_config() # REVIEW1: Clear the layout and do not display exception till callback gets executed app.config.suppress_callback_exceptions = True # Read the airline data into pandas dataframe airline_data = pd.read_csv('https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBMDeveloperSkillsNetwork-DV0101EN-SkillsNetwork/Data%20Files/airline_data.csv', encoding = "ISO-8859-1", dtype={'Div1Airport': str, 'Div1TailNum': str, 'Div2Airport': str, 'Div2TailNum': str}) # List of years year_list = [i for i in range(2005, 2021, 1)] """Compute graph data for creating yearly airline performance report
def make_dash_app(self, template='plotly_white', server_mode='external', port=8050, app_type='jupyter', plot_height=680, external_stylesheets=[ 'https://codepen.io/chriddyp/pen/bWLwgP.css' ], infer_proxy=False, slider_width=140, cutout_hdu=None, cutout_size=10): """ Create a Plotly/Dash app for interactive exploration Parameters ---------- template : str `plotly` style `template <https://plotly.com/python/templates/#specifying-themes-in-graph-object-figures>`_. server_mode, port : str, int If not `None`, the app server is started with `app.run_server(mode=server_mode, port=port)`. app_type : str If ``jupyter`` then `app = jupyter_dash.JupyterDash()`, else `app = dash.Dash()` plot_height : int Height in pixels of the scatter and SED+P(z) plot windows. infer_proxy : bool Run `JupyterDash.infer_jupyter_proxy_config()`, before app initilization, e.g., for running on GoogleColab. Returns ------- app : object App object following `app_type`. """ import dash from dash import dcc from dash import html import plotly.express as px from urllib.parse import urlparse, parse_qsl, urlencode import astropy.wcs as pywcs if app_type == 'dash': app = dash.Dash(__name__, external_stylesheets=external_stylesheets) else: from jupyter_dash import JupyterDash if infer_proxy: JupyterDash.infer_jupyter_proxy_config() app = JupyterDash(__name__, external_stylesheets=external_stylesheets) PLOT_TYPES = [ 'zphot-zspec', 'Mag-redshift', 'Mass-redshift', 'UVJ', 'RA/Dec', 'UV-redshift', 'chi2-redshift' ] for _t in self.extra_plots: PLOT_TYPES.append(_t) COLOR_TYPES = ['z_phot', 'z_spec', 'mass', 'sSFR', 'chi2'] #_title = f"{self.photoz.param['MAIN_OUTPUT_FILE']}" #_subhead = f"Nobj={self.photoz.NOBJ} Nfilt={self.photoz.NFILT}" _title = [ html.Strong(self.photoz.param['MAIN_OUTPUT_FILE']), ' / N', html.Sub('obj'), f'={self.photoz.NOBJ}', ' / N', html.Sub('filt'), f'={self.photoz.NFILT}', ] slider_row_style = { 'width': '90%', 'float': 'left', 'margin-left': '10px' } slider_container = { 'width': f'{slider_width}px', 'margin-left': '-25px' } check_kwargs = dict(style={ 'text-align': 'center', 'height': '14pt', 'margin-top': '-20px' }) # bool_options = {'has_zspec': 'z_spec > 0', # 'use': 'Use == 1'} if cutout_hdu is not None: cutout_wcs = pywcs.WCS(cutout_hdu.header, relax=True) cutout_data = cutout_hdu.data print('xxx', cutout_data.shape) cutout_div = html.Div( [dcc.Graph(id='cutout-figure', style={})], style={ 'right': '70px', 'width': '120px', 'height': '120px', 'border': '1px solid rgb(200,200,200)', 'top': '10px', 'position': 'absolute' }) cutout_target = 'figure' else: cutout_div = html.Div(id='cutout-figure', style={ 'left': '1px', 'width': '1px', 'height': '1px', 'bottom': '1px', 'position': 'absolute' }) cutout_data = None cutout_target = 'children' ####### App layout app.layout = html.Div([ # Selectors html.Div( [ dcc.Location(id='url', refresh=False), html.Div([ html.Div(_title, id='title-bar', style={ 'float': 'left', 'margin-top': '4pt' }), html.Div([ html.Div([ dcc.Dropdown(id='plot-type', options=[{ 'label': i, 'value': i } for i in PLOT_TYPES], value='zphot-zspec', clearable=False, style={ 'width': '120px', 'margin-right': '5px', 'margin-left': '5px', 'font-size': '8pt' }), ], style={'float': 'left'}), html.Div([ dcc.Dropdown(id='color-type', options=[{ 'label': i, 'value': i } for i in COLOR_TYPES], value='sSFR', clearable=False, style={ 'width': '80px', 'margin-right': '5px', 'font-size': '8pt' }), ], style={ 'display': 'inline-block', 'margin-left': '10px' }), ], style={'float': 'right'}), ], style=slider_row_style), html.Div( [ html.Div([ dcc.Dropdown( id='mag-filter', options=[{ 'label': i, 'value': i } for i in self.photoz.flux_columns], value=self.DEFAULT_FILTER, style={ 'width': f'{slider_width-45}px', 'margin-right': '20px', 'font-size': '8pt' }, clearable=False), ], style={'float': 'left'}), html.Div([ dcc.RangeSlider(id='mag-slider', min=12, max=32, step=0.2, value=[18, 27], updatemode='drag', tooltip={"placement": 'left'}), dcc.Checklist(id='mag-checked', options=[{ 'label': 'AB mag', 'value': 'checked' }], value=['checked'], **check_kwargs), ], style=dict(display='inline-block', **slider_container)), # html.Div([ dcc.RangeSlider(id='chi2-slider', min=0, max=20, step=0.1, value=[0, 6], updatemode='drag', tooltip={"placement": 'left'}), dcc.Checklist(id='chi2-checked', options=[{ 'label': 'chi2', 'value': 'checked' }], value=[], **check_kwargs), ], style=dict(display='inline-block', **slider_container)), html.Div([ dcc.RangeSlider(id='nfilt-slider', min=1, max=self.MAXNFILT, step=1, value=[3, self.MAXNFILT], updatemode='drag', tooltip={"placement": 'left'}), dcc.Checklist(id='nfilt-checked', options=[{ 'label': 'nfilt', 'value': 'checked' }], value=['checked'], **check_kwargs), ], style=dict(display='inline-block', **slider_container)), ], style=slider_row_style), html.Div( [ html.Div([ dcc.RangeSlider(id='zphot-slider', min=-0.5, max=12, step=0.1, value=[0, self.ZMAX], updatemode='drag', tooltip={"placement": 'left'}), dcc.Checklist(id='zphot-checked', options=[{ 'label': 'z_phot', 'value': 'checked' }], value=['checked'], **check_kwargs), ], style=dict(float='left', **slider_container)), html.Div([ dcc.RangeSlider(id='zspec-slider', min=-0.5, max=12, step=0.1, value=[-0.5, 6.5], updatemode='drag', tooltip={"placement": 'left'}), dcc.Checklist(id='zspec-checked', options=[{ 'label': 'z_spec', 'value': 'checked' }], value=['checked'], **check_kwargs), ], style=dict(display='inline-block', **slider_container)), html.Div([ dcc.RangeSlider(id='mass-slider', min=7, max=13, step=0.1, value=[8, 11.8], updatemode='drag', tooltip={"placement": 'left'}), dcc.Checklist(id='mass-checked', options=[{ 'label': 'mass', 'value': 'checked' }], value=['checked'], **check_kwargs), ], style=dict(display='inline-block', **slider_container)), # Boolean dropdown # dcc.Dropdown(id='bool-checks', # options=[{'label': self.bool_options[k], # 'value': k} # for k in self.bool_options], # value=[], # multi=True, # style={'width':'100px', # 'display':'inline-block', # 'margin-left':'0px', # 'font-size':'8pt'}, # clearable=True), ], style=slider_row_style), ], style={ 'float': 'left', 'width': '55%' }), # Object-level controls html.Div([ html.Div([ html.Div('ID / RA,Dec.', style={ 'float': 'left', 'width': '100px', 'margin-top': '5pt' }), dcc.Input(id='id-input', type='text', style={ 'width': '120px', 'padding': '2px', 'display': 'inline', 'font-size': '8pt' }), html.Div(children='', id='match-sep', style={ 'margin': '5pt', 'display': 'inline', 'width': '50px', 'font-size': '8pt' }), dcc.RadioItems(id='sed-unit-selector', options=[{ 'label': i, 'value': i } for i in ['Fλ', 'Fν', 'νFν']], value='Fλ', labelStyle={ 'display': 'inline', 'padding': '3px', }, style={ 'display': 'inline', 'width': '130px' }) ], style={ 'width': '260pix', 'float': 'left', 'margin-right': '20px' }), ]), html.Div( [ # html.Div([ # ], style={'width':'120px', 'float':'left'}), html.Div(id='object-info', children='ID: ', style={ 'margin': 'auto', 'margin-top': '10px', 'font-size': '10pt' }) ], style={ 'float': 'right', 'width': '45%' }), # Plots html.Div( [ # Scatter plot dcc.Graph(id='sample-selection-scatter', hoverData={ 'points': [{ 'customdata': (self.df['id'][0], 1.0, -9.0) }] }, style={'width': '95%'}) ], style={ 'float': 'left', 'height': '70%', 'width': '49%' }), html.Div( [ # SED dcc.Graph(id='object-sed-figure', style={'width': '95%'}) ], style={ 'float': 'right', 'width': '49%', 'height': '70%' }), cutout_div ]) ##### Callback functions @app.callback(dash.dependencies.Output('url', 'search'), [ dash.dependencies.Input('plot-type', 'value'), dash.dependencies.Input('color-type', 'value'), dash.dependencies.Input('mag-filter', 'value'), dash.dependencies.Input('mag-slider', 'value'), dash.dependencies.Input('mass-slider', 'value'), dash.dependencies.Input('chi2-slider', 'value'), dash.dependencies.Input('nfilt-slider', 'value'), dash.dependencies.Input('zphot-slider', 'value'), dash.dependencies.Input('zspec-slider', 'value'), dash.dependencies.Input('id-input', 'value') ]) def update_url_state(plot_type, color_type, mag_filter, mag_range, mass_range, chi2_range, nfilt_range, zphot_range, zspec_range, id_input): search = f'?plot_type={plot_type}&color_type={color_type}' search += f'&mag_filter={mag_filter}' search += f'&mag={mag_range[0]},{mag_range[1]}' search += f'&mass={mass_range[0]},{mass_range[1]}' search += f'&chi2={chi2_range[0]},{chi2_range[1]}' search += f'&nfilt={nfilt_range[0]},{nfilt_range[1]}' search += f'&zphot={zphot_range[0]},{zphot_range[1]}' search += f'&zspec={zspec_range[0]},{zspec_range[1]}' if id_input is not None: search += f"&id={id_input.replace(' ', '%20')}" return search @app.callback([ dash.dependencies.Output('plot-type', 'value'), dash.dependencies.Output('color-type', 'value'), dash.dependencies.Output('mag-filter', 'value'), dash.dependencies.Output('mag-slider', 'value'), dash.dependencies.Output('mass-slider', 'value'), dash.dependencies.Output('chi2-slider', 'value'), dash.dependencies.Output('nfilt-slider', 'value'), dash.dependencies.Output('zphot-slider', 'value'), dash.dependencies.Output('zspec-slider', 'value'), dash.dependencies.Output('id-input', 'value'), ], [dash.dependencies.Input('url', 'href')]) def set_state_from_url(href): plot_type = 'zphot-zspec' color_type = 'sSFR' mag_filter = self.DEFAULT_FILTER mag_range = [18, 27] mass_range = [8, 11.6] chi2_range = [0, 4] nfilt_range = [1, self.MAXNFILT] zphot_range = [0, self.ZMAX] zspec_range = [-0.5, 6.5] id_input = None if '?' not in href: return (plot_type, color_type, mag_filter, mag_range, mass_range, chi2_range, nfilt_range, zphot_range, zspec_range, id_input) search = href.split('?')[1] params = search.split('&') for p in params: if 'plot_type' in p: val = p.split('=')[1] if val in PLOT_TYPES: plot_type = val elif 'color_type' in p: val = p.split('=')[1] if val in COLOR_TYPES: color_type = val elif 'mag_filter' in p: val = p.split('=')[1] if val in self.photoz.flux_columns: mag_filter = val elif 'mag=' in p: try: vals = [float(v) for v in p.split('=')[1].split(',')] if len(vals) == 2: mag_range = vals except ValueError: pass elif 'mass' in p: try: vals = [float(v) for v in p.split('=')[1].split(',')] if len(vals) == 2: mass_range = vals except ValueError: pass elif 'nfilt=' in p: try: vals = [int(v) for v in p.split('=')[1].split(',')] if len(vals) == 2: nfilt_range = vals except ValueError: pass elif 'zspec' in p: try: vals = [float(v) for v in p.split('=')[1].split(',')] if len(vals) == 2: zspec_range = vals except ValueError: pass elif 'zphot' in p: try: vals = [float(v) for v in p.split('=')[1].split(',')] if len(vals) == 2: zphot_range = vals except ValueError: pass elif 'id' in p: try: id_input = p.split('=')[1].replace('%20', ' ') except ValueError: id_input = None if not id_input: id_input = None return (plot_type, color_type, mag_filter, mag_range, mass_range, chi2_range, nfilt_range, zphot_range, zspec_range, id_input) @app.callback( dash.dependencies.Output('sample-selection-scatter', 'figure'), [ dash.dependencies.Input('plot-type', 'value'), dash.dependencies.Input('color-type', 'value'), dash.dependencies.Input('mag-filter', 'value'), dash.dependencies.Input('mag-slider', 'value'), dash.dependencies.Input('mag-checked', 'value'), dash.dependencies.Input('mass-slider', 'value'), dash.dependencies.Input('mass-checked', 'value'), dash.dependencies.Input('chi2-slider', 'value'), dash.dependencies.Input('chi2-checked', 'value'), dash.dependencies.Input('nfilt-slider', 'value'), dash.dependencies.Input('nfilt-checked', 'value'), dash.dependencies.Input('zphot-slider', 'value'), dash.dependencies.Input('zphot-checked', 'value'), dash.dependencies.Input('zspec-slider', 'value'), dash.dependencies.Input('zspec-checked', 'value'), dash.dependencies.Input('id-input', 'value') ]) def update_selection(plot_type, color_type, mag_filter, mag_range, mag_checked, mass_range, mass_checked, chi2_range, chi2_checked, nfilt_range, nfilt_checked, zphot_range, zphot_checked, zspec_range, zspec_checked, id_input): """ Apply slider selections """ sel = np.isfinite(self.df['z_phot']) if 'checked' in zphot_checked: sel &= (self.df['z_phot'] > zphot_range[0]) sel &= (self.df['z_phot'] < zphot_range[1]) if 'checked' in zspec_checked: sel &= (self.df['z_spec'] > zspec_range[0]) sel &= (self.df['z_spec'] < zspec_range[1]) if 'checked' in mass_checked: sel &= (self.df['mass'] > mass_range[0]) sel &= (self.df['mass'] < mass_range[1]) if 'checked' in chi2_checked: sel &= (self.df['chi2'] >= chi2_range[0]) sel &= (self.df['chi2'] <= chi2_range[1]) if 'checked' in nfilt_checked: sel &= (self.df['nusefilt'] >= nfilt_range[0]) sel &= (self.df['nusefilt'] <= nfilt_range[1]) #print('redshift: ', sel.sum()) if mag_filter is None: mag_filter = self.DEFAULT_FILTER #self.self.df['mag'] = self.ABZP #self.self.df['mag'] -= 2.5*np.log10(self.photoz.cat[mag_filter]) mag_col = 'mag_' + mag_filter if 'checked' in mag_checked: sel &= (self.df[mag_col] > mag_range[0]) sel &= (self.df[mag_col] < mag_range[1]) self.df['mag'] = self.df[mag_col] #print('mag: ', sel.sum()) if plot_type == 'zphot-zspec': sel &= self.df['z_spec'] > 0 #print('zspec: ', sel.sum()) if id_input is not None: id_i, dr_i = parse_id_input(id_input) if id_i is not None: self.df['is_selected'] = self.df['id'] == id_i sel |= self.df['is_selected'] else: self.df['is_selected'] = False else: self.df['is_selected'] = False dff = self.df[sel] # Color-coding by color-type pulldown if color_type == 'z_phot': color_kwargs = dict(color=np.clip(dff['z_phot'], *zphot_range), color_continuous_scale='portland') elif color_type == 'z_spec': color_kwargs = dict(color=np.clip(dff['z_spec'], *zspec_range), color_continuous_scale='portland') elif color_type == 'mass': color_kwargs = dict(color=np.clip(dff['mass'], *mass_range), color_continuous_scale='magma_r') elif color_type == 'chi2': color_kwargs = dict(color=np.clip(dff['chi2'], *chi2_range), color_continuous_scale='viridis') else: color_kwargs = dict(color=np.clip(dff['ssfr'], -12., -8.), color_continuous_scale='portland_r') # Scatter plot plot_defs = { 'Mass-redshift': ('z_phot', 'mass', 'z<sub>phot</sub>', 'log Stellar mass', (-0.1, self.ZMAX), (7.5, 12.5)), 'Mag-redshift': ('z_phot', 'mag', 'z<sub>phot</sub>', f'AB mag ({mag_filter})', (-0.1, self.ZMAX), (18, 28)), 'RA/Dec': ('ra', 'dec', 'R.A.', 'Dec.', self.ra_bounds, self.dec_bounds), 'zphot-zspec': ('z_spec', 'z_phot', 'z<sub>spec</sub>', 'z<sub>phot</sub>', (0, 4.5), (0, 4.5)), 'UVJ': ('vj', 'uv', '(V-J)', '(U-V)', (-0.1, 2.5), (-0.1, 2.5)), 'UV-redshift': ('z_phot', 'uv', 'z<sub>phot</sub>', '(U-V)<sub>rest</sub>', (0, 4), (-0.1, 2.50)), 'chi2-redshift': ('z_phot', 'chi2', 'z<sub>phot</sub>', 'chi<sup>2</sup>', (0, 4), (0.1, 30)) } if plot_type in self.extra_plots: args = (*self.extra_plots[plot_type], {}, color_kwargs) elif plot_type in plot_defs: args = (*plot_defs[plot_type], {}, color_kwargs) else: args = (*plot_defs['zphot-zspec'], {}, color_kwargs) fig = update_sample_scatter(dff, *args) # Update ranges for some parameters if ('Mass' in plot_type) & ('checked' in mass_checked): fig.update_yaxes(range=mass_range) if ('Mag' in plot_type) & ('checked' in mag_checked): fig.update_yaxes(range=mag_range) if ('redshift' in plot_type) & ('checked' in zphot_checked): fig.update_xaxes(range=zphot_range) if ('zspec' in plot_type) & ('checked' in zspec_checked): fig.update_yaxes(range=zspec_range) return fig def update_sample_scatter(dff, xcol, ycol, x_label, y_label, x_range, y_range, extra, color_kwargs): """ Make scatter plot """ import plotly.graph_objects as go fig = px.scatter( data_frame=dff, x=xcol, y=ycol, custom_data=['id', 'z_phot', 'mass', 'ssfr', 'mag'], **color_kwargs) htempl = '(%{x:.2f}, %{y:.2f}) <br>' htempl += 'id: %{customdata[0]:0d} z_phot: %{customdata[1]:.2f}' htempl += '<br> mag: %{customdata[4]:.1f} ' htempl += 'mass: %{customdata[2]:.2f} ssfr: %{customdata[3]:.2f}' fig.update_traces(hovertemplate=htempl, opacity=0.7) if dff['is_selected'].sum() > 0: dffs = dff[dff['is_selected']] _sel = go.Scatter(x=dffs[xcol], y=dffs[ycol], mode="markers+text", text=[f'{id}' for id in dffs['id']], textposition="bottom center", marker=dict(color='rgba(250,0,0,0.5)', size=20, symbol='circle-open')) fig.add_trace(_sel) fig.update_xaxes(range=x_range, title_text=x_label) fig.update_yaxes(range=y_range, title_text=y_label) fig.update_layout(template=template, autosize=True, showlegend=False, margin=dict(l=0, r=0, b=0, t=20, pad=0, autoexpand=True)) if plot_height is not None: fig.update_layout(height=plot_height) fig.update_traces(marker_showscale=False, selector=dict(type='scatter')) fig.update_coloraxes(showscale=False) if (xcol, ycol) == ('z_spec', 'z_phot'): _one2one = go.Scatter(x=[0, 8], y=[0, 8], mode="lines", marker=dict(color='rgba(250,0,0,0.5)')) fig.add_trace(_one2one) fig.add_annotation(text=f'N = {len(dff)} / {len(self.df)}', xref="x domain", yref="y domain", x=0.98, y=0.05, showarrow=False) return fig def sed_cutout_figure(id_i): """ SED cutout """ from plotly.subplots import make_subplots if cutout_data is not None: ix = np.where(self.df['id'] == id_i)[0] ri, di = self.df['ra'][ix], self.df['dec'][ix] xi, yi = np.squeeze(cutout_wcs.all_world2pix([ri], [di], 0)) xp = int(np.round(xi)) yp = int(np.round(yi)) slx = slice(xp - cutout_size, xp + cutout_size + 1) sly = slice(yp - cutout_size, yp + cutout_size + 1) try: cutout = cutout_data[sly, slx] except: cutout = np.zeros((2 * cutout_size, 2 * cutout_size)) fig = px.imshow(cutout, color_continuous_scale='gray_r') fig.update_coloraxes(showscale=False) fig.update_layout(width=120, height=120, margin=dict(l=0, r=0, b=0, t=0, pad=0, autoexpand=True)) fig.update_xaxes(range=(0, 2 * cutout_size), visible=False, showticklabels=False) fig.update_yaxes(range=(0, 2 * cutout_size), visible=False, showticklabels=False) return fig def parse_id_input(id_input): """ Parse input as id or (ra dec) """ if id_input in ['None', None, '']: return None, None inp_split = id_input.replace(',', ' ').split() if len(inp_split) == 1: return int(inp_split[0]), None ra, dec = np.cast[float](inp_split) cosd = np.cos(self.df['dec'] / 180 * np.pi) dx = (self.df['ra'] - ra) * cosd dy = (self.df['dec'] - dec) dr = np.sqrt(dx**2 + dy**2) * 3600. imin = np.nanargmin(dr) return self.df['id'][imin], dr[imin] @app.callback([ dash.dependencies.Output('object-sed-figure', 'figure'), dash.dependencies.Output('object-info', 'children'), dash.dependencies.Output('match-sep', 'children'), dash.dependencies.Output('cutout-figure', cutout_target) ], [ dash.dependencies.Input('sample-selection-scatter', 'hoverData'), dash.dependencies.Input('sed-unit-selector', 'value'), dash.dependencies.Input('id-input', 'value') ]) def update_object_sed(hoverData, sed_unit, id_input): """ SED + p(z) plot """ id_i, dr_i = parse_id_input(id_input) if id_i is None: id_i = hoverData['points'][0]['customdata'][0] else: if id_i not in self.zout['id']: id_i = hoverData['points'][0]['customdata'][0] if dr_i is None: match_sep = '' else: match_sep = f'{dr_i:.1f}"' show_fnu = {'Fλ': 0, 'Fν': 1, 'νFν': 2} layout_kwargs = dict(template=template, autosize=True, showlegend=False, margin=dict(l=0, r=0, b=0, t=20, pad=0, autoexpand=True)) fig = self.photoz.show_fit_plotly(id_i, show_fnu=show_fnu[sed_unit], vertical=True, panel_ratio=[0.6, 0.4], show=False, layout_kwargs=layout_kwargs) if plot_height is not None: fig.update_layout(height=plot_height) ix = self.df['id'] == id_i if ix.sum() == 0: object_info = 'ID: N/A' else: ix = np.where(ix)[0][0] ra, dec = self.df['ra'][ix], self.df['dec'][ix] object_info = [ f'ID: {id_i} | α, δ = {ra:.6f} {dec:.6f} ', ' | ', html.A('ESO', href=utils.eso_query(ra, dec, radius=1.0, unit='s')), ' | ', html.A('CDS', href=utils.cds_query(ra, dec, radius=1.0, unit='s')), ' | ', html.A('LegacySurvey', href=utils.show_legacysurvey(ra, dec, layer='ls-dr9')), html.Br(), f"z_phot: {self.df['z_phot'][ix]:.3f} ", f" | z_spec: {self.df['z_spec'][ix]:.3f}", html.Br(), f"mag: {self.df['mag'][ix]:.2f} ", f" | mass: {self.df['mass'][ix]:.2f} ", f" | sSFR: {self.df['ssfr'][ix]:.2f}", html.Br() ] if cutout_data is None: cutout_fig = [''] else: cutout_fig = sed_cutout_figure(id_i) return fig, object_info, match_sep, cutout_fig if server_mode is not None: app.run_server(mode=server_mode, port=port) return app
def graph(G, mode="external", **kwargs): """ G: a multidirectional graph kwargs are passed to the Jupyter_Dash.run_server() function. Some usefull arguments are: mode: "inline" to run app inside the jupyter nodebook, default is external debug: True or False, Usefull to catch errors during development. """ import dash from jupyter_dash import JupyterDash import dash_cytoscape as cyto from dash.dependencies import Output, Input import dash_html_components as html import dash_core_components as dcc import dash_table import networkx as nx import scConnect as cn import plotly.graph_objs as go import plotly.io as pio import pandas as pd import numpy as np import json import matplotlib import matplotlib.pyplot as plt pio.templates.default = "plotly_white" cyto.load_extra_layouts() JupyterDash.infer_jupyter_proxy_config() app = JupyterDash(__name__) server = app.server # Add a modified index string to change the title to scConnect app.index_string = ''' <!DOCTYPE html> <html> <head> {%metas%} <title>scConnect</title> {%favicon%} {%css%} </head> <body> {%app_entry%} <footer> {%config%} {%scripts%} {%renderer%} </body> </html> ''' # Add colors to each node nodes = pd.Categorical(G.nodes()) # make a list of RGBA tuples, one for each node colors = plt.cm.tab20c(nodes.codes / len(nodes.codes), bytes=True) # zip node to color color_map_nodes = dict(zip(nodes, colors)) # add these colors to original graph for node, color in color_map_nodes.items(): G.nodes[node]["color"] = color[0:3] # Save only RGB # Add colors to edges(source node color) for G for u, v, k in G.edges(keys=True): G.edges[u, v, k]["color"] = color_map_nodes[u][0:3] # load graph into used formes def G_to_flat(G, weight): G_flat = cn.graph.flatten_graph(G, weight=weight, log=True) # Add colors to edges(source node color) for G_flat for u, v, in G_flat.edges(): G_flat.edges[u, v]["color"] = color_map_nodes[u][0:3] return G_flat # produce full graph variante to extract metadata G_flat = G_to_flat(G, weight="score") G_split = cn.graph.split_graph(G) # find and sort all found interactions interactions = list(G_split.keys()) interactions.sort() G_cyto = nx.cytoscape_data(G_flat) # get min and max weight for all edges for flat and normal graph #weights = [d["weight"] for u, v, d in G_flat.edges(data=True)] scores = [d["score"] for u, v, d in G.edges(data=True)] cent = [d["centrality"] for n, d in G.nodes(data=True)] # prepare data for network graph nodes = G_cyto["elements"]["nodes"] elements = [] # collect all available genes genes = list(nodes[0]["data"]["genes"].keys()) # Styling parameters font_size = 20 # Style for network graph default_stylesheet = [{ 'selector': 'node', 'style': { 'background-color': 'data(color)', 'label': 'data(id)', 'shape': 'ellipse', 'opacity': 1, 'font-size': f'{font_size}', 'font-weight': 'bold', 'text-wrap': 'wrap', 'text-max-width': "100px", 'text-opacity': 1, 'text-outline-color': "white", 'text-outline-opacity': 1, 'text-outline-width': 2 } }, { 'selector': 'node:selected', 'style': { 'background-color': 'data(color)', 'label': 'data(id)', 'shape': 'ellipse', 'opacity': 1, 'border-color': "black", 'border-width': "5" } }, { 'selector': 'edge', 'style': { 'line-color': 'data(color)', "opacity": 0.7, "curve-style": "unbundled-bezier", "width": "data(weight)", "target-arrow-shape": "vee", "target-arrow-color": "black", 'z-index': 1, 'font-size': f'{font_size}' } }, { 'selector': 'edge:selected', 'style': { 'line-color': 'red', 'line-style': "dashed", 'opacity': 1, 'z-index': 10, } }] app.layout = html.Div( className="wrapper", children=[ # wrapper html.Div( className="header", children=[ # header html.Img(src="assets/logo.png", alt="scConnect logo"), html.Div( className="graph-info", id="graph-stat", children=[ html. H3(f'Loaded graph with {len(G.nodes())} nodes and {len(G.edges())} edges' ) ]) ]), html.Div( className="network-settings", children=[ # network settings html.H2("Network settings", style={"text-align": "center"}), html.Label("Interactions"), dcc.Dropdown(id="network-interaction", options=[{ 'label': "all interactions", 'value': "all" }] + [{ 'label': interaction, 'value': interaction } for interaction in interactions], value="all"), # select if only significant ligands and receptors should be shown html.Label("Graph weight:"), dcc.RadioItems(id="weight-select", options=[{ "label": "Score", "value": "score" }, { "label": "Log score", "value": "log_score" }, { "label": "Specificity", "value": "specificity" }, { "label": "Importance", "value": "importance" }], value="importance", labelStyle={ 'display': 'block', "margin-left": "50px" }, style={ "padding": "10px", "margin": "auto" }), html.Label("Graph Layout"), dcc.Dropdown( id="network-layout", options=[{ 'label': name.capitalize(), 'value': name } for name in [ 'grid', 'random', 'circle', 'cose', 'concentric', 'breadthfirst', 'cose-bilkent', 'cola', 'euler', 'spread', 'dagre', 'klay' ]], value="circle", clearable=False), html.Label("Weight Filter", style={ "paddingBottom": 500, "paddingTop": 500 }), dcc. Slider( # min, max and value are set dynamically via a callback id="network-filter", step=0.001, updatemode="drag", tooltip={ "always_visible": True, "placement": "right" }, ), html.Label("Node size"), dcc.RangeSlider(id="node-size", value=[10, 50], min=0, max=100, updatemode="drag"), html.Label("Select gene"), dcc.Dropdown( id="gene_dropdown", options=[{ "label": gene, "value": gene } for gene in genes], clearable=True, placeholder="Color by gene expression", ), # Store node colors "hidden" for gene expresison html.Div(id="node-colors", style={"display": "none"}, children=[""]), html.Div(id="min-max", children=[]), # Click to download image of network graph html.Button(children="Download current view", id="download-network-graph", style={"margin": "10px"}) ]), # end network settings html.Div( id="network-graph", className="network-graph", children=[ # network graph html.H2("Network graph", style={"text-align": "center"}), cyto.Cytoscape(id="cyto-graph", style={ 'width': '100%', 'height': '80vh' }, stylesheet=default_stylesheet, elements=elements, autoRefreshLayout=True, zoomingEnabled=False) ]), # end network graph html.Div( className="sankey-settings", children=[ # network settings html.H2("Sankey Settings", style={"text-align": "center"}), html.Label("Weight Filter"), dcc.Slider(id="sankey-filter", min=min(scores), max=max(scores), value=0.75, step=0.001, updatemode="drag", tooltip={ "always_visible": True, "placement": "right" }), html.Label("Toggle weighted"), dcc.RadioItems(id="sankey-toggle", options=[{ "label": "Score", "value": "score" }, { "label": "Log score", "value": "log_score" }, { "label": "Specificity", "value": "specificity" }, { "label": "Importance", "value": "importance" }], value="importance", labelStyle={"display": "block"}) ]), # end network settings html.Div( className="sankey", id="sankey", children=[ # sankey graph html.H2("Sankey graph", style={"text-align": "center"}), dcc.Graph(id="sankey-graph") ]), # end sankey graph html.Div( className="interaction-list", children=[ # interaction list html.Div(id="selection", children=[ html.H2("Interactions", style={"text-align": "center"}), html.H3(id="edge-info", style={"text-align": "center"}), dcc.Graph(id="interaction-scatter"), html.Div(id="interaction-selection", style={"display": "none"}, children=[""]) ]), html.Div(children=[ dash_table.DataTable( id="edge-selection", page_size=20, style_table={ "overflowX": "scroll", "overflowY": "scroll", "height": "50vh", "width": "95%" }, style_cell_conditional=[{ "if": { "column_id": "interaction" }, "textAlign": "left" }, { "if": { "column_id": "receptorfamily" }, "textAlign": "left" }, { "if": { "column_id": "pubmedid" }, "textAlign": "left" }], style_header={ "fontWeight": "bold", "maxWidth": "200px", "minWidth": "70px" }, style_data={ "maxWidth": "200px", "minWidth": "70px", "textOverflow": "ellipsis" }, sort_action="native", fixed_rows={ 'headers': True, 'data': 0 }) ]) ]), # end interaction list html.Div( className="L-R-scores", children=[ # ligand and receptor lists html.H2("Ligand and receptors", style={"text-align": "center"}), html.Div(children=[ html.H3( id="selected-node", style={"text-align": "center"}, children=["Select a node in the notwork graph"]), html.Label("Search for ligands and receptors:", style={"margin-right": "10px"}), dcc.Input(id="filter_l_r", type="search", value="", placeholder="Search") ]), dcc.Tabs([ dcc.Tab(label="Ligands", children=[ dcc.Graph(id="ligand-graph", config=dict(autosizable=True, responsive=True)), dash_table.DataTable( id="ligand-table", page_size=20, style_table={ "overflowX": "scroll", "overflowY": "scroll", "height": "50vh", "width": "95%" }, style_cell_conditional=[{ "if": { "column_id": "Ligand" }, "textAlign": "left" }], style_header={ "fontWeight": "bold", "maxWidth": "200px", "minWidth": "70px" }, style_data={ "maxWidth": "200px", "minWidth": "70px", "textOverflow": "ellipsis" }, sort_action="native", fixed_rows={ 'headers': True, 'data': 0 }) ]), dcc.Tab(label="Receptors", children=[ dcc.Graph(id="receptor-graph", config=dict(autosizable=True, responsive=True)), dash_table.DataTable( id="receptor-table", page_size=20, style_table={ "overflowX": "scroll", "overflowY": "scroll", "height": "50vh", "width": "95%" }, style_cell_conditional=[{ "if": { "column_id": "Receptor" }, "textAlign": "left" }], style_header={ "fontWeight": "bold", "maxWidth": "200px", "minWidth": "70px" }, style_data={ "maxWidth": "200px", "minWidth": "70px", "textOverflow": "ellipsis" }, sort_action="native", fixed_rows={ 'headers': True, 'data': 0 }) ]) ]) ]) # end ligand receptor list ]) # end wrapper # Instantiate the graph and produce the bounderies for filters @app.callback([ Output("cyto-graph", "elements"), Output("network-filter", "min"), Output("network-filter", "max"), Output("network-filter", "value") ], [ Input("network-interaction", "value"), Input("weight-select", "value") ]) def make_graph(interaction, score): G_flat = G_to_flat(G, score) if interaction == "all": # if no interaction is selected, use full graph G_cyto = nx.cytoscape_data(G_flat) weights = [d["weight"] for u, v, d in G_flat.edges(data=True)] # prepare data for network graph nodes = G_cyto["elements"]["nodes"] edges = G_cyto["elements"]["edges"] elements = nodes + edges return elements, min(weights), max(weights), np.mean(weights) else: # an interaction is selected, select only that interaction G_split = cn.graph.split_graph(G) G_split_flat = G_to_flat(G_split[interaction], score) G_cyto = nx.cytoscape_data(G_split_flat) weights = [ d["weight"] for u, v, d in G_split_flat.edges(data=True) ] # prepare data for network graph nodes = G_cyto["elements"]["nodes"] edges = G_cyto["elements"]["edges"] elements = nodes + edges return elements, min(weights), max(weights), np.mean(weights) # Change layout of network graph @app.callback(Output("cyto-graph", "layout"), [Input("network-layout", "value")]) def update_network_layout(layout): return {"name": layout, "automate": True, "fit": True} # Choose gene to color nodes by @app.callback( [Output("node-colors", "children"), Output("min-max", "children")], [Input("gene_dropdown", "value")]) def calculate_colors(gene): if gene is None: return [None, ""] # get all gene expression values for selected gene gene_data = { celltype["data"]["id"]: celltype["data"]["genes"][gene] for celltype in nodes } min_value = min(gene_data.values()) max_value = max(gene_data.values()) # package min max expression information to a list that will be returned expression = html.Ul(children=[ html.Li(f"minimum gene expression: {min_value}"), html.Li(f"maximum gene expression: {max_value}") ]) cmap = matplotlib.cm.get_cmap("coolwarm") color_dict = dict() for k, v in gene_data.items(): color_dict[k] = {"rgb": cmap(v, bytes=True)[0:3], "expression": v} color = pd.Series(color_dict) return color.to_json(), expression # Select visible edges of network graph depending on filter value # node color depending on selected gene # width of edges @app.callback(Output("cyto-graph", "stylesheet"), [ Input("network-filter", "value"), Input("network-filter", "min"), Input("network-filter", "max"), Input("node-size", "value"), Input("node-colors", "children") ]) def style_network_graph(th, min_weight, max_weight, size, colors): # create a filter for edges filter_style = [{ "selector": f"edge[weight < {th}]", "style": { "display": "none" } }, { "selector": "node", "style": { 'height': f'mapData(centrality, {min(cent)}, {max(cent)}, {size[0]}, {size[1]})', 'width': f'mapData(centrality, {min(cent)}, {max(cent)}, {size[0]}, {size[1]})' } }] # create a color style for nodes based on gene expression if isinstance(colors, str): colors = pd.read_json(colors, typ="series", convert_dates=False) color_style = [{ 'selector': f'node[id = "{str(index)}"]', 'style': { 'background-color': f'rgb{tuple(colors[index]["rgb"])}' } } for index in colors.index] filter_style += color_style else: color_style = { "selector": "node", "style": { 'background-color': 'BFD7B5' } } # Map edges width to a set min and max value (scale for visibility) edge_style = [{ "selector": "edge", "style": { "width": f"mapData(weight, {min_weight}, {max_weight}, 1, 10)" } }] return default_stylesheet + filter_style + edge_style # download an image of current network graph view @app.callback(Output("cyto-graph", "generateImage"), Input("download-network-graph", "n_clicks")) def download_networkgraph_image(get_request): if get_request == None: return dict() return {"type": "svg", "action": "download"} # Produce a table of all edge data from tapped edge @app.callback([ Output("edge-info", "children"), Output("edge-selection", "columns"), Output("edge-selection", "data") ], [ Input("cyto-graph", "tapEdgeData"), Input("interaction-selection", "children") ]) def update_data(edge, selection): import pandas as pd import json # check if an edge has really been clicked, return default otherwise if edge is None: return ["", None, None] info = f"Interactions from {edge['source']} to {edge['target']}." # map visible names for columns with columns in edge[interaction] columns = [{ "name": "Interaction", "id": "interaction" }, { "name": "Receptor Family", "id": "receptorfamily" }, { "name": "Score", "id": "score" }, { "name": "Log10(score)", "id": "log_score" }, { "name": "Specificity", "id": "specificity" }, { "name": "Importance", "id": "importance" }, { "name": "Ligand z-score", "id": "ligand_zscore" }, { "name": "Ligand p-value", "id": "ligand_pval" }, { "name": "Receptor z-score", "id": "receptor_zscore" }, { "name": "Receptor p-value", "id": "receptor_pval" }, { "name": "PubMed ID", "id": "pubmedid" }] interactions = pd.DataFrame(edge["interactions"])[[ "interaction", "receptorfamily", "score", "log_score", "specificity", "importance", "ligand_zscore", "ligand_pval", "receptor_zscore", "receptor_pval", "pubmedid" ]] # Sort values based on score interactions.sort_values(by="score", ascending=False, inplace=True) # round values for scores to two decimals interactions[[ "score", "log_score", "specificity", "importance", "ligand_zscore", "receptor_zscore" ]] = interactions[[ "score", "log_score", "specificity", "importance", "ligand_zscore", "receptor_zscore" ]].round(decimals=2) interactions[["ligand_pval", "receptor_pval" ]] = interactions[["ligand_pval", "receptor_pval"]].round(decimals=4) # if selection from interaction graph, filter dataframe if selection != "": selection = json.loads(selection) interactions = interactions.loc[interactions["interaction"].isin( selection)] records = interactions.to_dict("records") return [info, columns, records] @app.callback([Output("interaction-scatter", "figure")], [Input("cyto-graph", "tapEdgeData")]) def interaction_scatter_plot(edge): import plotly.express as px fig = go.Figure() if not isinstance(edge, dict): return [ fig, ] interactions = pd.DataFrame(edge["interactions"])[[ "interaction", "receptorfamily", "score", "log_score", "ligand_zscore", "ligand_pval", "receptor_zscore", "receptor_pval", "specificity", "importance", "pubmedid" ]] # add 10% to the min and max value to not clip the datapoint range_x = (-max(interactions["log_score"]) * 0.1, max(interactions["log_score"]) * 1.1) range_y = (-max(interactions["specificity"]) * 0.1, max(interactions["specificity"]) * 1.1) #interactions["specificity"] = np.log10( interactions["specificity"]) fig = px.scatter(interactions, x="log_score", range_x=range_x, y="specificity", range_y=range_y, color="importance", hover_name="interaction", hover_data=[ "ligand_pval", "receptor_pval", "score", "specificity", "receptorfamily" ], color_continuous_scale=px.colors.sequential.Viridis_r, labels={ "ligand_zscore": "Ligand Z-score", "receptor_zscore": "Receptor Z-score", "log_score": "log(Interaction score)", "score": "Interaction score", "specificity": "Specificity", "importance": "Importance", "receptorfamily": "Receptor family", "pubmedid": "PubMed ID", "ligand_pval": "Ligand p-value", "receptor_pval": "Receptor p-value" }) return [ fig, ] @app.callback(Output("interaction-selection", "children"), [Input("interaction-scatter", "selectedData")]) def interaction_select(selected_data): import json if isinstance(selected_data, dict): interactions = [ point["hovertext"] for point in selected_data["points"] ] else: return "" return json.dumps(interactions) # Produce ligand and receptor graphs based on tapped node @app.callback([ Output("ligand-graph", "figure"), Output("receptor-graph", "figure"), Output("selected-node", "children") ], [Input("cyto-graph", "tapNodeData"), Input("filter_l_r", "value")]) def plot_l_r_expression(node, filter_text): # set output variables to empty figures ligand_fig = go.Figure() receptor_fig = go.Figure() node_id = "Select a node in the network graph" if isinstance(node, dict): import plotly.express as px node_id = node["id"] ligands_score = pd.DataFrame.from_dict(node["ligands_score"], orient="index", columns=["Score"]) ligands_zscore = np.log2( pd.DataFrame.from_dict(node["ligands_zscore"], orient="index", columns=["Z-score"])) ligands_corr_pval = pd.DataFrame.from_dict( node["ligands_corr_pval"], orient="index", columns=["p-value"]) ligands_merge = ligands_score.merge(ligands_zscore, how="left", left_index=True, right_index=True) ligands_merge = ligands_merge.merge(ligands_corr_pval, how="left", left_index=True, right_index=True) ligands_merge["log(score + 1)"] = np.log10(ligands_merge["Score"] + 1) ligands_merge["Significant"] = [ True if p_val < 0.05 else False for p_val in ligands_merge["p-value"] ] ligands_merge["-log(p-value)"] = -np.log10( ligands_merge["p-value"]) if filter_text != "": ligands_merge = ligands_merge.filter(like=filter_text, axis=0) ligand_fig = px.scatter(ligands_merge, x="log(score + 1)", y="-log(p-value)", color="Significant", hover_name=ligands_merge.index, hover_data=["Score", "Z-score", "p-value"]) receptors_score = pd.DataFrame.from_dict(node["receptors_score"], orient="index", columns=["Score"]) receptors_zscore = np.log2( pd.DataFrame.from_dict(node["receptors_zscore"], orient="index", columns=["Z-score"])) receptors_corr_pval = pd.DataFrame.from_dict( node["receptors_corr_pval"], orient="index", columns=["p-value"]) receptors_merge = receptors_score.merge(receptors_zscore, how="left", left_index=True, right_index=True) receptors_merge = receptors_merge.merge(receptors_corr_pval, how="left", left_index=True, right_index=True) receptors_merge["log(score + 1)"] = np.log10( receptors_merge["Score"] + 1) receptors_merge["Significant"] = [ True if p_val < 0.05 else False for p_val in receptors_merge["p-value"] ] receptors_merge["-log(p-value)"] = -np.log10( receptors_merge["p-value"]) if filter_text != "": receptors_merge = receptors_merge.filter(like=filter_text, axis=0) receptor_fig = px.scatter( receptors_merge, x="log(score + 1)", y="-log(p-value)", color="Significant", hover_name=receptors_merge.index, hover_data=["Score", "Z-score", "p-value"]) return [ligand_fig, receptor_fig, node_id] # Builds a sankey graph based on the tapped node (store in global G_s) G_s = nx.MultiDiGraph() #variable holding sankey graph @app.callback([ Output("sankey-filter", "min"), Output("sankey-filter", "max"), Output("sankey-filter", "value") ], [Input("cyto-graph", "tapNodeData"), Input("sankey-toggle", "value")]) def build_sankey_graph(node, score): import numpy as np # If no node has been selected, dont try to build graph if node is None: return (0, 0, 0) node = node["id"] # Find all interactions where node is target or source node nonlocal G_s G_s = nx.MultiDiGraph() # reset content weight = list( ) # list to store all weights (used to set min and max for the filter) for n, nbrs in G.adj.items( ): # graph has been modified by network graph before for nbr, edict in nbrs.items(): if n == node: for e, d in edict.items(): G_s.add_edge(n, " Post " + nbr, **d) weight.append(d[score]) if nbr == node: for e, d in edict.items(): G_s.add_edge("Pre " + n, nbr, **d) weight.append(d[score]) if len(weight) == 0: weight = [0, 1] if score == "specificity": # set default start value to specificity value for ligand and receptor # p-value of (0.05 and 0.05)/2 = 1.3 return (min(weight), max(weight), 1.3) return (min(weight), max(weight), np.mean(weight)) @app.callback(Output("sankey-graph", "figure"), [ Input("sankey-filter", "value"), Input("sankey-toggle", "value"), Input("cyto-graph", "tapNodeData") ]) def filter_sankey_graph(th, score, node): if node: node = node["id"] _G_s = nx.MultiDiGraph() for u, v, n, d in G_s.edges(data=True, keys=True): if d[score] > th: _G_s.add_edge(u, v, n, **d) _G_s.add_nodes_from(G_s.nodes(data=True)) edges = nx.to_pandas_edgelist(_G_s) if len(edges) < 1: fig = dict() return fig # add same color scheme as network graph for node_s in _G_s.nodes(): if " Post" in node_s: original_node = str(node_s).split(sep=" Post")[1] elif "Pre " in node_s: original_node = str(node_s).split(sep="Pre ")[1] else: original_node = str(node_s) new_color = color_map_nodes[original_node.strip()] G_s.nodes[node_s]["color"] = new_color nodes = G_s.nodes() node_map = {cluster: id for id, cluster in enumerate(list(nodes))} sankey = go.Sankey(node=dict(pad=15, thickness=20, line=dict(color="black", width=0.5), label=list(nodes), color=[ f'rgb{tuple(d["color"][0:3])}' for n, d in G_s.nodes(data=True) ]), link=dict( source=list(edges["source"].map(node_map)), target=list(edges["target"].map(node_map)), value=list(edges[score]), label=edges["interaction"])) data = [sankey] layout = go.Layout(autosize=True, title=f"Interactions: {node}", font=dict(size=font_size)) fig = go.Figure(data=data, layout=layout) return fig @app.callback( [Output("ligand-table", "columns"), Output("ligand-table", "data")], [ Input("ligand-graph", "figure"), Input("ligand-graph", "selectedData") ]) def select_ligands(figure, selected): import json ligands = [] score = [] zscore = [] pval = [] for group in figure["data"]: for ligand in group["hovertext"]: ligands.append(ligand) for data in group["customdata"]: score.append(data[0]) zscore.append(data[1]) pval.append(data[2]) df = pd.DataFrame({ "Ligand": ligands, "Score": score, "Z-score": zscore, "P-value": pval }) df.index = df["Ligand"] df.sort_values(by="Score", ascending=False, inplace=True) if isinstance(selected, dict): filt = [] for point in selected["points"]: filt.append(point["hovertext"]) df = df.loc[filt] columns = [{ "name": "Ligand", "id": "Ligand" }, { "name": "Score", "id": "Score" }, { "name": "Z-score", "id": "Z-score" }, { "name": "P-value", "id": "P-value" }] data = df.to_dict("records") return columns, data @app.callback([ Output("receptor-table", "columns"), Output("receptor-table", "data") ], [ Input("receptor-graph", "figure"), Input("receptor-graph", "selectedData") ]) def select_ligands(figure, selected): import json receptors = [] score = [] zscore = [] pval = [] for group in figure["data"]: for receptor in group["hovertext"]: receptors.append(receptor) for data in group["customdata"]: score.append(data[0]) zscore.append(data[1]) pval.append(data[2]) df = pd.DataFrame({ "Receptor": receptors, "Score": score, "Z-score": zscore, "P-value": pval }) df.index = df["Receptor"] df.sort_values(by="Score", ascending=False, inplace=True) if isinstance(selected, dict): filt = [] for point in selected["points"]: filt.append(point["hovertext"]) df = df.loc[filt] columns = [{ "name": "Receptor", "id": "Receptor" }, { "name": "Score", "id": "Score" }, { "name": "Z-score", "id": "Z-score" }, { "name": "P-value", "id": "P-value" }] data = df.to_dict("records") return columns, data # Run server app.run_server(**kwargs)