Esempio n. 1
0
    def __init__(self, tree: "DataTree"):
        self.tree = tree
        JupyterDash.infer_jupyter_proxy_config()
        self.app = JupyterDash(__name__, external_stylesheets=STYLESHEET)
        self.server = self.app.server
        self.model = CellModel(tree)
        self.scatter = Scatter(self.app, data={"model": self.model})
        self.dim_dropdowns = DropDownColumn(
            self.app,
            data=dict(title_list=self.model.DIMENSIONS,
                      options=self.model.col_options))
        self.sliders = SliderColumn(
            data={"slider_params": self.model.MARKER_ARGS_PARAMS})
        self.col_adder = ColumnAdder()
        self.annotator = ScatterAnnotator(data={
            "model": self.model,
            "options": self.model.col_options
        })
        self.colname_store = DropDownOptionsStore()
        self.gene_data = GeneData(self.app, data={"model": self.model})

        self.colname_store.callbacks(master=self.annotator,
                                     subscribers=[
                                         *self.dim_dropdowns.contents,
                                         self.annotator.dropdown
                                     ])
        self.annotator.callbacks(self.scatter, self.col_adder)
        self.scatter.callbacks(self.dim_dropdowns, self.sliders)
        self.gene_data.callbacks(self.scatter)
Esempio n. 2
0
# Copyright (C) 2021, RTE (http://www.rte-france.com/)
# See AUTHORS.txt
# SPDX-License-Identifier: MPL-2.0
"""
This file handles the html entry point of the application through dash components.
It will generate the layout of a given page and handle the routing
"""

import dash_bootstrap_components as dbc

# from dash import Dash
from jupyter_dash import JupyterDash

# We need to create app before importing the rest of the project as it uses @app decorators
JupyterDash.infer_jupyter_proxy_config(
)  # for binder or jupyterHub for instance
app = JupyterDash(__name__, external_stylesheets=[
    dbc.themes.BOOTSTRAP
])  # ,server_url="http://127.0.0.1:8050/")
"""
Get Imports to create layout and callbacks 
"""
from grid2viz.main_callbacks import register_callbacks_main
from grid2viz.layout import make_layout as layout

from grid2viz.src.episodes.episodes_clbk import register_callbacks_episodes
from grid2viz.src.overview.overview_clbk import (
    register_callbacks_overview, )  # as overview_clbk
from grid2viz.src.macro.macro_clbk import register_callbacks_macro  # as macro_clbk
from grid2viz.src.micro.micro_clbk import register_callbacks_micro  # as micro_clbk
"""
# Import required libraries
import pandas as pd
import dash
import dash_html_components as html
import dash_core_components as dcc
from dash.dependencies import Input, Output, State
from jupyter_dash import JupyterDash
import plotly.graph_objects as go
import plotly.express as px
from dash import no_update


# Create a dash application
app = JupyterDash(__name__)
JupyterDash.infer_jupyter_proxy_config()

# REVIEW1: Clear the layout and do not display exception till callback gets executed
app.config.suppress_callback_exceptions = True

# Read the airline data into pandas dataframe
airline_data =  pd.read_csv('https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBMDeveloperSkillsNetwork-DV0101EN-SkillsNetwork/Data%20Files/airline_data.csv', 
                            encoding = "ISO-8859-1",
                            dtype={'Div1Airport': str, 'Div1TailNum': str, 
                                   'Div2Airport': str, 'Div2TailNum': str})


# List of years 
year_list = [i for i in range(2005, 2021, 1)]

"""Compute graph data for creating yearly airline performance report 
Esempio n. 4
0
    def make_dash_app(self,
                      template='plotly_white',
                      server_mode='external',
                      port=8050,
                      app_type='jupyter',
                      plot_height=680,
                      external_stylesheets=[
                          'https://codepen.io/chriddyp/pen/bWLwgP.css'
                      ],
                      infer_proxy=False,
                      slider_width=140,
                      cutout_hdu=None,
                      cutout_size=10):
        """
        Create a Plotly/Dash app for interactive exploration
        
        Parameters
        ----------
        template : str
            `plotly` style `template <https://plotly.com/python/templates/#specifying-themes-in-graph-object-figures>`_.
        
        server_mode, port : str, int
            If not `None`, the app server is started with 
            `app.run_server(mode=server_mode, port=port)`.
        
        app_type : str
            If ``jupyter`` then `app = jupyter_dash.JupyterDash()`, else
            `app = dash.Dash()`
            
        plot_height : int
            Height in pixels of the scatter and SED+P(z) plot windows.
        
        infer_proxy : bool
            Run `JupyterDash.infer_jupyter_proxy_config()`, before app 
            initilization, e.g., for running on GoogleColab.
            
        Returns
        -------
        app : object
            App object following `app_type`.
            
        """
        import dash
        from dash import dcc
        from dash import html
        import plotly.express as px
        from urllib.parse import urlparse, parse_qsl, urlencode
        import astropy.wcs as pywcs

        if app_type == 'dash':
            app = dash.Dash(__name__,
                            external_stylesheets=external_stylesheets)
        else:
            from jupyter_dash import JupyterDash
            if infer_proxy:
                JupyterDash.infer_jupyter_proxy_config()

            app = JupyterDash(__name__,
                              external_stylesheets=external_stylesheets)

        PLOT_TYPES = [
            'zphot-zspec', 'Mag-redshift', 'Mass-redshift', 'UVJ', 'RA/Dec',
            'UV-redshift', 'chi2-redshift'
        ]

        for _t in self.extra_plots:
            PLOT_TYPES.append(_t)

        COLOR_TYPES = ['z_phot', 'z_spec', 'mass', 'sSFR', 'chi2']

        #_title = f"{self.photoz.param['MAIN_OUTPUT_FILE']}"
        #_subhead = f"Nobj={self.photoz.NOBJ}  Nfilt={self.photoz.NFILT}"
        _title = [
            html.Strong(self.photoz.param['MAIN_OUTPUT_FILE']),
            ' / N',
            html.Sub('obj'),
            f'={self.photoz.NOBJ}',
            ' / N',
            html.Sub('filt'),
            f'={self.photoz.NFILT}',
        ]

        slider_row_style = {
            'width': '90%',
            'float': 'left',
            'margin-left': '10px'
        }
        slider_container = {
            'width': f'{slider_width}px',
            'margin-left': '-25px'
        }
        check_kwargs = dict(style={
            'text-align': 'center',
            'height': '14pt',
            'margin-top': '-20px'
        })

        # bool_options = {'has_zspec': 'z_spec > 0',
        #                 'use': 'Use == 1'}

        if cutout_hdu is not None:
            cutout_wcs = pywcs.WCS(cutout_hdu.header, relax=True)
            cutout_data = cutout_hdu.data
            print('xxx', cutout_data.shape)

            cutout_div = html.Div(
                [dcc.Graph(id='cutout-figure', style={})],
                style={
                    'right': '70px',
                    'width': '120px',
                    'height': '120px',
                    'border': '1px solid rgb(200,200,200)',
                    'top': '10px',
                    'position': 'absolute'
                })
            cutout_target = 'figure'
        else:
            cutout_div = html.Div(id='cutout-figure',
                                  style={
                                      'left': '1px',
                                      'width': '1px',
                                      'height': '1px',
                                      'bottom': '1px',
                                      'position': 'absolute'
                                  })
            cutout_data = None
            cutout_target = 'children'

        ####### App layout
        app.layout = html.Div([
            # Selectors
            html.Div(
                [
                    dcc.Location(id='url', refresh=False),
                    html.Div([
                        html.Div(_title,
                                 id='title-bar',
                                 style={
                                     'float': 'left',
                                     'margin-top': '4pt'
                                 }),
                        html.Div([
                            html.Div([
                                dcc.Dropdown(id='plot-type',
                                             options=[{
                                                 'label': i,
                                                 'value': i
                                             } for i in PLOT_TYPES],
                                             value='zphot-zspec',
                                             clearable=False,
                                             style={
                                                 'width': '120px',
                                                 'margin-right': '5px',
                                                 'margin-left': '5px',
                                                 'font-size': '8pt'
                                             }),
                            ],
                                     style={'float': 'left'}),
                            html.Div([
                                dcc.Dropdown(id='color-type',
                                             options=[{
                                                 'label': i,
                                                 'value': i
                                             } for i in COLOR_TYPES],
                                             value='sSFR',
                                             clearable=False,
                                             style={
                                                 'width': '80px',
                                                 'margin-right': '5px',
                                                 'font-size': '8pt'
                                             }),
                            ],
                                     style={
                                         'display': 'inline-block',
                                         'margin-left': '10px'
                                     }),
                        ],
                                 style={'float': 'right'}),
                    ],
                             style=slider_row_style),
                    html.Div(
                        [
                            html.Div([
                                dcc.Dropdown(
                                    id='mag-filter',
                                    options=[{
                                        'label':
                                        i,
                                        'value':
                                        i
                                    } for i in self.photoz.flux_columns],
                                    value=self.DEFAULT_FILTER,
                                    style={
                                        'width': f'{slider_width-45}px',
                                        'margin-right': '20px',
                                        'font-size': '8pt'
                                    },
                                    clearable=False),
                            ],
                                     style={'float': 'left'}),
                            html.Div([
                                dcc.RangeSlider(id='mag-slider',
                                                min=12,
                                                max=32,
                                                step=0.2,
                                                value=[18, 27],
                                                updatemode='drag',
                                                tooltip={"placement": 'left'}),
                                dcc.Checklist(id='mag-checked',
                                              options=[{
                                                  'label': 'AB mag',
                                                  'value': 'checked'
                                              }],
                                              value=['checked'],
                                              **check_kwargs),
                            ],
                                     style=dict(display='inline-block',
                                                **slider_container)),
                            #
                            html.Div([
                                dcc.RangeSlider(id='chi2-slider',
                                                min=0,
                                                max=20,
                                                step=0.1,
                                                value=[0, 6],
                                                updatemode='drag',
                                                tooltip={"placement": 'left'}),
                                dcc.Checklist(id='chi2-checked',
                                              options=[{
                                                  'label': 'chi2',
                                                  'value': 'checked'
                                              }],
                                              value=[],
                                              **check_kwargs),
                            ],
                                     style=dict(display='inline-block',
                                                **slider_container)),
                            html.Div([
                                dcc.RangeSlider(id='nfilt-slider',
                                                min=1,
                                                max=self.MAXNFILT,
                                                step=1,
                                                value=[3, self.MAXNFILT],
                                                updatemode='drag',
                                                tooltip={"placement": 'left'}),
                                dcc.Checklist(id='nfilt-checked',
                                              options=[{
                                                  'label': 'nfilt',
                                                  'value': 'checked'
                                              }],
                                              value=['checked'],
                                              **check_kwargs),
                            ],
                                     style=dict(display='inline-block',
                                                **slider_container)),
                        ],
                        style=slider_row_style),
                    html.Div(
                        [
                            html.Div([
                                dcc.RangeSlider(id='zphot-slider',
                                                min=-0.5,
                                                max=12,
                                                step=0.1,
                                                value=[0, self.ZMAX],
                                                updatemode='drag',
                                                tooltip={"placement": 'left'}),
                                dcc.Checklist(id='zphot-checked',
                                              options=[{
                                                  'label': 'z_phot',
                                                  'value': 'checked'
                                              }],
                                              value=['checked'],
                                              **check_kwargs),
                            ],
                                     style=dict(float='left',
                                                **slider_container)),
                            html.Div([
                                dcc.RangeSlider(id='zspec-slider',
                                                min=-0.5,
                                                max=12,
                                                step=0.1,
                                                value=[-0.5, 6.5],
                                                updatemode='drag',
                                                tooltip={"placement": 'left'}),
                                dcc.Checklist(id='zspec-checked',
                                              options=[{
                                                  'label': 'z_spec',
                                                  'value': 'checked'
                                              }],
                                              value=['checked'],
                                              **check_kwargs),
                            ],
                                     style=dict(display='inline-block',
                                                **slider_container)),
                            html.Div([
                                dcc.RangeSlider(id='mass-slider',
                                                min=7,
                                                max=13,
                                                step=0.1,
                                                value=[8, 11.8],
                                                updatemode='drag',
                                                tooltip={"placement": 'left'}),
                                dcc.Checklist(id='mass-checked',
                                              options=[{
                                                  'label': 'mass',
                                                  'value': 'checked'
                                              }],
                                              value=['checked'],
                                              **check_kwargs),
                            ],
                                     style=dict(display='inline-block',
                                                **slider_container)),

                            # Boolean dropdown
                            # dcc.Dropdown(id='bool-checks',
                            #              options=[{'label': self.bool_options[k],
                            #                        'value': k}
                            #                       for k in self.bool_options],
                            #              value=[],
                            #              multi=True,
                            #              style={'width':'100px',
                            #                     'display':'inline-block',
                            #                     'margin-left':'0px',
                            #                     'font-size':'8pt'},
                            #              clearable=True),
                        ],
                        style=slider_row_style),
                ],
                style={
                    'float': 'left',
                    'width': '55%'
                }),

            # Object-level controls
            html.Div([
                html.Div([
                    html.Div('ID / RA,Dec.',
                             style={
                                 'float': 'left',
                                 'width': '100px',
                                 'margin-top': '5pt'
                             }),
                    dcc.Input(id='id-input',
                              type='text',
                              style={
                                  'width': '120px',
                                  'padding': '2px',
                                  'display': 'inline',
                                  'font-size': '8pt'
                              }),
                    html.Div(children='',
                             id='match-sep',
                             style={
                                 'margin': '5pt',
                                 'display': 'inline',
                                 'width': '50px',
                                 'font-size': '8pt'
                             }),
                    dcc.RadioItems(id='sed-unit-selector',
                                   options=[{
                                       'label': i,
                                       'value': i
                                   } for i in ['Fλ', 'Fν', 'νFν']],
                                   value='Fλ',
                                   labelStyle={
                                       'display': 'inline',
                                       'padding': '3px',
                                   },
                                   style={
                                       'display': 'inline',
                                       'width': '130px'
                                   })
                ],
                         style={
                             'width': '260pix',
                             'float': 'left',
                             'margin-right': '20px'
                         }),
            ]),
            html.Div(
                [
                    # html.Div([
                    # ],  style={'width':'120px', 'float':'left'}),
                    html.Div(id='object-info',
                             children='ID: ',
                             style={
                                 'margin': 'auto',
                                 'margin-top': '10px',
                                 'font-size': '10pt'
                             })
                ],
                style={
                    'float': 'right',
                    'width': '45%'
                }),

            # Plots
            html.Div(
                [  # Scatter plot
                    dcc.Graph(id='sample-selection-scatter',
                              hoverData={
                                  'points': [{
                                      'customdata':
                                      (self.df['id'][0], 1.0, -9.0)
                                  }]
                              },
                              style={'width': '95%'})
                ],
                style={
                    'float': 'left',
                    'height': '70%',
                    'width': '49%'
                }),
            html.Div(
                [  # SED
                    dcc.Graph(id='object-sed-figure', style={'width': '95%'})
                ],
                style={
                    'float': 'right',
                    'width': '49%',
                    'height': '70%'
                }),
            cutout_div
        ])

        ##### Callback functions
        @app.callback(dash.dependencies.Output('url', 'search'), [
            dash.dependencies.Input('plot-type', 'value'),
            dash.dependencies.Input('color-type', 'value'),
            dash.dependencies.Input('mag-filter', 'value'),
            dash.dependencies.Input('mag-slider', 'value'),
            dash.dependencies.Input('mass-slider', 'value'),
            dash.dependencies.Input('chi2-slider', 'value'),
            dash.dependencies.Input('nfilt-slider', 'value'),
            dash.dependencies.Input('zphot-slider', 'value'),
            dash.dependencies.Input('zspec-slider', 'value'),
            dash.dependencies.Input('id-input', 'value')
        ])
        def update_url_state(plot_type, color_type, mag_filter, mag_range,
                             mass_range, chi2_range, nfilt_range, zphot_range,
                             zspec_range, id_input):
            search = f'?plot_type={plot_type}&color_type={color_type}'
            search += f'&mag_filter={mag_filter}'
            search += f'&mag={mag_range[0]},{mag_range[1]}'
            search += f'&mass={mass_range[0]},{mass_range[1]}'
            search += f'&chi2={chi2_range[0]},{chi2_range[1]}'
            search += f'&nfilt={nfilt_range[0]},{nfilt_range[1]}'
            search += f'&zphot={zphot_range[0]},{zphot_range[1]}'
            search += f'&zspec={zspec_range[0]},{zspec_range[1]}'
            if id_input is not None:
                search += f"&id={id_input.replace(' ', '%20')}"

            return search

        @app.callback([
            dash.dependencies.Output('plot-type', 'value'),
            dash.dependencies.Output('color-type', 'value'),
            dash.dependencies.Output('mag-filter', 'value'),
            dash.dependencies.Output('mag-slider', 'value'),
            dash.dependencies.Output('mass-slider', 'value'),
            dash.dependencies.Output('chi2-slider', 'value'),
            dash.dependencies.Output('nfilt-slider', 'value'),
            dash.dependencies.Output('zphot-slider', 'value'),
            dash.dependencies.Output('zspec-slider', 'value'),
            dash.dependencies.Output('id-input', 'value'),
        ], [dash.dependencies.Input('url', 'href')])
        def set_state_from_url(href):
            plot_type = 'zphot-zspec'
            color_type = 'sSFR'
            mag_filter = self.DEFAULT_FILTER
            mag_range = [18, 27]
            mass_range = [8, 11.6]
            chi2_range = [0, 4]
            nfilt_range = [1, self.MAXNFILT]
            zphot_range = [0, self.ZMAX]
            zspec_range = [-0.5, 6.5]
            id_input = None

            if '?' not in href:
                return (plot_type, color_type, mag_filter, mag_range,
                        mass_range, chi2_range, nfilt_range, zphot_range,
                        zspec_range, id_input)

            search = href.split('?')[1]
            params = search.split('&')

            for p in params:
                if 'plot_type' in p:
                    val = p.split('=')[1]
                    if val in PLOT_TYPES:
                        plot_type = val

                elif 'color_type' in p:
                    val = p.split('=')[1]
                    if val in COLOR_TYPES:
                        color_type = val

                elif 'mag_filter' in p:
                    val = p.split('=')[1]
                    if val in self.photoz.flux_columns:
                        mag_filter = val

                elif 'mag=' in p:
                    try:
                        vals = [float(v) for v in p.split('=')[1].split(',')]
                        if len(vals) == 2:
                            mag_range = vals
                    except ValueError:
                        pass

                elif 'mass' in p:
                    try:
                        vals = [float(v) for v in p.split('=')[1].split(',')]
                        if len(vals) == 2:
                            mass_range = vals
                    except ValueError:
                        pass

                elif 'nfilt=' in p:
                    try:
                        vals = [int(v) for v in p.split('=')[1].split(',')]
                        if len(vals) == 2:
                            nfilt_range = vals
                    except ValueError:
                        pass

                elif 'zspec' in p:
                    try:
                        vals = [float(v) for v in p.split('=')[1].split(',')]
                        if len(vals) == 2:
                            zspec_range = vals
                    except ValueError:
                        pass

                elif 'zphot' in p:
                    try:
                        vals = [float(v) for v in p.split('=')[1].split(',')]
                        if len(vals) == 2:
                            zphot_range = vals
                    except ValueError:
                        pass

                elif 'id' in p:
                    try:
                        id_input = p.split('=')[1].replace('%20', ' ')
                    except ValueError:
                        id_input = None

                    if not id_input:
                        id_input = None

            return (plot_type, color_type, mag_filter, mag_range, mass_range,
                    chi2_range, nfilt_range, zphot_range, zspec_range,
                    id_input)

        @app.callback(
            dash.dependencies.Output('sample-selection-scatter', 'figure'), [
                dash.dependencies.Input('plot-type', 'value'),
                dash.dependencies.Input('color-type', 'value'),
                dash.dependencies.Input('mag-filter', 'value'),
                dash.dependencies.Input('mag-slider', 'value'),
                dash.dependencies.Input('mag-checked', 'value'),
                dash.dependencies.Input('mass-slider', 'value'),
                dash.dependencies.Input('mass-checked', 'value'),
                dash.dependencies.Input('chi2-slider', 'value'),
                dash.dependencies.Input('chi2-checked', 'value'),
                dash.dependencies.Input('nfilt-slider', 'value'),
                dash.dependencies.Input('nfilt-checked', 'value'),
                dash.dependencies.Input('zphot-slider', 'value'),
                dash.dependencies.Input('zphot-checked', 'value'),
                dash.dependencies.Input('zspec-slider', 'value'),
                dash.dependencies.Input('zspec-checked', 'value'),
                dash.dependencies.Input('id-input', 'value')
            ])
        def update_selection(plot_type, color_type, mag_filter, mag_range,
                             mag_checked, mass_range, mass_checked, chi2_range,
                             chi2_checked, nfilt_range, nfilt_checked,
                             zphot_range, zphot_checked, zspec_range,
                             zspec_checked, id_input):
            """
            Apply slider selections
            """
            sel = np.isfinite(self.df['z_phot'])
            if 'checked' in zphot_checked:
                sel &= (self.df['z_phot'] > zphot_range[0])
                sel &= (self.df['z_phot'] < zphot_range[1])

            if 'checked' in zspec_checked:
                sel &= (self.df['z_spec'] > zspec_range[0])
                sel &= (self.df['z_spec'] < zspec_range[1])

            if 'checked' in mass_checked:
                sel &= (self.df['mass'] > mass_range[0])
                sel &= (self.df['mass'] < mass_range[1])

            if 'checked' in chi2_checked:
                sel &= (self.df['chi2'] >= chi2_range[0])
                sel &= (self.df['chi2'] <= chi2_range[1])

            if 'checked' in nfilt_checked:
                sel &= (self.df['nusefilt'] >= nfilt_range[0])
                sel &= (self.df['nusefilt'] <= nfilt_range[1])

            #print('redshift: ', sel.sum())

            if mag_filter is None:
                mag_filter = self.DEFAULT_FILTER

            #self.self.df['mag'] = self.ABZP
            #self.self.df['mag'] -= 2.5*np.log10(self.photoz.cat[mag_filter])
            mag_col = 'mag_' + mag_filter
            if 'checked' in mag_checked:
                sel &= (self.df[mag_col] > mag_range[0])
                sel &= (self.df[mag_col] < mag_range[1])

            self.df['mag'] = self.df[mag_col]

            #print('mag: ', sel.sum())

            if plot_type == 'zphot-zspec':
                sel &= self.df['z_spec'] > 0

            #print('zspec: ', sel.sum())

            if id_input is not None:
                id_i, dr_i = parse_id_input(id_input)
                if id_i is not None:
                    self.df['is_selected'] = self.df['id'] == id_i
                    sel |= self.df['is_selected']
                else:
                    self.df['is_selected'] = False
            else:
                self.df['is_selected'] = False

            dff = self.df[sel]

            # Color-coding by color-type pulldown
            if color_type == 'z_phot':
                color_kwargs = dict(color=np.clip(dff['z_phot'], *zphot_range),
                                    color_continuous_scale='portland')
            elif color_type == 'z_spec':
                color_kwargs = dict(color=np.clip(dff['z_spec'], *zspec_range),
                                    color_continuous_scale='portland')
            elif color_type == 'mass':
                color_kwargs = dict(color=np.clip(dff['mass'], *mass_range),
                                    color_continuous_scale='magma_r')
            elif color_type == 'chi2':
                color_kwargs = dict(color=np.clip(dff['chi2'], *chi2_range),
                                    color_continuous_scale='viridis')
            else:
                color_kwargs = dict(color=np.clip(dff['ssfr'], -12., -8.),
                                    color_continuous_scale='portland_r')

            # Scatter plot
            plot_defs = {
                'Mass-redshift':
                ('z_phot', 'mass', 'z<sub>phot</sub>', 'log Stellar mass',
                 (-0.1, self.ZMAX), (7.5, 12.5)),
                'Mag-redshift':
                ('z_phot', 'mag', 'z<sub>phot</sub>', f'AB mag ({mag_filter})',
                 (-0.1, self.ZMAX), (18, 28)),
                'RA/Dec':
                ('ra', 'dec', 'R.A.', 'Dec.', self.ra_bounds, self.dec_bounds),
                'zphot-zspec': ('z_spec', 'z_phot', 'z<sub>spec</sub>',
                                'z<sub>phot</sub>', (0, 4.5), (0, 4.5)),
                'UVJ':
                ('vj', 'uv', '(V-J)', '(U-V)', (-0.1, 2.5), (-0.1, 2.5)),
                'UV-redshift': ('z_phot', 'uv', 'z<sub>phot</sub>',
                                '(U-V)<sub>rest</sub>', (0, 4), (-0.1, 2.50)),
                'chi2-redshift': ('z_phot', 'chi2', 'z<sub>phot</sub>',
                                  'chi<sup>2</sup>', (0, 4), (0.1, 30))
            }

            if plot_type in self.extra_plots:
                args = (*self.extra_plots[plot_type], {}, color_kwargs)
            elif plot_type in plot_defs:
                args = (*plot_defs[plot_type], {}, color_kwargs)
            else:
                args = (*plot_defs['zphot-zspec'], {}, color_kwargs)

            fig = update_sample_scatter(dff, *args)

            # Update ranges for some parameters
            if ('Mass' in plot_type) & ('checked' in mass_checked):
                fig.update_yaxes(range=mass_range)

            if ('Mag' in plot_type) & ('checked' in mag_checked):
                fig.update_yaxes(range=mag_range)

            if ('redshift' in plot_type) & ('checked' in zphot_checked):
                fig.update_xaxes(range=zphot_range)

            if ('zspec' in plot_type) & ('checked' in zspec_checked):
                fig.update_yaxes(range=zspec_range)

            return fig

        def update_sample_scatter(dff, xcol, ycol, x_label, y_label, x_range,
                                  y_range, extra, color_kwargs):
            """
            Make scatter plot
            """
            import plotly.graph_objects as go

            fig = px.scatter(
                data_frame=dff,
                x=xcol,
                y=ycol,
                custom_data=['id', 'z_phot', 'mass', 'ssfr', 'mag'],
                **color_kwargs)

            htempl = '(%{x:.2f}, %{y:.2f}) <br>'
            htempl += 'id: %{customdata[0]:0d}  z_phot: %{customdata[1]:.2f}'
            htempl += '<br> mag: %{customdata[4]:.1f}  '
            htempl += 'mass: %{customdata[2]:.2f}  ssfr: %{customdata[3]:.2f}'

            fig.update_traces(hovertemplate=htempl, opacity=0.7)

            if dff['is_selected'].sum() > 0:
                dffs = dff[dff['is_selected']]
                _sel = go.Scatter(x=dffs[xcol],
                                  y=dffs[ycol],
                                  mode="markers+text",
                                  text=[f'{id}' for id in dffs['id']],
                                  textposition="bottom center",
                                  marker=dict(color='rgba(250,0,0,0.5)',
                                              size=20,
                                              symbol='circle-open'))

                fig.add_trace(_sel)

            fig.update_xaxes(range=x_range, title_text=x_label)
            fig.update_yaxes(range=y_range, title_text=y_label)

            fig.update_layout(template=template,
                              autosize=True,
                              showlegend=False,
                              margin=dict(l=0,
                                          r=0,
                                          b=0,
                                          t=20,
                                          pad=0,
                                          autoexpand=True))

            if plot_height is not None:
                fig.update_layout(height=plot_height)

            fig.update_traces(marker_showscale=False,
                              selector=dict(type='scatter'))
            fig.update_coloraxes(showscale=False)

            if (xcol, ycol) == ('z_spec', 'z_phot'):
                _one2one = go.Scatter(x=[0, 8],
                                      y=[0, 8],
                                      mode="lines",
                                      marker=dict(color='rgba(250,0,0,0.5)'))
                fig.add_trace(_one2one)

            fig.add_annotation(text=f'N = {len(dff)} / {len(self.df)}',
                               xref="x domain",
                               yref="y domain",
                               x=0.98,
                               y=0.05,
                               showarrow=False)

            return fig

        def sed_cutout_figure(id_i):
            """
            SED cutout
            """
            from plotly.subplots import make_subplots

            if cutout_data is not None:
                ix = np.where(self.df['id'] == id_i)[0]
                ri, di = self.df['ra'][ix], self.df['dec'][ix]
                xi, yi = np.squeeze(cutout_wcs.all_world2pix([ri], [di], 0))
                xp = int(np.round(xi))
                yp = int(np.round(yi))
                slx = slice(xp - cutout_size, xp + cutout_size + 1)
                sly = slice(yp - cutout_size, yp + cutout_size + 1)

                try:
                    cutout = cutout_data[sly, slx]
                except:
                    cutout = np.zeros((2 * cutout_size, 2 * cutout_size))

                fig = px.imshow(cutout, color_continuous_scale='gray_r')

                fig.update_coloraxes(showscale=False)
                fig.update_layout(width=120,
                                  height=120,
                                  margin=dict(l=0,
                                              r=0,
                                              b=0,
                                              t=0,
                                              pad=0,
                                              autoexpand=True))

                fig.update_xaxes(range=(0, 2 * cutout_size),
                                 visible=False,
                                 showticklabels=False)
                fig.update_yaxes(range=(0, 2 * cutout_size),
                                 visible=False,
                                 showticklabels=False)

                return fig

        def parse_id_input(id_input):
            """
            Parse input as id or (ra dec)
            """
            if id_input in ['None', None, '']:
                return None, None

            inp_split = id_input.replace(',', ' ').split()

            if len(inp_split) == 1:
                return int(inp_split[0]), None

            ra, dec = np.cast[float](inp_split)

            cosd = np.cos(self.df['dec'] / 180 * np.pi)
            dx = (self.df['ra'] - ra) * cosd
            dy = (self.df['dec'] - dec)
            dr = np.sqrt(dx**2 + dy**2) * 3600.
            imin = np.nanargmin(dr)

            return self.df['id'][imin], dr[imin]

        @app.callback([
            dash.dependencies.Output('object-sed-figure', 'figure'),
            dash.dependencies.Output('object-info', 'children'),
            dash.dependencies.Output('match-sep', 'children'),
            dash.dependencies.Output('cutout-figure', cutout_target)
        ], [
            dash.dependencies.Input('sample-selection-scatter', 'hoverData'),
            dash.dependencies.Input('sed-unit-selector', 'value'),
            dash.dependencies.Input('id-input', 'value')
        ])
        def update_object_sed(hoverData, sed_unit, id_input):
            """
            SED + p(z) plot
            """
            id_i, dr_i = parse_id_input(id_input)
            if id_i is None:
                id_i = hoverData['points'][0]['customdata'][0]
            else:
                if id_i not in self.zout['id']:
                    id_i = hoverData['points'][0]['customdata'][0]

            if dr_i is None:
                match_sep = ''
            else:
                match_sep = f'{dr_i:.1f}"'

            show_fnu = {'Fλ': 0, 'Fν': 1, 'νFν': 2}

            layout_kwargs = dict(template=template,
                                 autosize=True,
                                 showlegend=False,
                                 margin=dict(l=0,
                                             r=0,
                                             b=0,
                                             t=20,
                                             pad=0,
                                             autoexpand=True))

            fig = self.photoz.show_fit_plotly(id_i,
                                              show_fnu=show_fnu[sed_unit],
                                              vertical=True,
                                              panel_ratio=[0.6, 0.4],
                                              show=False,
                                              layout_kwargs=layout_kwargs)

            if plot_height is not None:
                fig.update_layout(height=plot_height)

            ix = self.df['id'] == id_i
            if ix.sum() == 0:
                object_info = 'ID: N/A'
            else:
                ix = np.where(ix)[0][0]
                ra, dec = self.df['ra'][ix], self.df['dec'][ix]
                object_info = [
                    f'ID: {id_i}  |  α, δ = {ra:.6f} {dec:.6f} ', ' | ',
                    html.A('ESO',
                           href=utils.eso_query(ra, dec, radius=1.0,
                                                unit='s')), ' | ',
                    html.A('CDS',
                           href=utils.cds_query(ra, dec, radius=1.0,
                                                unit='s')), ' | ',
                    html.A('LegacySurvey',
                           href=utils.show_legacysurvey(ra,
                                                        dec,
                                                        layer='ls-dr9')),
                    html.Br(), f"z_phot: {self.df['z_phot'][ix]:.3f}  ",
                    f" | z_spec: {self.df['z_spec'][ix]:.3f}",
                    html.Br(), f"mag: {self.df['mag'][ix]:.2f}  ",
                    f" | mass: {self.df['mass'][ix]:.2f} ",
                    f" | sSFR: {self.df['ssfr'][ix]:.2f}",
                    html.Br()
                ]

            if cutout_data is None:
                cutout_fig = ['']
            else:
                cutout_fig = sed_cutout_figure(id_i)

            return fig, object_info, match_sep, cutout_fig

        if server_mode is not None:
            app.run_server(mode=server_mode, port=port)

        return app
Esempio n. 5
0
def graph(G, mode="external", **kwargs):
    """
    G: a multidirectional graph

    kwargs are passed to the Jupyter_Dash.run_server() function. Some usefull arguments are:
        mode: "inline" to run app inside the jupyter nodebook, default is external 
        debug: True or False, Usefull to catch errors during development.
    """

    import dash
    from jupyter_dash import JupyterDash
    import dash_cytoscape as cyto
    from dash.dependencies import Output, Input
    import dash_html_components as html
    import dash_core_components as dcc
    import dash_table
    import networkx as nx
    import scConnect as cn
    import plotly.graph_objs as go
    import plotly.io as pio
    import pandas as pd
    import numpy as np
    import json
    import matplotlib
    import matplotlib.pyplot as plt
    pio.templates.default = "plotly_white"

    cyto.load_extra_layouts()

    JupyterDash.infer_jupyter_proxy_config()

    app = JupyterDash(__name__)

    server = app.server
    # Add a modified index string to change the title to scConnect
    app.index_string = '''
        <!DOCTYPE html>
        <html>
            <head>
                {%metas%}
                <title>scConnect</title>
                {%favicon%}
                {%css%}
            </head>
            <body>
                {%app_entry%}
                <footer>
                    {%config%}
                    {%scripts%}
                    {%renderer%}
            </body>
        </html>
        '''
    # Add colors to each node
    nodes = pd.Categorical(G.nodes())
    # make a list of RGBA tuples, one for each node
    colors = plt.cm.tab20c(nodes.codes / len(nodes.codes), bytes=True)
    # zip node to color
    color_map_nodes = dict(zip(nodes, colors))

    # add these colors to original graph
    for node, color in color_map_nodes.items():
        G.nodes[node]["color"] = color[0:3]  # Save only RGB

    # Add colors to edges(source node color) for  G
    for u, v, k in G.edges(keys=True):
        G.edges[u, v, k]["color"] = color_map_nodes[u][0:3]

    # load graph into used formes
    def G_to_flat(G, weight):
        G_flat = cn.graph.flatten_graph(G, weight=weight, log=True)

        # Add colors to edges(source node color) for G_flat
        for u, v, in G_flat.edges():
            G_flat.edges[u, v]["color"] = color_map_nodes[u][0:3]
        return G_flat

    # produce full graph variante to extract metadata
    G_flat = G_to_flat(G, weight="score")
    G_split = cn.graph.split_graph(G)

    # find and sort all found interactions
    interactions = list(G_split.keys())
    interactions.sort()

    G_cyto = nx.cytoscape_data(G_flat)

    # get min and max weight for all edges for flat and normal graph
    #weights = [d["weight"] for u, v, d in G_flat.edges(data=True)]
    scores = [d["score"] for u, v, d in G.edges(data=True)]
    cent = [d["centrality"] for n, d in G.nodes(data=True)]

    # prepare data for network graph
    nodes = G_cyto["elements"]["nodes"]
    elements = []

    # collect all available genes
    genes = list(nodes[0]["data"]["genes"].keys())

    # Styling parameters
    font_size = 20

    # Style for network graph
    default_stylesheet = [{
        'selector': 'node',
        'style': {
            'background-color': 'data(color)',
            'label': 'data(id)',
            'shape': 'ellipse',
            'opacity': 1,
            'font-size': f'{font_size}',
            'font-weight': 'bold',
            'text-wrap': 'wrap',
            'text-max-width': "100px",
            'text-opacity': 1,
            'text-outline-color': "white",
            'text-outline-opacity': 1,
            'text-outline-width': 2
        }
    }, {
        'selector': 'node:selected',
        'style': {
            'background-color': 'data(color)',
            'label': 'data(id)',
            'shape': 'ellipse',
            'opacity': 1,
            'border-color': "black",
            'border-width': "5"
        }
    }, {
        'selector': 'edge',
        'style': {
            'line-color': 'data(color)',
            "opacity": 0.7,
            "curve-style": "unbundled-bezier",
            "width": "data(weight)",
            "target-arrow-shape": "vee",
            "target-arrow-color": "black",
            'z-index': 1,
            'font-size': f'{font_size}'
        }
    }, {
        'selector': 'edge:selected',
        'style': {
            'line-color': 'red',
            'line-style': "dashed",
            'opacity': 1,
            'z-index': 10,
        }
    }]
    app.layout = html.Div(
        className="wrapper",
        children=[  # wrapper
            html.Div(
                className="header",
                children=[  # header
                    html.Img(src="assets/logo.png", alt="scConnect logo"),
                    html.Div(
                        className="graph-info",
                        id="graph-stat",
                        children=[
                            html.
                            H3(f'Loaded graph with {len(G.nodes())} nodes and {len(G.edges())} edges'
                               )
                        ])
                ]),
            html.Div(
                className="network-settings",
                children=[  # network settings
                    html.H2("Network settings", style={"text-align":
                                                       "center"}),
                    html.Label("Interactions"),
                    dcc.Dropdown(id="network-interaction",
                                 options=[{
                                     'label': "all interactions",
                                     'value': "all"
                                 }] + [{
                                     'label': interaction,
                                     'value': interaction
                                 } for interaction in interactions],
                                 value="all"),
                    # select if only significant ligands and receptors should be shown
                    html.Label("Graph weight:"),
                    dcc.RadioItems(id="weight-select",
                                   options=[{
                                       "label": "Score",
                                       "value": "score"
                                   }, {
                                       "label": "Log score",
                                       "value": "log_score"
                                   }, {
                                       "label": "Specificity",
                                       "value": "specificity"
                                   }, {
                                       "label": "Importance",
                                       "value": "importance"
                                   }],
                                   value="importance",
                                   labelStyle={
                                       'display': 'block',
                                       "margin-left": "50px"
                                   },
                                   style={
                                       "padding": "10px",
                                       "margin": "auto"
                                   }),
                    html.Label("Graph Layout"),
                    dcc.Dropdown(
                        id="network-layout",
                        options=[{
                            'label':
                            name.capitalize(),
                            'value':
                            name
                        } for name in [
                            'grid', 'random', 'circle', 'cose', 'concentric',
                            'breadthfirst', 'cose-bilkent', 'cola', 'euler',
                            'spread', 'dagre', 'klay'
                        ]],
                        value="circle",
                        clearable=False),
                    html.Label("Weight Filter",
                               style={
                                   "paddingBottom": 500,
                                   "paddingTop": 500
                               }),
                    dcc.
                    Slider(  # min, max and value are set dynamically via a callback
                        id="network-filter",
                        step=0.001,
                        updatemode="drag",
                        tooltip={
                            "always_visible": True,
                            "placement": "right"
                        },
                    ),
                    html.Label("Node size"),
                    dcc.RangeSlider(id="node-size",
                                    value=[10, 50],
                                    min=0,
                                    max=100,
                                    updatemode="drag"),
                    html.Label("Select gene"),
                    dcc.Dropdown(
                        id="gene_dropdown",
                        options=[{
                            "label": gene,
                            "value": gene
                        } for gene in genes],
                        clearable=True,
                        placeholder="Color by gene expression",
                    ),

                    # Store node colors "hidden" for gene expresison
                    html.Div(id="node-colors",
                             style={"display": "none"},
                             children=[""]),
                    html.Div(id="min-max", children=[]),
                    # Click to download image of network graph
                    html.Button(children="Download current view",
                                id="download-network-graph",
                                style={"margin": "10px"})
                ]),  # end network settings
            html.Div(
                id="network-graph",
                className="network-graph",
                children=[  # network graph
                    html.H2("Network graph", style={"text-align": "center"}),
                    cyto.Cytoscape(id="cyto-graph",
                                   style={
                                       'width': '100%',
                                       'height': '80vh'
                                   },
                                   stylesheet=default_stylesheet,
                                   elements=elements,
                                   autoRefreshLayout=True,
                                   zoomingEnabled=False)
                ]),  # end network graph
            html.Div(
                className="sankey-settings",
                children=[  # network settings
                    html.H2("Sankey Settings", style={"text-align": "center"}),
                    html.Label("Weight Filter"),
                    dcc.Slider(id="sankey-filter",
                               min=min(scores),
                               max=max(scores),
                               value=0.75,
                               step=0.001,
                               updatemode="drag",
                               tooltip={
                                   "always_visible": True,
                                   "placement": "right"
                               }),
                    html.Label("Toggle weighted"),
                    dcc.RadioItems(id="sankey-toggle",
                                   options=[{
                                       "label": "Score",
                                       "value": "score"
                                   }, {
                                       "label": "Log score",
                                       "value": "log_score"
                                   }, {
                                       "label": "Specificity",
                                       "value": "specificity"
                                   }, {
                                       "label": "Importance",
                                       "value": "importance"
                                   }],
                                   value="importance",
                                   labelStyle={"display": "block"})
                ]),  # end network settings
            html.Div(
                className="sankey",
                id="sankey",
                children=[  # sankey graph
                    html.H2("Sankey graph", style={"text-align": "center"}),
                    dcc.Graph(id="sankey-graph")
                ]),  # end sankey graph
            html.Div(
                className="interaction-list",
                children=[  # interaction list
                    html.Div(id="selection",
                             children=[
                                 html.H2("Interactions",
                                         style={"text-align": "center"}),
                                 html.H3(id="edge-info",
                                         style={"text-align": "center"}),
                                 dcc.Graph(id="interaction-scatter"),
                                 html.Div(id="interaction-selection",
                                          style={"display": "none"},
                                          children=[""])
                             ]),
                    html.Div(children=[
                        dash_table.DataTable(
                            id="edge-selection",
                            page_size=20,
                            style_table={
                                "overflowX": "scroll",
                                "overflowY": "scroll",
                                "height": "50vh",
                                "width": "95%"
                            },
                            style_cell_conditional=[{
                                "if": {
                                    "column_id": "interaction"
                                },
                                "textAlign": "left"
                            }, {
                                "if": {
                                    "column_id": "receptorfamily"
                                },
                                "textAlign": "left"
                            }, {
                                "if": {
                                    "column_id": "pubmedid"
                                },
                                "textAlign": "left"
                            }],
                            style_header={
                                "fontWeight": "bold",
                                "maxWidth": "200px",
                                "minWidth": "70px"
                            },
                            style_data={
                                "maxWidth": "200px",
                                "minWidth": "70px",
                                "textOverflow": "ellipsis"
                            },
                            sort_action="native",
                            fixed_rows={
                                'headers': True,
                                'data': 0
                            })
                    ])
                ]),  # end interaction list
            html.Div(
                className="L-R-scores",
                children=[  # ligand and receptor lists
                    html.H2("Ligand and receptors",
                            style={"text-align": "center"}),
                    html.Div(children=[
                        html.H3(
                            id="selected-node",
                            style={"text-align": "center"},
                            children=["Select a node in the notwork graph"]),
                        html.Label("Search for ligands and receptors:",
                                   style={"margin-right": "10px"}),
                        dcc.Input(id="filter_l_r",
                                  type="search",
                                  value="",
                                  placeholder="Search")
                    ]),
                    dcc.Tabs([
                        dcc.Tab(label="Ligands",
                                children=[
                                    dcc.Graph(id="ligand-graph",
                                              config=dict(autosizable=True,
                                                          responsive=True)),
                                    dash_table.DataTable(
                                        id="ligand-table",
                                        page_size=20,
                                        style_table={
                                            "overflowX": "scroll",
                                            "overflowY": "scroll",
                                            "height": "50vh",
                                            "width": "95%"
                                        },
                                        style_cell_conditional=[{
                                            "if": {
                                                "column_id": "Ligand"
                                            },
                                            "textAlign":
                                            "left"
                                        }],
                                        style_header={
                                            "fontWeight": "bold",
                                            "maxWidth": "200px",
                                            "minWidth": "70px"
                                        },
                                        style_data={
                                            "maxWidth": "200px",
                                            "minWidth": "70px",
                                            "textOverflow": "ellipsis"
                                        },
                                        sort_action="native",
                                        fixed_rows={
                                            'headers': True,
                                            'data': 0
                                        })
                                ]),
                        dcc.Tab(label="Receptors",
                                children=[
                                    dcc.Graph(id="receptor-graph",
                                              config=dict(autosizable=True,
                                                          responsive=True)),
                                    dash_table.DataTable(
                                        id="receptor-table",
                                        page_size=20,
                                        style_table={
                                            "overflowX": "scroll",
                                            "overflowY": "scroll",
                                            "height": "50vh",
                                            "width": "95%"
                                        },
                                        style_cell_conditional=[{
                                            "if": {
                                                "column_id": "Receptor"
                                            },
                                            "textAlign":
                                            "left"
                                        }],
                                        style_header={
                                            "fontWeight": "bold",
                                            "maxWidth": "200px",
                                            "minWidth": "70px"
                                        },
                                        style_data={
                                            "maxWidth": "200px",
                                            "minWidth": "70px",
                                            "textOverflow": "ellipsis"
                                        },
                                        sort_action="native",
                                        fixed_rows={
                                            'headers': True,
                                            'data': 0
                                        })
                                ])
                    ])
                ])  # end ligand receptor list
        ])  # end wrapper

    # Instantiate the graph and produce the bounderies for filters
    @app.callback([
        Output("cyto-graph", "elements"),
        Output("network-filter", "min"),
        Output("network-filter", "max"),
        Output("network-filter", "value")
    ], [
        Input("network-interaction", "value"),
        Input("weight-select", "value")
    ])
    def make_graph(interaction, score):
        G_flat = G_to_flat(G, score)

        if interaction == "all":  # if no interaction is selected, use full graph
            G_cyto = nx.cytoscape_data(G_flat)
            weights = [d["weight"] for u, v, d in G_flat.edges(data=True)]

            # prepare data for network graph
            nodes = G_cyto["elements"]["nodes"]
            edges = G_cyto["elements"]["edges"]
            elements = nodes + edges

            return elements, min(weights), max(weights), np.mean(weights)

        else:  # an interaction is selected, select only that interaction
            G_split = cn.graph.split_graph(G)
            G_split_flat = G_to_flat(G_split[interaction], score)
            G_cyto = nx.cytoscape_data(G_split_flat)
            weights = [
                d["weight"] for u, v, d in G_split_flat.edges(data=True)
            ]

            # prepare data for network graph
            nodes = G_cyto["elements"]["nodes"]
            edges = G_cyto["elements"]["edges"]
            elements = nodes + edges

            return elements, min(weights), max(weights), np.mean(weights)

    # Change layout of network graph

    @app.callback(Output("cyto-graph", "layout"),
                  [Input("network-layout", "value")])
    def update_network_layout(layout):
        return {"name": layout, "automate": True, "fit": True}

    # Choose gene to color nodes by

    @app.callback(
        [Output("node-colors", "children"),
         Output("min-max", "children")], [Input("gene_dropdown", "value")])
    def calculate_colors(gene):
        if gene is None:
            return [None, ""]
        # get all gene expression values for selected gene
        gene_data = {
            celltype["data"]["id"]: celltype["data"]["genes"][gene]
            for celltype in nodes
        }

        min_value = min(gene_data.values())
        max_value = max(gene_data.values())

        # package min max expression information to a list that will be returned
        expression = html.Ul(children=[
            html.Li(f"minimum gene expression: {min_value}"),
            html.Li(f"maximum gene expression: {max_value}")
        ])

        cmap = matplotlib.cm.get_cmap("coolwarm")

        color_dict = dict()
        for k, v in gene_data.items():
            color_dict[k] = {"rgb": cmap(v, bytes=True)[0:3], "expression": v}

        color = pd.Series(color_dict)

        return color.to_json(), expression

    # Select visible edges of network graph depending on filter value
    # node color depending on selected gene
    # width of edges

    @app.callback(Output("cyto-graph", "stylesheet"), [
        Input("network-filter", "value"),
        Input("network-filter", "min"),
        Input("network-filter", "max"),
        Input("node-size", "value"),
        Input("node-colors", "children")
    ])
    def style_network_graph(th, min_weight, max_weight, size, colors):

        # create a filter for edges
        filter_style = [{
            "selector": f"edge[weight < {th}]",
            "style": {
                "display": "none"
            }
        }, {
            "selector": "node",
            "style": {
                'height':
                f'mapData(centrality, {min(cent)}, {max(cent)}, {size[0]}, {size[1]})',
                'width':
                f'mapData(centrality, {min(cent)}, {max(cent)}, {size[0]}, {size[1]})'
            }
        }]

        # create a color style for nodes based on gene expression
        if isinstance(colors, str):
            colors = pd.read_json(colors, typ="series", convert_dates=False)
            color_style = [{
                'selector': f'node[id = "{str(index)}"]',
                'style': {
                    'background-color': f'rgb{tuple(colors[index]["rgb"])}'
                }
            } for index in colors.index]
            filter_style += color_style
        else:
            color_style = {
                "selector": "node",
                "style": {
                    'background-color': 'BFD7B5'
                }
            }

        # Map edges width to a set min and max value (scale for visibility)
        edge_style = [{
            "selector": "edge",
            "style": {
                "width": f"mapData(weight, {min_weight}, {max_weight}, 1, 10)"
            }
        }]

        return default_stylesheet + filter_style + edge_style

    # download an image of current network graph view
    @app.callback(Output("cyto-graph", "generateImage"),
                  Input("download-network-graph", "n_clicks"))
    def download_networkgraph_image(get_request):

        if get_request == None:
            return dict()

        return {"type": "svg", "action": "download"}

    # Produce a table of all edge data from tapped edge
    @app.callback([
        Output("edge-info", "children"),
        Output("edge-selection", "columns"),
        Output("edge-selection", "data")
    ], [
        Input("cyto-graph", "tapEdgeData"),
        Input("interaction-selection", "children")
    ])
    def update_data(edge, selection):
        import pandas as pd
        import json

        # check if an edge has really been clicked, return default otherwise
        if edge is None:
            return ["", None, None]

        info = f"Interactions from {edge['source']} to {edge['target']}."

        # map visible names for columns with columns in edge[interaction]
        columns = [{
            "name": "Interaction",
            "id": "interaction"
        }, {
            "name": "Receptor Family",
            "id": "receptorfamily"
        }, {
            "name": "Score",
            "id": "score"
        }, {
            "name": "Log10(score)",
            "id": "log_score"
        }, {
            "name": "Specificity",
            "id": "specificity"
        }, {
            "name": "Importance",
            "id": "importance"
        }, {
            "name": "Ligand z-score",
            "id": "ligand_zscore"
        }, {
            "name": "Ligand p-value",
            "id": "ligand_pval"
        }, {
            "name": "Receptor z-score",
            "id": "receptor_zscore"
        }, {
            "name": "Receptor p-value",
            "id": "receptor_pval"
        }, {
            "name": "PubMed ID",
            "id": "pubmedid"
        }]

        interactions = pd.DataFrame(edge["interactions"])[[
            "interaction", "receptorfamily", "score", "log_score",
            "specificity", "importance", "ligand_zscore", "ligand_pval",
            "receptor_zscore", "receptor_pval", "pubmedid"
        ]]

        # Sort values based on score
        interactions.sort_values(by="score", ascending=False, inplace=True)

        # round values for scores to two decimals
        interactions[[
            "score", "log_score", "specificity", "importance", "ligand_zscore",
            "receptor_zscore"
        ]] = interactions[[
            "score", "log_score", "specificity", "importance", "ligand_zscore",
            "receptor_zscore"
        ]].round(decimals=2)

        interactions[["ligand_pval", "receptor_pval"
                      ]] = interactions[["ligand_pval",
                                         "receptor_pval"]].round(decimals=4)

        # if selection from interaction graph, filter dataframe
        if selection != "":
            selection = json.loads(selection)
            interactions = interactions.loc[interactions["interaction"].isin(
                selection)]

        records = interactions.to_dict("records")

        return [info, columns, records]

    @app.callback([Output("interaction-scatter", "figure")],
                  [Input("cyto-graph", "tapEdgeData")])
    def interaction_scatter_plot(edge):
        import plotly.express as px

        fig = go.Figure()
        if not isinstance(edge, dict):
            return [
                fig,
            ]

        interactions = pd.DataFrame(edge["interactions"])[[
            "interaction", "receptorfamily", "score", "log_score",
            "ligand_zscore", "ligand_pval", "receptor_zscore", "receptor_pval",
            "specificity", "importance", "pubmedid"
        ]]

        # add 10% to the min and max value to not clip the datapoint
        range_x = (-max(interactions["log_score"]) * 0.1,
                   max(interactions["log_score"]) * 1.1)
        range_y = (-max(interactions["specificity"]) * 0.1,
                   max(interactions["specificity"]) * 1.1)
        #interactions["specificity"] = np.log10( interactions["specificity"])

        fig = px.scatter(interactions,
                         x="log_score",
                         range_x=range_x,
                         y="specificity",
                         range_y=range_y,
                         color="importance",
                         hover_name="interaction",
                         hover_data=[
                             "ligand_pval", "receptor_pval", "score",
                             "specificity", "receptorfamily"
                         ],
                         color_continuous_scale=px.colors.sequential.Viridis_r,
                         labels={
                             "ligand_zscore": "Ligand Z-score",
                             "receptor_zscore": "Receptor Z-score",
                             "log_score": "log(Interaction score)",
                             "score": "Interaction score",
                             "specificity": "Specificity",
                             "importance": "Importance",
                             "receptorfamily": "Receptor family",
                             "pubmedid": "PubMed ID",
                             "ligand_pval": "Ligand p-value",
                             "receptor_pval": "Receptor p-value"
                         })
        return [
            fig,
        ]

    @app.callback(Output("interaction-selection", "children"),
                  [Input("interaction-scatter", "selectedData")])
    def interaction_select(selected_data):
        import json
        if isinstance(selected_data, dict):
            interactions = [
                point["hovertext"] for point in selected_data["points"]
            ]
        else:
            return ""
        return json.dumps(interactions)

    # Produce ligand and receptor graphs based on tapped node

    @app.callback([
        Output("ligand-graph", "figure"),
        Output("receptor-graph", "figure"),
        Output("selected-node", "children")
    ], [Input("cyto-graph", "tapNodeData"),
        Input("filter_l_r", "value")])
    def plot_l_r_expression(node, filter_text):

        # set output variables to empty figures
        ligand_fig = go.Figure()
        receptor_fig = go.Figure()
        node_id = "Select a node in the network graph"

        if isinstance(node, dict):
            import plotly.express as px

            node_id = node["id"]

            ligands_score = pd.DataFrame.from_dict(node["ligands_score"],
                                                   orient="index",
                                                   columns=["Score"])
            ligands_zscore = np.log2(
                pd.DataFrame.from_dict(node["ligands_zscore"],
                                       orient="index",
                                       columns=["Z-score"]))
            ligands_corr_pval = pd.DataFrame.from_dict(
                node["ligands_corr_pval"], orient="index", columns=["p-value"])
            ligands_merge = ligands_score.merge(ligands_zscore,
                                                how="left",
                                                left_index=True,
                                                right_index=True)
            ligands_merge = ligands_merge.merge(ligands_corr_pval,
                                                how="left",
                                                left_index=True,
                                                right_index=True)
            ligands_merge["log(score + 1)"] = np.log10(ligands_merge["Score"] +
                                                       1)
            ligands_merge["Significant"] = [
                True if p_val < 0.05 else False
                for p_val in ligands_merge["p-value"]
            ]
            ligands_merge["-log(p-value)"] = -np.log10(
                ligands_merge["p-value"])

            if filter_text != "":
                ligands_merge = ligands_merge.filter(like=filter_text, axis=0)

            ligand_fig = px.scatter(ligands_merge,
                                    x="log(score + 1)",
                                    y="-log(p-value)",
                                    color="Significant",
                                    hover_name=ligands_merge.index,
                                    hover_data=["Score", "Z-score", "p-value"])

            receptors_score = pd.DataFrame.from_dict(node["receptors_score"],
                                                     orient="index",
                                                     columns=["Score"])
            receptors_zscore = np.log2(
                pd.DataFrame.from_dict(node["receptors_zscore"],
                                       orient="index",
                                       columns=["Z-score"]))
            receptors_corr_pval = pd.DataFrame.from_dict(
                node["receptors_corr_pval"],
                orient="index",
                columns=["p-value"])
            receptors_merge = receptors_score.merge(receptors_zscore,
                                                    how="left",
                                                    left_index=True,
                                                    right_index=True)
            receptors_merge = receptors_merge.merge(receptors_corr_pval,
                                                    how="left",
                                                    left_index=True,
                                                    right_index=True)
            receptors_merge["log(score + 1)"] = np.log10(
                receptors_merge["Score"] + 1)
            receptors_merge["Significant"] = [
                True if p_val < 0.05 else False
                for p_val in receptors_merge["p-value"]
            ]
            receptors_merge["-log(p-value)"] = -np.log10(
                receptors_merge["p-value"])

            if filter_text != "":
                receptors_merge = receptors_merge.filter(like=filter_text,
                                                         axis=0)

            receptor_fig = px.scatter(
                receptors_merge,
                x="log(score + 1)",
                y="-log(p-value)",
                color="Significant",
                hover_name=receptors_merge.index,
                hover_data=["Score", "Z-score", "p-value"])

        return [ligand_fig, receptor_fig, node_id]

    # Builds a sankey graph based on the tapped node (store in global G_s)
    G_s = nx.MultiDiGraph()  #variable holding sankey graph

    @app.callback([
        Output("sankey-filter", "min"),
        Output("sankey-filter", "max"),
        Output("sankey-filter", "value")
    ], [Input("cyto-graph", "tapNodeData"),
        Input("sankey-toggle", "value")])
    def build_sankey_graph(node, score):
        import numpy as np
        # If no node has been selected, dont try to build graph
        if node is None:
            return (0, 0, 0)

        node = node["id"]
        # Find all interactions where node is target or source node
        nonlocal G_s
        G_s = nx.MultiDiGraph()  # reset content
        weight = list(
        )  # list to store all weights (used to set min and max for the filter)
        for n, nbrs in G.adj.items(
        ):  # graph has been modified by network graph before
            for nbr, edict in nbrs.items():
                if n == node:
                    for e, d in edict.items():
                        G_s.add_edge(n, " Post " + nbr, **d)
                        weight.append(d[score])
                if nbr == node:
                    for e, d in edict.items():
                        G_s.add_edge("Pre " + n, nbr, **d)
                        weight.append(d[score])

        if len(weight) == 0:
            weight = [0, 1]
        if score == "specificity":
            # set default start value to specificity value for ligand and receptor
            # p-value of (0.05 and 0.05)/2 = 1.3
            return (min(weight), max(weight), 1.3)
        return (min(weight), max(weight), np.mean(weight))

    @app.callback(Output("sankey-graph", "figure"), [
        Input("sankey-filter", "value"),
        Input("sankey-toggle", "value"),
        Input("cyto-graph", "tapNodeData")
    ])
    def filter_sankey_graph(th, score, node):

        if node:
            node = node["id"]

        _G_s = nx.MultiDiGraph()
        for u, v, n, d in G_s.edges(data=True, keys=True):
            if d[score] > th:
                _G_s.add_edge(u, v, n, **d)
        _G_s.add_nodes_from(G_s.nodes(data=True))

        edges = nx.to_pandas_edgelist(_G_s)
        if len(edges) < 1:
            fig = dict()
            return fig
        # add same color scheme as network graph
        for node_s in _G_s.nodes():
            if " Post" in node_s:
                original_node = str(node_s).split(sep=" Post")[1]
            elif "Pre " in node_s:
                original_node = str(node_s).split(sep="Pre ")[1]
            else:
                original_node = str(node_s)

            new_color = color_map_nodes[original_node.strip()]
            G_s.nodes[node_s]["color"] = new_color

        nodes = G_s.nodes()

        node_map = {cluster: id for id, cluster in enumerate(list(nodes))}

        sankey = go.Sankey(node=dict(pad=15,
                                     thickness=20,
                                     line=dict(color="black", width=0.5),
                                     label=list(nodes),
                                     color=[
                                         f'rgb{tuple(d["color"][0:3])}'
                                         for n, d in G_s.nodes(data=True)
                                     ]),
                           link=dict(
                               source=list(edges["source"].map(node_map)),
                               target=list(edges["target"].map(node_map)),
                               value=list(edges[score]),
                               label=edges["interaction"]))

        data = [sankey]

        layout = go.Layout(autosize=True,
                           title=f"Interactions: {node}",
                           font=dict(size=font_size))

        fig = go.Figure(data=data, layout=layout)

        return fig

    @app.callback(
        [Output("ligand-table", "columns"),
         Output("ligand-table", "data")], [
             Input("ligand-graph", "figure"),
             Input("ligand-graph", "selectedData")
         ])
    def select_ligands(figure, selected):
        import json
        ligands = []
        score = []
        zscore = []
        pval = []

        for group in figure["data"]:
            for ligand in group["hovertext"]:
                ligands.append(ligand)
            for data in group["customdata"]:
                score.append(data[0])
                zscore.append(data[1])
                pval.append(data[2])

        df = pd.DataFrame({
            "Ligand": ligands,
            "Score": score,
            "Z-score": zscore,
            "P-value": pval
        })
        df.index = df["Ligand"]
        df.sort_values(by="Score", ascending=False, inplace=True)

        if isinstance(selected, dict):
            filt = []
            for point in selected["points"]:
                filt.append(point["hovertext"])
            df = df.loc[filt]

        columns = [{
            "name": "Ligand",
            "id": "Ligand"
        }, {
            "name": "Score",
            "id": "Score"
        }, {
            "name": "Z-score",
            "id": "Z-score"
        }, {
            "name": "P-value",
            "id": "P-value"
        }]

        data = df.to_dict("records")

        return columns, data

    @app.callback([
        Output("receptor-table", "columns"),
        Output("receptor-table", "data")
    ], [
        Input("receptor-graph", "figure"),
        Input("receptor-graph", "selectedData")
    ])
    def select_ligands(figure, selected):
        import json
        receptors = []
        score = []
        zscore = []
        pval = []

        for group in figure["data"]:
            for receptor in group["hovertext"]:
                receptors.append(receptor)
            for data in group["customdata"]:
                score.append(data[0])
                zscore.append(data[1])
                pval.append(data[2])

        df = pd.DataFrame({
            "Receptor": receptors,
            "Score": score,
            "Z-score": zscore,
            "P-value": pval
        })
        df.index = df["Receptor"]
        df.sort_values(by="Score", ascending=False, inplace=True)

        if isinstance(selected, dict):
            filt = []
            for point in selected["points"]:
                filt.append(point["hovertext"])
            df = df.loc[filt]

        columns = [{
            "name": "Receptor",
            "id": "Receptor"
        }, {
            "name": "Score",
            "id": "Score"
        }, {
            "name": "Z-score",
            "id": "Z-score"
        }, {
            "name": "P-value",
            "id": "P-value"
        }]

        data = df.to_dict("records")

        return columns, data

    # Run server
    app.run_server(**kwargs)