Exemplo n.º 1
0
def link_plot(adata, key, genes=None, basis=['umap', 'pca'], components=[1, 2],
             subsample=None, steps=[40, 40], sample_size=500,
             distance=2, cutoff=True, highlight_only=None, palette=None,
             show_legend=False, legend_loc='top_right', plot_width=None, plot_height=None, save=None):
    """
    Display the distances of cells from currently highlighted cell.

    Params
    --------
    adata: AnnData
        annotated data object
    key: str 
        key in `adata.obs_keys()` to color the static plot
    genes: list(str), optional (default: `None`)
        list of genes in `adata.var_names`,
        which are used to compute the distance;
        if None, take all the genes
    basis: list(str), optional (default:`['umap', 'pca']`)
        list of basis to use when plotting;
        only the first plot is hoverable
    components: list(int); list(list(int)), optional (default: `[1, 2]`)
        list of components for each basis
    subsample: str, optional (default: `None`)
        subsample strategy to use when there are too many cells
        possible values are: `"density"`, `"uniform"`, `None`
    steps: int; list(int), optional (default: `[40, 40]`)
        number of steps in each direction when using `subsample="uniform"`
    sample_size: int, optional (default: `500`)
        number of cells to sample based on their density in the respective embedding
        when using `subsample="density"`; should be < `1000`
    distance: int; str, optional (default: `2`)
        for integers, use p-norm,
        for strings, only `'dpt'` is available
    cutoff: bool, optional (default: `True`)
        if `True`, do not color cells whose distance is further away
        than the threshold specified by the slider
    highlight_only: 'str', optional (default: `None`)
        key in `adata.obs_keys()`, which makes highlighting
        work only on clusters specified by this parameter
    palette: matplotlib.colors.Colormap; list(str), optional (default: `None`)
        colormap to use, if None, use plt.cm.RdYlBu 
    show_legend: bool, optional (default: `False`)
        display the legend also in the linked plot
    legend_loc: str, optional (default `'top_right'`)
        location of the legend
    seed: int, optional (default: `None`)
        seed when `subsample='density'`
    plot_width: int, optional (default: `None`)
        width of the plot
    plot_height: int, optional (default: `None`)
        height of the plot
    save: Union[os.PathLike, Str, NoneType], optional (default: `None`)
        path where to save the plot

    Returns
    --------
    None
    """

    assert key in adata.obs.keys(), f'`{key}` not found in `adata.obs`.'

    if subsample == 'uniform':
        adata, _ = sample_unif(adata, steps, basis[0])
    elif subsample == 'density':
        adata, _ = sample_density(adata, sample_size, basis[0], seed=seed)
    elif subsample is not None:
        raise ValueError(f'Unknown subsample strategy: `{subsample}`.')

    palette = cm.RdYlBu if palette is None else palette
    if isinstance(palette, matplotlib.colors.Colormap):
        palette = to_hex_palette(palette(range(palette.N), 1., bytes=True))

    if not isinstance(components[0], list):
        components = [components]

    if len(components) != len(basis):
        assert len(basis) % len(components) == 0 and len(basis) >= len(components)
        components = components * (len(basis) // len(components))

    if not isinstance(components, np.ndarray):
        components = np.asarray(components)

    if highlight_only is not None:
        assert highlight_only in adata.obs_keys(), f'`{highlight_only}` is not in adata.obs_keys().'

    genes = adata.var_names if genes is None else genes 
    gene_subset = np.in1d(adata.var_names, genes)

    if distance != 'dpt':
        d = adata.X[:, gene_subset]
        if issparse(d):
            d = d.A
        dmat = distance_matrix(d, d, p=distance)
    else:
        if not all(gene_subset):
            warnings.warn('`genes` is not None, are you sure this is what you want when using `dpt` distance?')

        dmat = []
        ad_tmp = adata.copy()
        ad_tmp = ad_tmp[:, gene_subset]
        for i in range(ad_tmp.n_obs):
            ad_tmp.uns['iroot'] = i
            sc.tl.dpt(ad_tmp)
            dmat.append(list(ad_tmp.obs['dpt_pseudotime'].replace([np.nan, np.inf], [0, 1])))

    dmat = pd.DataFrame(dmat, columns=list(map(str, range(adata.n_obs))))
    df = pd.concat([pd.DataFrame(adata.obsm[f'X_{bs}'][:, comp - (bs != 'diffmap')], columns=[f'x{i}', f'y{i}'])
                    for i, (bs, comp) in enumerate(zip(basis, components))] + [dmat], axis=1)
    df['hl_color'] = np.nan
    df['index'] = range(len(df))
    df['hl_key'] = list(adata.obs[highlight_only]) if highlight_only is not None else 0
    df[key] = list(map(str, adata.obs[key]))

    start_ix = '0'  # our root cell
    ds = ColumnDataSource(df)
    mapper = linear_cmap(field_name='hl_color', palette=palette,
                         low=df[start_ix].min(), high=df[start_ix].max())
    static_fig_mapper = _create_mapper(adata, key)

    static_figs = []
    figs, renderers = [], []
    for i, bs in enumerate(basis):
        # linked plots
        fig = figure(tools='pan, reset, save, ' + ('zoom_in, zoom_out' if i == 0 else 'wheel_zoom'),
                     title=bs, plot_width=400, plot_height=400)
        _set_plot_wh(fig, plot_width, plot_height)

        kwargs = {}
        if show_legend and legend_loc is not None:
            kwargs['legend_group'] = 'hl_key' if highlight_only is not None else key

        scatter = fig.scatter(f'x{i}', f'y{i}', source=ds, line_color=mapper, color=mapper,
                              hover_color='black', size=8, line_width=8, line_alpha=0, **kwargs)
        if show_legend and legend_loc is not None:
            fig.legend.location = legend_loc

        figs.append(fig)
        renderers.append(scatter)
    
        # static plots
        fig = figure(title=bs, plot_width=400, plot_height=400)

        fig.scatter(f'x{i}', f'y{i}', source=ds, size=8,
                    color={'field': key, 'transform': static_fig_mapper}, **kwargs)

        if legend_loc is not None:
            fig.legend.location = legend_loc
    
        static_figs.append(fig)

    fig = figs[0]

    end = dmat[~np.isinf(dmat)].max().max() if distance != 'dpt' else 1.0
    slider = Slider(start=0, end=end, value=end / 2, step=end / 1000,
                    title='Distance ' + '(dpt)' if distance == 'dpt' else f'({distance}-norm)')
    col_ds = ColumnDataSource(dict(value=[start_ix]))
    update_color_code = f'''
        source.data['hl_color'] = source.data[first].map(
            (x, i) => {{ return isNaN(x) ||
                        {'x > slider.value || ' if cutoff else ''}
                        source.data['hl_key'][first] != source.data['hl_key'][i]  ? NaN : x; }}
        );
    '''
    slider.callback = CustomJS(args={'slider': slider, 'mapper': mapper['transform'], 'source': ds, 'col': col_ds}, code=f'''
        mapper.high = slider.value;
        var first = col.data['value'];
        {update_color_code}
        source.change.emit();
    ''')

    h_tool = HoverTool(renderers=renderers, tooltips=[], show_arrow=False)
    h_tool.callback = CustomJS(args=dict(source=ds, slider=slider, col=col_ds), code=f'''
        var indices = cb_data.index['1d'].indices;
        if (indices.length == 0) {{
            source.data['hl_color'] = source.data['hl_color'];
        }} else {{
            var first = indices[0];
            source.data['hl_color'] = source.data[first];
            {update_color_code}
            col.data['value'] = first;
            col.change.emit();
        }}
        source.change.emit();
    ''')
    fig.add_tools(h_tool)

    color_bar = ColorBar(color_mapper=mapper['transform'], width=12, location=(0,0))
    fig.add_layout(color_bar, 'left')

    fig.add_tools(h_tool)
    plot = column(slider, row(*static_figs), row(*figs))

    if save is not None:
        save = save if str(save).endswith('.html') else str(save) + '.html'
        bokeh_save(plot, save)
    else:
        show(plot)
Exemplo n.º 2
0
def spectroscopy_plot(obj_id, spec_id=None):
    obj = Obj.query.get(obj_id)
    spectra = Obj.query.get(obj_id).spectra
    if spec_id is not None:
        spectra = [spec for spec in spectra if spec.id == int(spec_id)]
    if len(spectra) == 0:
        return None, None, None

    color_map = dict(zip([s.id for s in spectra], viridis(len(spectra))))

    data = []
    for i, s in enumerate(spectra):

        # normalize spectra to a common average flux per resolving
        # element of 1 (facilitates easy visual comparison)
        normfac = np.sum(np.gradient(s.wavelengths) * s.fluxes) / len(s.fluxes)

        if not (np.isfinite(normfac) and normfac > 0):
            # otherwise normalize the value at the median wavelength to 1
            median_wave_index = np.argmin(
                np.abs(s.wavelengths - np.median(s.wavelengths)))
            normfac = s.fluxes[median_wave_index]

        df = pd.DataFrame({
            'wavelength':
            s.wavelengths,
            'flux':
            s.fluxes / normfac,
            'id':
            s.id,
            'telescope':
            s.instrument.telescope.name,
            'instrument':
            s.instrument.name,
            'date_observed':
            s.observed_at.date().isoformat(),
            'pi':
            s.assignment.run.pi if s.assignment is not None else "",
        })
        data.append(df)
    data = pd.concat(data)

    dfs = []
    for i, s in enumerate(spectra):
        # Smooth the spectrum by using a rolling average
        df = (pd.DataFrame({
            'wavelength': s.wavelengths,
            'flux': s.fluxes
        }).rolling(2).mean(numeric_only=True).dropna())
        dfs.append(df)

    smoothed_data = pd.concat(dfs)

    split = data.groupby('id')
    hover = HoverTool(tooltips=[
        ('wavelength', '$x'),
        ('flux', '$y'),
        ('telesecope', '@telescope'),
        ('instrument', '@instrument'),
        ('UTC date observed', '@date_observed'),
        ('PI', '@pi'),
    ])
    smoothed_max = np.max(smoothed_data['flux'])
    smoothed_min = np.min(smoothed_data['flux'])
    ymax = smoothed_max * 1.05
    ymin = smoothed_min - 0.05 * (smoothed_max - smoothed_min)
    xmin = np.min(data['wavelength']) - 100
    xmax = np.max(data['wavelength']) + 100
    plot = figure(
        plot_width=600,
        plot_height=300,
        y_range=(ymin, ymax),
        x_range=(xmin, xmax),
        sizing_mode='scale_both',
        tools='box_zoom,wheel_zoom,pan,reset',
        active_drag='box_zoom',
    )
    plot.add_tools(hover)
    model_dict = {}
    for i, (key, df) in enumerate(split):
        model_dict['s' + str(i)] = plot.line(x='wavelength',
                                             y='flux',
                                             color=color_map[key],
                                             source=ColumnDataSource(df))
    plot.xaxis.axis_label = 'Wavelength (Å)'
    plot.yaxis.axis_label = 'Flux'
    plot.toolbar.logo = None

    # TODO how to choose a good default?
    plot.y_range = Range1d(0, 1.03 * data.flux.max())

    toggle = CheckboxWithLegendGroup(
        labels=[
            f'{s.instrument.telescope.nickname}/{s.instrument.name} ({s.observed_at.date().isoformat()})'
            for s in spectra
        ],
        active=list(range(len(spectra))),
        colors=[color_map[k] for k, df in split],
    )
    toggle.callback = CustomJS(
        args={
            'toggle': toggle,
            **model_dict
        },
        code="""
          for (let i = 0; i < toggle.labels.length; i++) {
              eval("s" + i).visible = (toggle.active.includes(i))
          }
    """,
    )

    z_title = Div(text="Redshift (<i>z</i>): ")
    z_slider = Slider(
        value=obj.redshift if obj.redshift is not None else 0.0,
        start=0.0,
        end=1.0,
        step=0.001,
        show_value=False,
        format="0[.]000",
    )
    z_textinput = TextInput(
        value=str(obj.redshift if obj.redshift is not None else 0.0))
    z_slider.callback = CustomJS(
        args={
            'slider': z_slider,
            'textinput': z_textinput
        },
        code="""
            textinput.value = slider.value.toFixed(3).toString();
            textinput.change.emit();
        """,
    )
    z = column(z_title, z_slider, z_textinput)

    v_title = Div(text="<i>V</i><sub>expansion</sub> (km/s): ")
    v_exp_slider = Slider(
        value=0.0,
        start=0.0,
        end=3e4,
        step=10.0,
        show_value=False,
    )
    v_exp_textinput = TextInput(value='0')
    v_exp_slider.callback = CustomJS(
        args={
            'slider': v_exp_slider,
            'textinput': v_exp_textinput
        },
        code="""
            textinput.value = slider.value.toFixed(0).toString();
            textinput.change.emit();
        """,
    )
    v_exp = column(v_title, v_exp_slider, v_exp_textinput)

    for i, (wavelengths, color) in enumerate(SPEC_LINES.values()):
        el_data = pd.DataFrame({'wavelength': wavelengths})
        obj_redshift = 0 if obj.redshift is None else obj.redshift
        el_data['x'] = el_data['wavelength'] * (1.0 + obj_redshift)
        model_dict[f'el{i}'] = plot.segment(
            x0='x',
            x1='x',
            # TODO change limits
            y0=0,
            y1=1e-13,
            color=color,
            source=ColumnDataSource(el_data),
        )
        model_dict[f'el{i}'].visible = False

    # Split spectral lines into 3 columns
    element_dicts = np.array_split(
        np.array(list(SPEC_LINES.items()), dtype=object), 3)
    elements_groups = []
    col_offset = 0
    for element_dict in element_dicts:
        labels = [key for key, value in element_dict]
        colors = [c for key, (w, c) in element_dict]
        elements = CheckboxWithLegendGroup(
            labels=labels,
            active=[],
            colors=colors,
        )
        elements_groups.append(elements)

        # TODO callback policy: don't require submit for text changes?
        elements.callback = CustomJS(
            args={
                'elements': elements,
                'z': z_textinput,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code=f"""
            let c = 299792.458; // speed of light in km / s
            const i_max = {col_offset} + elements.labels.length;
            let local_i = 0;
            for (let i = {col_offset}; i < i_max; i++) {{
                let el = eval("el" + i);
                el.visible = (elements.active.includes(local_i))
                el.data_source.data.x = el.data_source.data.wavelength.map(
                    x_i => (x_i * (1 + parseFloat(z.value)) /
                                    (1 + parseFloat(v_exp.value) / c))
                );
                el.data_source.change.emit();
                local_i++;
            }}
        """,
        )

        col_offset += len(labels)

    # Our current version of Bokeh doesn't properly execute multiple callbacks
    # https://github.com/bokeh/bokeh/issues/6508
    # Workaround is to manually put the code snippets together
    z_textinput.js_on_change(
        'value',
        CustomJS(
            args={
                'elements0': elements_groups[0],
                'elements1': elements_groups[1],
                'elements2': elements_groups[2],
                'z': z_textinput,
                'slider': z_slider,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code="""
            // Update slider value to match text input
            slider.value = parseFloat(z.value).toFixed(3);

            // Update plot data for each element
            let c = 299792.458; // speed of light in km / s
            const offset_col_1 = elements0.labels.length;
            const offset_col_2 = offset_col_1 + elements1.labels.length;
            const i_max = offset_col_2 + elements2.labels.length;
            for (let i = 0; i < i_max; i++) {{
                let el = eval("el" + i);
                el.visible =
                    elements0.active.includes(i) ||
                    elements1.active.includes(i - offset_col_1) ||
                    elements2.active.includes(i - offset_col_2);
                el.data_source.data.x = el.data_source.data.wavelength.map(
                    x_i => (x_i * (1 + parseFloat(z.value)) /
                                    (1 + parseFloat(v_exp.value) / c))
                );
                el.data_source.change.emit();
            }}
        """,
        ),
    )

    v_exp_textinput.js_on_change(
        'value',
        CustomJS(
            args={
                'elements0': elements_groups[0],
                'elements1': elements_groups[1],
                'elements2': elements_groups[2],
                'z': z_textinput,
                'slider': v_exp_slider,
                'v_exp': v_exp_textinput,
                **model_dict,
            },
            code="""
            // Update slider value to match text input
            slider.value = parseFloat(v_exp.value).toFixed(3);

            // Update plot data for each element
            let c = 299792.458; // speed of light in km / s
            const offset_col_1 = elements0.labels.length;
            const offset_col_2 = offset_col_1 + elements1.labels.length;
            const i_max = offset_col_2 + elements2.labels.length;
            for (let i = 0; i < i_max; i++) {{
                let el = eval("el" + i);
                el.visible =
                    elements0.active.includes(i) ||
                    elements1.active.includes(i - offset_col_1) ||
                    elements2.active.includes(i - offset_col_2);
                el.data_source.data.x = el.data_source.data.wavelength.map(
                    x_i => (x_i * (1 + parseFloat(z.value)) /
                                    (1 + parseFloat(v_exp.value) / c))
                );
                el.data_source.change.emit();
            }}
        """,
        ),
    )

    row1 = row(plot, toggle)
    row2 = row(elements_groups)
    row3 = row(z, v_exp)
    layout = column(row1, row2, row3)
    return _plot_to_json(layout)
Exemplo n.º 3
0
p.title.align = 'center'
p.xgrid.grid_line_color = None
p.ygrid.grid_line_color = None
p.line(x, y, line_width=2, color='blue')
p.line(x='x', y='y', source=source, color='orange', alpha=.5)
slider = Slider(start=20, end=1000, value=20, step=1, title="Nombre")

slider.callback = CustomJS(args=dict(source=source, slider=slider),
                           code="""
    F = (t, y) => Math.cos(t) * y;
    var a = 0;
    var b = 10;
    var data = source.data;
    data.x = new Array();
    data.y = new Array();
    var x = 0
    var y = 1;
    data.x.push(x);
    data.y.push(y);
    var step = (b-a)/slider.value;
    for (var k=0; k<slider.value; k++){
        y += F(x, y) * step;
        x += step;
        data.x.push(x);
        data.y.push(y);
    }
    source.change.emit();
""")

show(column(p, slider))
def principal_window_words(inputs_file):
    output_file(filename="../dashboards/words_representations.html", title="Words representations")

    ###################################################
    # PREPARING DATA
    ###################################################
    sources = []
    sources_visible = []
    languages = []
    colors = viridis(len(inputs_file))
    for input_file in inputs_file:
        x = list()
        y = list()
        words = list()
        language = input_file[11:-5]
        languages.append(language)
        with open(input_file, 'r') as file:
            array = json.load(file)
            data = array['content']
            for translation in data:
                for word, vector in translation.items():
                    words.append(word)
                    x.append(vector[0])
                    y.append(vector[1])

            source_dict = {
                'x': x,
                'y': y,
                'words': words,
            }
            source = ColumnDataSource(source_dict)
            source_visible = ColumnDataSource({
                'x': x[:len(x) // 100 * 10],
                'y': y[:len(y) // 100 * 10],
                'words': words[:len(words) // 100 * 10],
            })
        sources.append(source)
        sources_visible.append(source_visible)

    ###################################################
    # SET UP MAIN FIGURE
    ###################################################

    tools = "hover, tap, box_zoom, box_select, reset, help"
    p = figure(tools=tools, title='Words intermediate representations\n', plot_width=1000, plot_height=650)

    for index in range(len(sources_visible)):
        p.circle('x', 'y',
                 size=4, alpha=0.4,
                 hover_color='red', hover_alpha=1.0,
                 selection_color='red',
                 nonselection_color='white',
                 source=sources_visible[index],
                 color=colors[index],
                 legend=languages[index],
                 name="data_{}".format(index))

    p.legend.click_policy = "hide"

    ###################################################
    # ADD LINKS IN HOVERTOOL
    ###################################################

    hover = p.select(dict(type=HoverTool))
    hover.tooltips = [
        ("words", "@words"),
    ]

    ###################################################
    # SET UP SLIDER
    ###################################################

    slider = Slider(title='Percentage of words',
                    value=10,
                    start=1,
                    end=100,
                    step=1)
    slider.callback = CustomJS(args=dict(sources_visible=sources_visible, sources=sources), code="""
            var percentage = cb_obj.value;
            // Get the data from the data sources
            for(var i=0; i < sources.length; i++) {
                var point_visible = sources_visible[i].data;
                var point_available = sources[i].data;
                var nbr_points = (point_available.x.length / 100) * percentage
    
                point_visible.x = []
                point_visible.y = []
    
                // Update the visible data
                for(var j = 0; j < nbr_points; j++) {  
                    point_visible.x.push(point_available.x[j]);
                    point_visible.y.push(point_available.y[j]);
                }  
                sources_visible[i].change.emit();
            }
            """)

    ###################################################
    # SET UP DATATABLE
    ###################################################

    columns0 = [TableColumn(field="words", title="Words in English")]
    columns1 = [TableColumn(field="words", title="Words in Spanish")]
    columns2 = [TableColumn(field="words", title="Words in French")]

    data_table0 = DataTable(source=sources_visible[0], columns=columns0, width=300, height=175)
    data_table1 = DataTable(source=sources_visible[1], columns=columns1, width=300, height=175)
    data_table2 = DataTable(source=sources_visible[2], columns=columns2, width=300, height=175)

    ###################################################
    # CREATION OF THE LAYOUT
    ###################################################

    window = layout([[p, column(slider, data_table0, data_table1, data_table2)]])
    # curdoc().add_root(window)
    show(window)
Exemplo n.º 5
0
def get_spectral_network_bokeh_plot(
    spectral_network_data,
    plot_range=None,
    plot_joints=False,
    plot_data_points=False,
    plot_on_cylinder=False,
    plot_two_way_streets=False,
    soliton_trees=None,
    no_unstable_streets=False,
    soliton_tree_data=None,
    plot_width=800,
    plot_height=800,
    notebook=False,
    slide=False,
    logger_name=None,
    marked_points=[],
    without_errors=False,
    download=False,
    #downsample=True,
    downsample=False,
    downsample_ratio=None,
):
    logger = logging.getLogger(logger_name)

    # Determine if the data set corresponds to a multi-parameter
    # configuration.
    if type(spectral_network_data.sw_data) is list:
        multi_parameter = True
    else:
        multi_parameter = False

    if without_errors is True:
        spectral_networks = [
            sn for sn in spectral_network_data.spectral_networks
            if len(sn.errors) == 0
        ]
    else:
        spectral_networks = spectral_network_data.spectral_networks

    if soliton_trees is None:
        soliton_trees = spectral_network_data.soliton_trees

    if (len(spectral_networks) == 0
            and (soliton_trees is None or len(soliton_trees) == 0)):
        raise RuntimeError('get_spectral_network_bokeh_plot(): '
                           'No spectral network to plot.')

    sw_data = spectral_network_data.sw_data

    plot_x_range, plot_y_range = plot_range
    y_min, y_max = plot_y_range

    # Setup tools.
    hover = HoverTool(tooltips=[
        ('name', '@label'),
        ('root', '@root'),
    ])

    # Prepare a bokeh Figure.
    bokeh_figure = figure(
        tools='reset,box_zoom,pan,wheel_zoom,save,tap',
        plot_width=plot_width,
        plot_height=plot_height,
        title=None,
        x_range=plot_x_range,
        y_range=plot_y_range,
    )
    bokeh_figure.add_tools(hover)
    bokeh_figure.grid.grid_line_color = None

    # Data source for marked points, which are drawn for an illustration.
    mpds = ColumnDataSource({
        'x': [],
        'y': [],
        'color': [],
        'label': [],
        'root': []
    })
    for mp in marked_points:
        z, color = mp
        mpds.data['x'].append(z.real)
        mpds.data['y'].append(z.imag)
        mpds.data['color'].append(color)
        mpds.data['label'].append('')
        mpds.data['root'].append('')
    bokeh_figure.circle(
        x='x',
        y='y',
        size=5,
        color='color',
        source=mpds,
    )

    # Data source for punctures.
    if multi_parameter is False:
        puncts = sw_data.regular_punctures + sw_data.irregular_punctures
    else:
        puncts = sw_data[0].regular_punctures + sw_data[0].irregular_punctures
    ppds = ColumnDataSource({'x': [], 'y': [], 'label': [], 'root': []})
    for pp in puncts:
        if pp.z == oo:
            continue
        ppds.data['x'].append(pp.z.real)
        ppds.data['y'].append(pp.z.imag)
        ppds.data['label'].append(str(pp.label))
        ppds.data['root'].append('')
    bokeh_figure.circle(
        'x',
        'y',
        size=10,
        color="#e6550D",
        fill_color=None,
        line_width=3,
        source=ppds,
    )

    # Data source for branch points & cuts.
    if multi_parameter is False:
        bpds = ColumnDataSource({'x': [], 'y': [], 'label': [], 'root': []})
        for bp in sw_data.branch_points:
            if bp.z == oo:
                continue
            bpds.data['x'].append(bp.z.real)
            bpds.data['y'].append(bp.z.imag)
            bpds.data['label'].append(str(bp.label))
            positive_roots = bp.positive_roots
            if len(positive_roots) > 0:
                root_label = ''
                for root in positive_roots:
                    root_label += str(root.tolist()) + ', '
                bpds.data['root'].append(root_label[:-2])
            else:
                bpds.data['root'].append('')
        bokeh_figure.x(
            'x',
            'y',
            size=10,
            color="#e6550D",
            line_width=3,
            source=bpds,
        )

        bcds = ColumnDataSource({'xs': [], 'ys': []})
        try:
            branch_cut_rotation = sw_data.branch_cut_rotation
        except AttributeError:
            branch_cut_rotation = None
        if branch_cut_rotation is not None:
            for bl in sw_data.branch_points + sw_data.irregular_singularities:
                y_r = (2j * y_max) * complex(sw_data.branch_cut_rotation)
                bcds.data['xs'].append([bl.z.real, bl.z.real + y_r.real])
                bcds.data['ys'].append([bl.z.imag, bl.z.imag + y_r.imag])

            bokeh_figure.multi_line(
                xs='xs',
                ys='ys',
                line_width=2,
                color='gray',
                line_dash='dashed',
                source=bcds,
            )

    # XXX: Need to clean up copy-and-pasted codes.
    else:
        bpds = []
        bcds = []
        for swd in sw_data:
            bpds_i = ColumnDataSource({
                'x': [],
                'y': [],
                'label': [],
                'root': []
            })
            for bp in swd.branch_points:
                if bp.z == oo:
                    continue
                bpds_i.data['x'].append(bp.z.real)
                bpds_i.data['y'].append(bp.z.imag)
                bpds_i.data['label'].append(str(bp.label))
                root_label = ''
                for root in bp.positive_roots:
                    root_label += str(root.tolist()) + ', '
                bpds_i.data['root'].append(root_label[:-2])
            bpds.append(bpds_i)

            bcds_i = ColumnDataSource({'xs': [], 'ys': []})
            for bl in swd.branch_points + swd.irregular_singularities:
                y_r = (2j * y_max) * complex(swd.branch_cut_rotation)
                bcds_i.data['xs'].append([bl.z.real, bl.z.real + y_r.real])
                bcds_i.data['ys'].append([bl.z.imag, bl.z.imag + y_r.imag])
            bcds.append(bcds_i)

        # In this case the branch points and cuts will be
        # drawn differently for each spectral network.
        # Each call of the slider will deal with them.

    # Data source for the current plot
    cds = ColumnDataSource({
        'xs': [],
        'ys': [],
        'ranges': [],
        'color': [],
        'alpha': [],
        'arrow_x': [],
        'arrow_y': [],
        'arrow_angle': [],
        'label': [],
        'root': [],
    })

    # Data source for plotting data points
    dpds = ColumnDataSource({
        'x': [],
        'y': [],
    })

    # Data source for phases
    pds = ColumnDataSource({
        'phase': [],
    })
    #    for sn in spectral_networks:
    #        # sn_phase = '{:.3f}'.format(sn.phase / pi)
    #        sn_phase = '{:.3f}'.format(sn.phase)
    #        pds.data['phase'].append(sn_phase)

    # Data source containing all the spectral networks
    snds = ColumnDataSource({
        'spectral_networks': [],
    })

    if soliton_trees is not None and len(soliton_trees) > 0:
        # snds['spectral_networks'] is a 1-dim array,
        # of soliton trees.
        for tree in soliton_trees:
            if no_unstable_streets and tree.stability != 1:
                continue
            elif tree.stability == 1 or tree.stability is None:
                s_wall_color = '#0000FF'
            elif tree.stability == 0:
                s_wall_color = '#00FF00'
            elif tree.stability > 1:
                s_wall_color = '#FF0000'
            tree_data = get_s_wall_plot_data(
                tree.streets,
                sw_data,
                logger_name,
                tree.phase,
                s_wall_color=s_wall_color,
                downsample=downsample,
                downsample_ratio=downsample_ratio,
            )
            snds.data['spectral_networks'].append(tree_data)
            pds.data['phase'].append('{:.3f}'.format(tree.phase))
        init_data = snds.data['spectral_networks'][0]

    elif plot_two_way_streets:
        if soliton_tree_data is not None:
            # snds['spectral_networks'] is a 2-dim array,
            # where the first index chooses a spectral network
            # and the second index chooses a soliton tree
            # of the two-way streets of the spectral network.
            for i, soliton_trees in enumerate(soliton_tree_data):
                data_entry = []
                theta_i = spectral_networks[i].phase
                if len(soliton_trees) == 0:
                    # Fill with empty data.
                    empty_data = get_s_wall_plot_data(
                        [],
                        sw_data,
                        logger_name,
                        theta_i,
                        downsample=downsample,
                        downsample_ratio=downsample_ratio,
                    )
                    data_entry.append(empty_data)
                else:
                    for tree in soliton_trees:
                        tree_data = get_s_wall_plot_data(
                            tree.streets,
                            sw_data,
                            logger_name,
                            theta_i,
                            downsample=downsample,
                            downsample_ratio=downsample_ratio,
                        )
                        # The first data contains all the soliton trees
                        # of the two-way streets in a spectral network.
                        if len(data_entry) == 0:
                            data_entry.append(deepcopy(tree_data))
                        else:
                            for key in tree_data.keys():
                                data_entry[0][key] += tree_data[key]
                        data_entry.append(tree_data)

                snds.data['spectral_networks'].append(data_entry)
                pds.data['phase'].append('{:.3f}'.format(theta_i))

            init_data = snds.data['spectral_networks'][0][0]
        else:
            logger.warning('No data to plot.')
    else:
        # snds['spectral_networks'] is a 1-dim array,
        # of spectral network data.
        for sn in spectral_networks:
            skip_plotting = False
            for error in sn.errors:
                error_type, error_msg = error
                if error_type == 'Unknown':
                    skip_plotting = True
            if skip_plotting is True:
                continue

            sn_data = get_s_wall_plot_data(
                sn.s_walls,
                sw_data,
                logger_name,
                sn.phase,
                downsample=downsample,
                downsample_ratio=downsample_ratio,
            )
            snds.data['spectral_networks'].append(sn_data)
            pds.data['phase'].append('{:.3f}'.format(sn.phase))

        init_data = snds.data['spectral_networks'][0]

    # Initialization of the current plot data source.
    for key in cds.data.keys():
        cds.data[key] = init_data[key]

    bokeh_figure.scatter(
        x='x',
        y='y',
        alpha=0.5,
        source=dpds,
    )

    bokeh_figure.multi_line(
        xs='xs',
        ys='ys',
        color='color',
        alpha='alpha',
        line_width=1.5,
        source=cds,
    )

    bokeh_figure.triangle(
        x='arrow_x',
        y='arrow_y',
        angle='arrow_angle',
        color='color',
        alpha='alpha',
        size=8,
        source=cds,
    )

    bokeh_obj = {}
    notebook_vform_elements = []

    # XXX: Where is a good place to put the following?
    custom_js_code = ''
    if notebook is True or slide is True:
        with open('static/bokeh_callbacks.js', 'r') as fp:
            custom_js_code += fp.read()
            custom_js_code += '\n'

    # Data source for plot ranges
    if download is False and notebook is False and slide is False:
        range_callback = CustomJS(
            args={
                'x_range': bokeh_figure.x_range,
                'y_range': bokeh_figure.y_range
            },
            code=(custom_js_code + 'update_plot_range(x_range, y_range);'),
        )
        bokeh_figure.x_range.callback = range_callback
        bokeh_figure.y_range.callback = range_callback

    # 'Redraw arrows' button.
    redraw_arrows_button = Button(
        label='Redraw arrows',
        callback=CustomJS(
            args={
                'cds': cds,
                'x_range': bokeh_figure.x_range,
                'y_range': bokeh_figure.y_range
            },
            code=(custom_js_code + 'redraw_arrows(cds, x_range, y_range);'),
        ),
    )
    bokeh_obj['redraw_arrows_button'] = redraw_arrows_button
    notebook_vform_elements.append(redraw_arrows_button)

    # 'Show data points' button
    show_data_points_button = Button(label='Show data points', )
    show_data_points_button.callback = CustomJS(
        args={
            'cds': cds,
            'dpds': dpds,
            'hover': hover
        },
        code=(custom_js_code + 'show_data_points(cds, dpds, hover);'),
    )
    bokeh_obj['show_data_points_button'] = show_data_points_button
    notebook_vform_elements.append(show_data_points_button)

    # 'Hide data points' button
    hide_data_points_button = Button(label='Hide data points', )
    hide_data_points_button.callback = CustomJS(
        args={
            'cds': cds,
            'dpds': dpds,
            'hover': hover
        },
        code=(custom_js_code + 'hide_data_points(cds, dpds, hover);'),
    )
    bokeh_obj['hide_data_points_button'] = hide_data_points_button
    notebook_vform_elements.append(hide_data_points_button)

    # Prev/Next soliton tree button
    tree_idx_ds = ColumnDataSource({'j': ['0']})
    sn_idx_ds = ColumnDataSource({'i': ['0']})
    plot_options_ds = ColumnDataSource({
        'notebook': [notebook],
        'show_trees': [plot_two_way_streets]
    })

    if plot_two_way_streets and soliton_tree_data is not None:
        prev_soliton_tree_button = Button(label='<', )
        prev_soliton_tree_button.callback = CustomJS(
            args={
                'cds': cds,
                'snds': snds,
                'sn_idx_ds': sn_idx_ds,
                'tree_idx_ds': tree_idx_ds,
                'plot_options_ds': plot_options_ds,
            },
            code=(custom_js_code +
                  'show_prev_soliton_tree(cds, snds, sn_idx_ds, tree_idx_ds, '
                  'plot_options_ds);'),
        )
        bokeh_obj['prev_soliton_tree_button'] = prev_soliton_tree_button
        notebook_vform_elements.append(prev_soliton_tree_button)

        next_soliton_tree_button = Button(label='>', )
        next_soliton_tree_button.callback = CustomJS(
            args={
                'cds': cds,
                'snds': snds,
                'sn_idx_ds': sn_idx_ds,
                'tree_idx_ds': tree_idx_ds,
                'plot_options_ds': plot_options_ds,
            },
            code=(custom_js_code +
                  'show_next_soliton_tree(cds, snds, sn_idx_ds, tree_idx_ds, '
                  'plot_options_ds);'),
        )
        bokeh_obj['next_soliton_tree_button'] = next_soliton_tree_button
        notebook_vform_elements.append(next_soliton_tree_button)

    # Slider
    if (plot_two_way_streets and soliton_trees is not None
            and len(soliton_trees) > 0):
        slider_title = 'soliton tree #'
    else:
        slider_title = 'spectral network #'
    num_of_plots = len(snds.data['spectral_networks'])
    if num_of_plots > 1:
        if multi_parameter is False:
            sn_slider = Slider(
                start=0,
                end=num_of_plots - 1,
                value=0,
                step=1,
                title=slider_title,
            )

            sn_slider.callback = CustomJS(
                args={
                    'cds': cds,
                    'snds': snds,
                    'sn_idx_ds': sn_idx_ds,
                    'dpds': dpds,
                    'pds': pds,
                    'hover': hover,
                    'plot_options': plot_options_ds,
                    'tree_idx_ds': tree_idx_ds
                },
                code=(custom_js_code +
                      'sn_slider(cb_obj, cds, snds, sn_idx_ds, dpds, pds, '
                      'hover, plot_options, tree_idx_ds);'),
            )
            plot = vform(
                bokeh_figure,
                sn_slider,
                width=plot_width,
            )
            notebook_vform_elements = ([bokeh_figure, sn_slider] +
                                       notebook_vform_elements)

        else:
            # TODO: implement new js routine for sn_slider when
            # there are multiple parameters.
            # Need to draw branch points and cuts for each value of the
            # parameters.
            sn_slider = Slider(start=0,
                               end=num_of_plots - 1,
                               value=0,
                               step=1,
                               title="spectral network #")

            sn_slider.callback = CustomJS(
                args={
                    'cds': cds,
                    'snds': snds,
                    'sn_idx_ds': sn_idx_ds,
                    'dpds': dpds,
                    'pds': pds,
                    'hover': hover,
                    'plot_options': plot_options_ds,
                    'tree_idx_ds': tree_idx_ds
                },
                code=(custom_js_code +
                      'sn_slider(cb_obj, cds, snds, sn_idx_ds, dpds, pds, '
                      'hover, plot_options, tree_idx_ds);'),
            )
            plot = vform(
                bokeh_figure,
                sn_slider,
                width=plot_width,
            )
            notebook_vform_elements = ([bokeh_figure, sn_slider] +
                                       notebook_vform_elements)

    else:
        plot = bokeh_figure
        notebook_vform_elements = ([bokeh_figure] + notebook_vform_elements)

    bokeh_obj['plot'] = plot

    if notebook is True:
        # TODO: Include phase text input
        return vform(*notebook_vform_elements, width=plot_width)
    elif slide is True:
        return plot
    else:
        return bokeh.embed.components(bokeh_obj)
Exemplo n.º 6
0
def modify_doc():

    # collect all the names of .npz files in the folder
    lfpca_all_names = glob.glob("*.npz")
    lfpca_all_names.sort()

    # loading the npz files
    lfpca_all = {}
    for ind, lf in enumerate(lfpca_all_names):
        lfpca_all[lf[:-4]] = lfpca.lfpca_load_spec(lf)

    # initialize with the first lfpca object
    lf = lfpca_all[lfpca_all_names[0][:-4]]

    # grabbing channel count from psd
    chan_count, freq = lf.psd.shape

    # mapping all the channels
    DEFAULT_TICKERS = list(map(str, range(chan_count)))
    LF_TICKERS = [key for key in lfpca_all.keys()]

    # initializing values for frequency, psd, scv, histogram plot
    chan = 0
    select_freq = 10
    select_bin = 20
    freq_vals = lf.f_axis[1:]
    psd_vals = lf.psd[chan].T[1:]
    scv_vals = lf.scv[chan].T[1:]

    # creating a selector and slider
    lf_ticker = Select(value=lfpca_all_names[0][:-4],
                       title='lf_condition',
                       options=LF_TICKERS)
    ticker = Select(value=str(chan), title='channel', options=DEFAULT_TICKERS)
    freq_slider = Slider(start=1,
                         end=199,
                         value=select_freq,
                         step=1,
                         title="Frequency",
                         callback_policy="mouseup")
    bin_slider = Slider(start=10,
                        end=55,
                        value=select_bin,
                        step=5,
                        title="Number of bins",
                        callback_policy="mouseup")

    # create data and selection tools
    source = ColumnDataSource(
        data=dict(freq_vals=freq_vals, psd_vals=psd_vals, scv_vals=scv_vals))

    TOOLS = "help"  #tapTool work in progress

    # setting up plots
    psd_plot = figure(tools=TOOLS,
                      title='PSD',
                      x_axis_type='log',
                      y_axis_type='log')
    psd_plot.legend.location = 'top_left'
    psd_plot.xaxis.axis_label = 'Frequency (Hz)'
    psd_plot.yaxis.axis_label = 'Power/Frequency (dB/Hz)'
    psd_plot.grid.grid_line_alpha = 0.3

    scv_plot = figure(tools=TOOLS,
                      title='SCV',
                      x_axis_type='log',
                      y_axis_type='log')
    scv_plot.legend.location = 'top_left'
    scv_plot.xaxis.axis_label = 'Frequency (Hz)'
    scv_plot.yaxis.axis_label = '(Unitless)'
    scv_plot.grid.grid_line_alpha = 0.3

    # create histogram frame
    hist_source = ColumnDataSource({'top': [], 'left': [], 'right': []})
    fit_hist_source = ColumnDataSource({'x': [], 'y': []})
    hist, edges = np.histogram(lf.spg[chan, select_freq, :],
                               bins=select_bin,
                               density=True)
    hist_source.data = {'top': hist, 'left': edges[:-1], 'right': edges[1:]}

    # create fit line for the histogram
    rv = expon(scale=sp.stats.expon.fit(lf.spg[chan,
                                               select_freq, :], floc=0)[1])
    hist_source.data = {'top': hist, 'left': edges[:-1], 'right': edges[1:]}
    fit_hist_source.data = {'x': edges, 'y': rv.pdf(edges)}

    hist_fig = figure(x_axis_label='Power',
                      y_axis_label='Probability',
                      background_fill_color="#E8DDCB")
    hist_fig.axis.visible = False
    hist_fig.title.text = 'Freq = %.1fHz, p-value = %.4f' % (
        select_freq, lf.ks_pvals[chan, select_freq])

    # customize plot to psd
    def create_psd_plot(psd_plot, source):
        psd_plot.line('freq_vals', 'psd_vals', source=source, color='navy')
        psd_plot.circle(
            'freq_vals',
            'psd_vals',
            source=source,
            size=5,
            color='darkgrey',
            alpha=0.2,
            # set visual properties for selected glyphs
            selection_color="firebrick",
            # set visual properties for non-selected glyphs
            nonselection_fill_alpha=0.2,
            nonselection_fill_color="darkgrey",
            name='psd_circ')

    # customize plot to psd
    def create_scv_plot(scv_plot, source):
        scv_plot.line('freq_vals', 'scv_vals', source=source, color='navy')
        scv_plot.circle(
            'freq_vals',
            'scv_vals',
            source=source,
            size=5,
            color='darkgrey',
            alpha=0.2,
            # set visual properties for selected glyphs
            selection_color="firebrick",
            # set visual properties for non-selected glyphs
            nonselection_fill_alpha=0.2,
            nonselection_fill_color="darkgrey",
            name='scv_circ')

    # customize histogram
    def create_hist(hist_fig, hist_source):
        hist_fig.quad(top='top',
                      bottom=0,
                      left='left',
                      right='right',
                      fill_color="#036564",
                      line_color="#033649",
                      source=hist_source)

    # initializing plots
    create_psd_plot(psd_plot, source)
    create_scv_plot(scv_plot, source)
    vline_psd = Span(location=select_freq,
                     dimension='height',
                     line_color='red',
                     line_dash='dashed',
                     line_width=3)
    vline_scv = Span(location=select_freq,
                     dimension='height',
                     line_color='red',
                     line_dash='dashed',
                     line_width=3)
    psd_plot.add_layout(vline_psd)
    scv_plot.add_layout(vline_scv)
    create_hist(hist_fig, hist_source)
    fit_line = bokeh.models.glyphs.Line(x='x',
                                        y='y',
                                        line_width=8,
                                        line_alpha=0.7,
                                        line_color="#D95B43")
    hist_fig.add_glyph(fit_hist_source, fit_line)

    all_plots = gridplot([[psd_plot, scv_plot, hist_fig]],
                         plot_width=300,
                         plot_height=300)

    # set up connector spans
    freq_slider.callback = CustomJS(args=dict(span1=vline_psd,
                                              span2=vline_scv,
                                              slider=freq_slider),
                                    code="""span1.location = slider.value; 
                                                        span2.location = slider.value"""
                                    )

    def update(attrname, old, new):
        # get current slider values
        chan = int(ticker.value)
        lf = lfpca_all[lf_ticker.value]
        select_freq = freq_slider.value
        select_bin = bin_slider.value

        # update data
        psd_vals = lf.psd[chan].T[1:]
        scv_vals = lf.scv[chan].T[1:]
        data = dict(freq_vals=freq_vals, psd_vals=psd_vals, scv_vals=scv_vals)
        # create a column data source for the plots to share
        source.data = data

        # update histogram and fit line
        hist, edges = np.histogram(lf.spg[chan, select_freq, :],
                                   bins=select_bin,
                                   density=True)
        rv = expon(
            scale=sp.stats.expon.fit(lf.spg[chan, select_freq, :], floc=0)[1])
        hist_source.data = {
            'top': hist,
            'left': edges[:-1],
            'right': edges[1:]
        }
        fit_hist_source.data = {'x': edges, 'y': rv.pdf(edges)}
        create_psd_plot(psd_plot=psd_plot, source=source)
        create_scv_plot(scv_plot=scv_plot, source=source)
        hist_fig.title.text = 'Freq = %.1fHz, p-value = %.4f' % (
            select_freq, lf.ks_pvals[chan, select_freq])
        create_hist(hist_fig=hist_fig, hist_source=hist_source)
        fit_line = bokeh.models.glyphs.Line(x='x',
                                            y='y',
                                            line_width=8,
                                            line_alpha=0.7,
                                            line_color="#D95B43")
        hist_fig.add_glyph(fit_hist_source, fit_line)

    # whenever a widget changes, the changes are tracked and histogram always updated
    for widget in [lf_ticker, ticker, bin_slider, freq_slider]:
        widget.on_change('value', update)

    # when selected value changes, take the following methods of actions
    # lf_ticker.on_change('value', lf_selection_change)
    # ticker.on_change('value', selection_change)

    # what to do when freq slider value changes
    def freq_change(attr, old, new):
        select_freq = source.selected.indices[0]
        # update histogram and fit line
        hist, edges = np.histogram(lf.spg[chan, select_freq, :],
                                   bins=select_bin,
                                   density=True)
        rv = expon(
            scale=sp.stats.expon.fit(lf.spg[chan, select_freq, :], floc=0)[1])
        hist_source.data = {
            'top': hist,
            'left': edges[:-1],
            'right': edges[1:]
        }
        fit_hist_source.data = {'x': edges, 'y': rv.pdf(edges)}
        hist_fig.title.text = 'Freq = %.1fHz, p-value = %.4f' % (
            select_freq, lf.ks_pvals[chan, select_freq])
        create_hist(hist_fig=hist_fig, hist_source=hist_source)
        fit_line = bokeh.models.glyphs.Line(x='x',
                                            y='y',
                                            line_width=8,
                                            line_alpha=0.7,
                                            line_color="#D95B43")
        hist_fig.add_glyph(fit_hist_source, fit_line)

    # organize layout
    widgets = row(lf_ticker, ticker)
    sliders = row(freq_slider, bin_slider)
    layout = column(widgets, sliders, all_plots)
    # doc.add_root(layout)
    return layout


# In the notebook, just pass the function that defines the app to show
# show(modify_doc, notebook_handle=True)
# curdoc().add_root(layout) - for .py file
# curdoc().add_root(layout)
# h = show(layout, notebook_handle=True)
# In the notebook, just pass the function that defines the app to show
# show(modify_doc)
# curdoc().add_root(layout) # for .py file
Exemplo n.º 7
0
slider.callback = CustomJS(args=dict(source=source, slider=slider),
                           code="""
    var a = 0;
    var b = Math.PI;
    var data = source.data;
    var f = Math.cos;
    var x = a;
    var y = f(a);
    var step = (b-a)/slider.value;
    data.xleft = new Array();
    data.xright = new Array();
    data.yleft = new Array();
    data.yright = new Array();
    data.bottom = new Array();
    data.xleft.push(x);
    data.yright.push(y)
    data.bottom.push(0);
    for (var k=0; k<slider.value-1; k++){
        x += step;
        y = f(x);
        data.xleft.push(x);
        data.xright.push(x)
        data.yleft.push(y);
        data.yright.push(y);
        data.bottom.push(0);
    }
    x += step;
    y = f(x);
    data.xright.push(x);
    data.yleft.push(y);
    source.change.emit();
""")
Exemplo n.º 8
0
def graph_year_property(h, p_no=0):
    '''
    Creates bar graphs for particular year and property
    '''
    from bokeh.core.properties import value
    from bokeh.plotting import figure
    from bokeh.io import show, output_file
    from bokeh.layouts import row, column
    from bokeh.models import ColumnDataSource
    from bokeh.models.widgets import Dropdown
    from bokeh.models.callbacks import CustomJS
    output_file("bars.html")
    xlabels = []

    for key, value in cities.items():
        xlabels.append(key + ': ' + value[0])

    yr_ind = 0

    p_values = []
    for yr_ind in range(12):
        i = 0
        p_value = [0 for x in range(18)]
        for state, city_list in h.items():
            city_p = [d.get(str(col_index_names1000[p_no])) for d in city_list]
            a = array(city_p)
            col = 0
            for row in a:
                if col == 2:
                    break
                p_value[i] = row[yr_ind]
                col = col + 1
            i = i + 1
        p_values.append(p_value)
    alldat = {}
    syear = h['CA'][0].index[0]
    nyears = len(h['CA'][0].index)
    for ix, yy in enumerate(range(syear, syear + nyears)):
        alldat[str(yy)] = p_values[ix]
    source_available = ColumnDataSource(data=alldat)
    source_visible = ColumnDataSource(
        data=dict(counties=xlabels, pvalue=p_values[0]))
    TOOLS = "pan,wheel_zoom,reset,hover,save"
    p = figure(x_range=xlabels, plot_height=450, plot_width=800,
               title='Year', toolbar_location=None, tools=TOOLS)
    p.vbar(x='counties', top='pvalue', source=source_visible,
           width=0.4, alpha=0.7, color='#3FE0D0')
    p.x_range.range_padding = 0.1
    p.title.align = 'center'
    p.yaxis.axis_label = col_index_names1000[p_no]
    p.yaxis.axis_label_text_font_size = '12pt'
    p.xaxis.major_label_orientation = 3.14/2
    p.xaxis.major_label_text_font_size = '12pt'
    p.axis.minor_tick_line_color = 'black'
    p.outline_line_color = 'black'
    # show(p)
    slider = Slider(start=syear, end=syear+nyears-1,
                    value=syear, step=1, title="Year")

    #show(column(p, widgetbox(slider),))
    slider.callback = CustomJS(
        args=dict(source_visible=source_visible,
                  source_available=source_available), code="""
        var selected_year = cb_obj.value;
        // Get the data from the data sources
        var data_visible = source_visible.data;
        var data_available = source_available.data;
        // Change y-axis data according to the selected value
        data_visible.pvalue = data_available[selected_year];
        // Update the plot
        source_visible.change.emit();
    """)
    hover = p.select_one(HoverTool)
    # hover.point_policy = "follow_mouse"
    property = h['CA'][0].columns[p_no]
    hover.tooltips = [("County", "@counties"), (property,
                                                "$y")]
    show(column(p, widgetbox(slider),))
Exemplo n.º 9
0
            year.value=2017;
            month.value=2;
            alert("Select month less than 4");
        }
        var column = year*100+month;
        
        data['count'] = data['count'+column].slice();
        for(i=0;i<668;i++){
            data['transform'][i] = data['transform'+column][i]*300;
        }
        data.fill2 = data.fill.slice();
        
        var arr = data.count.slice();
        arr.sort(function(a, b){return b - a});
        for(var i=0;i<5;i++){
            var high = data.count.indexOf(arr[i]);
            data.fill2[high] = "yellow";
        }
        console.log(data['count']);
        console.log(data['transform']);
        source.change.emit();
    """)
year.callback = callback
month.callback = callback

p = Div(text="Click Reset in the top right tool box.<br>Please select year and month to get started.<br>Yellow circles represent top five popular stations",width=400, height=50)
layout = gridplot([[p,year, month],[fig]])

outfile=open('plot_popular.html','w')
outfile.write(file_html(layout,CDN,'Popular Stations'))
outfile.close()
Exemplo n.º 10
0
def graph_year_property(h, p_no=0):
    '''
    Creates bar graphs with slider for years from 2006 to 2017 for a particular parameter
    
    Arguments:
        h{dict} -- Dictionary with filled dataframes

    Keyword Arguments:
        p_no{int} -- Parameter to be plotted
    '''
    assert isinstance(p_no, int)
    assert p_no >= 0 and p_no <= 6
    assert isinstance(h, dict)
    
    from bokeh.core.properties import value
    from bokeh.plotting import figure
    from bokeh.io import show, output_file
    from bokeh.layouts import row, column
    from bokeh.models import ColumnDataSource
    from bokeh.models.widgets import Dropdown
    from bokeh.models.callbacks import CustomJS
    output_file("bars.html")
    xlabels = []

    for key, value in cities.items():
        xlabels.append(key + ': ' + value[0]) #labels used for the x-axis in the form "State: City"

    yr_ind = 0

    p_values = []
    for yr_ind in range(12):
        i = 0
        p_value = [0 for x in range(18)]
        for state, city_list in h.items():
            city_p = [d.get(str(col_index_names1000[p_no])) for d in city_list]
            a = array(city_p)
            if p_no == 0:
                a = a / 1000000 #To get "Population" in millions
            if p_no == 5:
                a = a / 1000 #To get "Unlinked passenger trips" in thousands
            col = 0
            for row in a:
                if col == 1: #To get values of only the 1st city for each state
                    break
                p_value[i] = row[yr_ind]
                col = col + 1
            i = i + 1
        p_values.append(p_value) #List containing values corresponding to 1st city for each state
    
    alldat = {}
    syear = h['CA'][0].index[0]
    nyears = len(h['CA'][0].index)
    for ix, yy in enumerate(range(syear, syear + nyears)):
        alldat[str(yy)] = p_values[ix]
    source_available = ColumnDataSource(data=alldat)
    source_visible = ColumnDataSource(
        data=dict(counties=xlabels, pvalue=p_values[0]))
    TOOLS = "pan,wheel_zoom,reset,hover,save"
    p = figure(x_range=xlabels, plot_height=450, plot_width=800, title=col_index_names1000[p_no], toolbar_location=None, tools=TOOLS)
    p.vbar(x='counties', top='pvalue', source=source_visible,
           width=0.4, alpha=0.7, color='#643fe0')
    p.x_range.range_padding = 0.1
    p.title.align = 'center'
    p.title.text_font_size = '14pt'
    if p_no == 0:
        p.yaxis.axis_label = 'In millions'
    if p_no == 5:
        p.yaxis.axis_label = 'In thousands'
    p.yaxis.axis_label_text_font_size = '12pt'
    p.yaxis.major_label_text_font_size = '12pt'
    p.xaxis.major_label_orientation = 3.14/4
    p.xaxis.major_label_text_font_size = '12pt'
    p.axis.minor_tick_line_color = 'black'
    p.outline_line_color = 'black'
    slider = Slider(start=syear, end=syear+nyears-1,
                    value=syear, step=1, title="Year",bar_color='#643fe0',align="center")
    slider.callback = CustomJS(
        args=dict(source_visible=source_visible,
                  source_available=source_available), code="""
        var selected_year = cb_obj.value;
        // Get the data from the data sources
        var data_visible = source_visible.data;
        var data_available = source_available.data;
        // Change y-axis data according to the selected value
        data_visible.pvalue = data_available[selected_year];
        // Update the plot
        source_visible.change.emit();
    """)
    hover = p.select_one(HoverTool)
    property = h['CA'][0].columns[p_no]
    hover.tooltips = [("County", "@counties"), (property,
                                                "$y")]
    show(column(p, widgetbox(slider),))
Exemplo n.º 11
0
    c = (a + b) / 2
    la.append(a)
    lb.append(b)
    lc.append(c)

source = ColumnDataSource(data=dict(la=la, lb=lb, lc=lc))

p = figure(title="Résolution par dichotomie", plot_width=700, plot_height=500)
p.title.align = 'center'
p.line(x, y, line_width=2)

slider = Slider(start=0, end=N, value=0, step=1, title="Itérations")
spa = Span(location=la[slider.value], dimension='height',
           line_color='red', line_dash='dashed', line_width=3)
p.add_layout(spa)
spb = Span(location=lb[slider.value], dimension='height',
           line_color='green', line_dash='dashed', line_width=3)
p.add_layout(spb)
spc = Span(location=lc[slider.value], dimension='height',
           line_color='orange', line_dash='dashed', line_width=3)
p.add_layout(spc)

slider.callback = CustomJS(args=dict(spa=spa, spb=spb, spc=spc, source=source, slider=slider), code="""
    var n = slider.value;
    spa.location = source.data['la'][n];
    spb.location = source.data['lb'][n];
    spc.location = source.data['lc'][n];
""")

show(column(p, widgetbox(slider)))
Exemplo n.º 12
0
  working_copy_catchment_data['parameters'][7] = vr_slider.value 
  working_copy_catchment_data['parameters'][8] = k0_slider.value 
  working_copy_catchment_data['parameters'][9] = CD_slider.value 
  y = getTopModel(working_copy_catchment_data)
  model.data_source.data["y"] = y

# This data source is just used to communicate / trigger the real callback
source = ColumnDataSource(data=dict(value=[]))
source.on_change('data', slider_cb)

init_value = int((3.167914e-05 / 4e-5) * 100)
qs0_slider = Slider(start=1, end=100, value=init_value, step=0.1,
        title="Initial subsurface flow / unit area [% 0 - 4e-5 m]", 
        callback_policy="throttle", callback_throttle=300)
qs0_slider.callback = CustomJS(args=dict(source=source), code="""
  source.data = { value: [cb_obj.value] }
""")

lnTe_slider = Slider(start=-2, end=1, value=-5.990615e-01, step=.1,
                    title="Log of the areal average of T0 [m2/h]", callback_policy="throttle", callback_throttle=300)
lnTe_slider.callback = CustomJS(args=dict(source=source), code="""
  source.data = { value: [cb_obj.value] }
""")

m_slider = Slider(start=0, end=0.2, value=2.129723e-02, step=.01,
                    title="Decline of transmissivity in soil profile", callback_policy="throttle", callback_throttle=300)
m_slider.callback = CustomJS(args=dict(source=source), code="""
  source.data = { value: [cb_obj.value] }
""")

Sr0_slider = Slider(start=0, end=0.02, value=2.626373e-03, step=0.001,
Exemplo n.º 13
0
fake_callback_source5 = ColumnDataSource(data=dict(value=[]))  # for age
fake_callback_source5.on_change('data', age_slider_update)

astro_controls = []
exposure_controls = []
visual_controls = [widget]

age_slider = Slider(start=5.5,
                    end=10.15,
                    value=10.,
                    step=0.05,
                    title="Log(Age in Gyr)",
                    callback_policy='mouseup')
age_slider.callback = CustomJS(args=dict(source=fake_callback_source5),
                               code="""
    source.data = { value: [cb_obj.value] }
""")
astro_controls.append(age_slider)

metallicity_slider = Slider(start=-2.,
                            end=0.0,
                            value=0.,
                            step=0.5,
                            title="Log(Z/Zsun)",
                            callback_policy='mouseup')
metallicity_slider.callback = CustomJS(args=dict(source=fake_callback_source4),
                                       code="""
    source.data = { value: [cb_obj.value] }
""")
astro_controls.append(metallicity_slider)
Exemplo n.º 14
0

# Connect different objects/events to callback functions
rbutton_g.on_click(go_right_by_one_gframe)
lbutton_g.on_click(go_left_by_one_gframe)
rbutton_r.on_click(go_right_by_one_rframe)
lbutton_r.on_click(go_left_by_one_rframe)
fig_lc.on_event('tap', jump_to_lightcurve_position)

# 2-Step callback for the sliders to allow for callback throttling
fake_source_g = ColumnDataSource(data=dict(value=[]))
fake_source_r = ColumnDataSource(data=dict(value=[]))
fake_source_g.on_change('data', update_g_frame)
fake_source_r.on_change('data', update_r_frame)
g_frame_slider.callback = CustomJS(args=dict(source=fake_source_g),
                                   code="""
    source.data = { value: [cb_obj.value] }
    """)
r_frame_slider.callback = CustomJS(args=dict(source=fake_source_r),
                                   code="""
    source.data = { value: [cb_obj.value] }
    """)

# Create plot grid
l = layout([fig_lc, fig_img], [
    column(row(lbutton_g, rbutton_g, g_frame_slider),
           row(lbutton_r, rbutton_r, r_frame_slider), row(
               fig_infog, fig_infor)), fig_imr
])

# Add everything into the Bokeh document
curdoc().add_root(l)
Exemplo n.º 15
0
            y[i] = data['y'][i]
        }

        txtData = txtSrc.get('data');
        txtData['text'][0] = 'speed: ' + data.speed
        txtMinData = txtMinSrc.get('data');
        txtMinData['text'][0] = 'minWL: ' + data.minwl
        txtMaxData = txtMaxSrc.get('data');
        txtMaxData['text'][0] = 'maxWL: ' + data.maxwl
});

        source.trigger('change');
        sPlot.trigger('change');
        txtSrc.trigger('change');
        txtMinSrc.trigger('change');
        txtMaxSrc.trigger('change');

    """)

sliderPDE.callback = callbackHist
sliderMinWL.callback = callbackHist
sliderASIC.callback = callbackHist
sliderSPTR.callback = callbackHist


sliders = hplot(sliderPDE,sliderMinWL,sliderASIC,sliderSPTR)
layout = vplot(vplot(sliders,hplot(ph,p2)), width=800, height=800)

output_file("scatter.html", title="color_scatter.py example")
show(layout)
Exemplo n.º 16
0
def volcano_plot(data_sans):
    index = data_sans['access']
    x = data_sans['logfc']
    y = data_sans['pvalue']
    pos = data_sans['pos']

    source = ColumnDataSource(
        data=dict(x=x, y=y, accession=index, position=pos))

    color_mapper = CategoricalColorMapper(factors=["up", "normal", "down"],
                                          palette=['yellow', 'green', 'blue'])
    # dictionnary for the hoover tool
    hover = HoverTool(tooltips=[("accession",
                                 "@accession"), ("x", "@x"), ("y", "@y")])

    yrange = [min(y) - 0.1, max(y) + 0.1]
    xrange = [min(x) - 1, max(x) + 0.1]

    # setting the tools
    TOOLS = ",pan,wheel_zoom,box_zoom,reset,box_select,lasso_select,previewsave"

    # create a new plot with a title and axis labels
    p = figure(y_range=yrange,
               x_range=xrange,
               x_axis_label='log(fc)',
               y_axis_label='-log(pvalue)',
               tools=TOOLS,
               plot_width=800,
               plot_height=800)

    p.add_tools(hover)

    # title modification
    p.title.text = "pvalue versus fold-change"
    p.title.align = "center"
    p.title.text_color = "blue"
    p.title.text_font_size = "25px"
    #p.title.background_fill_color = "#aaaaee"

    #setting the widgets slider
    h_slider = Slider(start=yrange[0],
                      end=yrange[1],
                      value=1,
                      step=.1,
                      title="variation of log(pvalue)")
    v_slider_right = Slider(start=0,
                            end=xrange[1],
                            value=0.5,
                            step=.01,
                            title="right fold change")
    v_slider_left = Slider(start=xrange[0],
                           end=0,
                           value=-0.5,
                           step=.01,
                           title="left log fold change")

    # Horizontal line
    hline = Span(location=h_slider.value,
                 dimension='width',
                 line_color='green',
                 line_width=2)
    # Vertical line
    vline1 = Span(location=v_slider_right.value,
                  dimension='height',
                  line_color='blue',
                  line_width=2)
    vline2 = Span(location=v_slider_left.value,
                  dimension='height',
                  line_color='black',
                  line_width=2)

    #setting the widgets slider
    h_slider = Slider(start=yrange[0],
                      end=yrange[1],
                      value=1,
                      step=.1,
                      title="variation of log(pvalue)")
    v_slider_right = Slider(start=0,
                            end=xrange[1],
                            value=0.5,
                            step=.01,
                            title="right fold change")
    v_slider_left = Slider(start=xrange[0],
                           end=0,
                           value=-0.5,
                           step=.01,
                           title="left log fold change")
    p.renderers.extend([vline1, vline2, hline])

    # add a circle points
    p.circle('x',
             'y',
             source=source,
             color=dict(field='position', transform=color_mapper),
             legend='position')
    #setting the code to obain a real time ajustement of value and color
    #on th plot
    code = """
    var data = source.data;
    var low =  v_slider_left.value;
    var up = v_slider_right.value
    var back_value = h_slider.value;

    x = data['x']
    y = data['y']
    pos = data['position']

    span.location = slider.value

    for (i = 0; i < x.length; i++) {
        if( (x[i] < low) && (y[i] > back_value)) {
            pos[i] = 'down'
        } else if ((x[i] > up) && (y[i] > back_value)){
            pos[i] = 'up'
        } else {
            pos[i] = 'normal'
        }
    }
    console.log(source.data)
    source.change.emit()
    """
    # callback of the sliders
    h_slider.callback = CustomJS(args=dict(source=source,
                                           span=hline,
                                           slider=h_slider,
                                           v_slider_left=v_slider_left,
                                           h_slider=h_slider,
                                           v_slider_right=v_slider_right),
                                 code=code)
    v_slider_right.callback = CustomJS(args=dict(
        source=source,
        span=vline1,
        slider=v_slider_right,
        v_slider_left=v_slider_left,
        h_slider=h_slider,
        v_slider_right=v_slider_right),
                                       code=code)
    v_slider_left.callback = CustomJS(args=dict(source=source,
                                                span=vline2,
                                                slider=v_slider_left,
                                                v_slider_left=v_slider_left,
                                                h_slider=h_slider,
                                                v_slider_right=v_slider_right),
                                      code=code)

    # creating du tableau des résulats de la selection datacolumn
    columns = [
        TableColumn(field="accession", title="numero d'accession"),
        TableColumn(field="x", title="log(fc)"),
        TableColumn(field="y", title="-log(pvalue)"),
        TableColumn(field="position", title="position"),
    ]

    data_table = DataTable(source=source,
                           columns=columns,
                           width=400,
                           height=280)
    # creating of the download button
    button = Button(label="Download", button_type="success")
    button.callback = CustomJS(args=dict(source=source),
                               code=open(
                                   join(dirname(__file__),
                                        "static/js/download.js")).read())
    layout = row(
        p,
        widgetbox(v_slider_left, v_slider_right, h_slider, data_table, button))
    return layout
Exemplo n.º 17
0
    def create_bokeh_choro(self, ff, prop=0):
        """Creates Interactive Bokeh Choropleth for US counties transportationdata.

        Arguments:
            ff {dict} -- Dictionary containing filled dataframes

        Keyword Arguments:
            prop {int} -- Select the property for which choropleth needs to be created (default: {0})
        """
        assert isinstance(ff, dict)
        year = 0
        # Very Important Function
        assert isinstance(prop, int)
        assert isinstance(year, int)
        assert len(ff['CA'][0].columns) > prop >= 0
        assert len(ff['CA'][0].index) > year >= 0
        try:
            # del states["HI"]
            del states["AK"]
        except:
            pass
        nyears = len(ff['CA'][0].index)
        state_xs = [states[code]["lons"] for code in states]
        state_ys = [states[code]["lats"] for code in states]
        county_xs = []
        county_ys = []
        district_name = []
        for cs in self.bokeh_counties.values():
            for dname in cs:
                county_xs.append([
                    counties[code]["lons"] for code in counties
                    if counties[code]["detailed name"] == dname
                ][0])
                county_ys.append([
                    counties[code]["lats"] for code in counties
                    if counties[code]["detailed name"] == dname
                ][0])
                district_name.append(dname)
        if isinstance(palette, dict):
            color_mapper = LogColorMapper(
                palette=palette[list(palette.keys())[-1]])
        else:
            color_mapper = LogColorMapper(palette=palette)
        pvalues = []

        for yx in range(nyears):
            yvalues = []
            for state in ff.keys():
                for cs in ff[state]:
                    yvalues.append(cs.iloc[yx, prop])
            pvalues.append(yvalues)
        alldat = {}
        syear = ff['CA'][0].index[0]
        for ix, yy in enumerate(range(syear, syear + nyears)):
            alldat[str(yy)] = pvalues[ix]
        source_available = ColumnDataSource(data=alldat)
        source_visible = ColumnDataSource(data=dict(
            x=county_xs, y=county_ys, name=district_name, pvalue=pvalues[0]))
        TOOLS = "pan,wheel_zoom,reset,hover,save"
        p = figure(title=f"{ff['CA'][0].columns[prop]} across Counties",
                   tools=TOOLS,
                   plot_width=850,
                   plot_height=400,
                   x_axis_location=None,
                   y_axis_location=None)
        p.toolbar.active_scroll = "auto"
        p.toolbar.active_drag = 'auto'
        p.background_fill_color = "#B0E0E6"
        p.patches(state_xs,
                  state_ys,
                  fill_alpha=1.0,
                  fill_color='#FFFFE0',
                  line_color="#884444",
                  line_width=2,
                  line_alpha=0.3)
        p.patches('x',
                  'y',
                  source=source_visible,
                  fill_color={
                      'field': 'pvalue',
                      'transform': color_mapper
                  },
                  fill_alpha=0.8,
                  line_color="white",
                  line_width=0.3)
        hover = p.select_one(HoverTool)
        hover.point_policy = "follow_mouse"
        property = ff['CA'][0].columns[prop]
        hover.tooltips = [("County", "@name"), (property, "@pvalue"),
                          ("(Long, Lat)", "($x, $y)")]
        output_file(f"{ff['CA'][0].columns[prop].replace(' ','')}.html",
                    title="US Public Transport")
        slider = Slider(start=int(ff['CA'][0].index[0]),
                        end=int(ff['CA'][0].index[-1]),
                        value=int(ff['CA'][0].index[0]),
                        step=1,
                        title="Year")
        slider.callback = CustomJS(args=dict(
            source_visible=source_visible, source_available=source_available),
                                   code="""
            var selected_year = cb_obj.value;
            // Get the data from the data sources
            var data_visible = source_visible.data;
            var data_available = source_available.data;
            // Change y-axis data according to the selected value
            data_visible.pvalue = data_available[selected_year];
            // Update the plot
            source_visible.change.emit();
        """)
        show(column(
            p,
            widgetbox(slider),
        ))
def modify_doc(doc):
    fid = polymer.getfid(ie)  #dataframe

    # TODO: more testing on get_x_axis
    tau = get_x_axis(polymer.getparameter(ie))

    #calculate magnetization:
    startpoint = int(0.05 * polymer.getparvalue(ie, 'BS'))
    endpoint = int(
        0.1 * polymer.getparvalue(ie, 'BS')
    )  #TODO: make a range slider to get start- and endpoint interactively
    phi = get_mag_amplitude(fid, startpoint, endpoint,
                            polymer.getparvalue(ie, 'NBLK'),
                            polymer.getparvalue(ie, 'BS'))

    #prepare magnetization decay curve for fit
    df = pd.DataFrame(data=np.c_[tau, phi], columns=['tau', 'phi'])
    df['phi_normalized'] = (df['phi'] - df['phi'].iloc[0]) / (
        df['phi'].iloc[-1] - df['phi'].iloc[1])
    polymer.addparameter(ie, 'df_magnetization', df)

    fit_option = 2
    p0 = [1, 2 * polymer.getparvalue(ie, 'T1MX')**-1, 0]
    df, popt = magnetization_fit(df, p0, fit_option)
    polymer.addparameter(ie, 'popt(mono_exp)', popt)

    df['fit_phi'] = model_exp_dec(df.tau, *popt)

    # convert data to handle in bokeh
    source_fid = ColumnDataSource(data=ColumnDataSource.from_df(fid))
    source_df = ColumnDataSource(data=ColumnDataSource.from_df(df))

    # create and plot figures
    p1 = figure(plot_width=300,
                plot_height=300,
                title='Free Induction Decay',
                webgl=True)
    p1.line('index', 'im', source=source_fid, color='blue')
    p1.line('index', 'real', source=source_fid, color='green')
    p1.line('index', 'magnitude', source=source_fid, color='red')

    fid_slider = RangeSlider(start=1,
                             end=polymer.getparvalue(ie, 'BS'),
                             step=1,
                             callback_policy='mouseup')

    p2 = figure(plot_width=300, plot_height=300, title='Magnetization Decay')
    p2.circle_cross('tau', 'phi_normalized', source=source_df, color="navy")
    p2.line('tau', 'fit_phi', source=source_df, color="teal")

    # in the plot 4 use followingimpo
    SIZES = list(range(6, 22, 3))  # for some sizes
    COLORS = Spectral5  # for some colors (more colors would be nice somehow)

    def plot_par():
        xs = par_df[x.value].values
        ys = par_df[y.value].values
        x_title = x.value.title()
        y_title = y.value.title()

        kw = dict()  #holds optional keyword arguments for figure()
        if x.value in discrete:
            kw['x_range'] = sorted(set(xs))
        if y.value in discrete:
            kw['y_range'] = sorted(set(ys))
        if y.value in time:
            kw['y_axis_type'] = 'datetime'
        if x.value in time:
            kw['x_axis_type'] = 'datetime'

        kw['title'] = "%s vs %s" % (x_title, y_title)

        p4 = figure(plot_height=300,
                    plot_width=600,
                    tools='pan,box_zoom,reset',
                    **kw)

        p4.xaxis.axis_label = x_title
        p4.yaxis.axis_label = y_title

        if x.value in discrete:
            p4.xaxis.major_label_orientation = pd.np.pi / 4  # rotates labels...

        sz = 9
        if size.value != 'None':
            groups = pd.qcut(pd.to_numeric(par_df[size.value].values),
                             len(SIZES))
            sz = [SIZES[xx] for xx in groups.codes]

        c = "#31AADE"
        if color.value != 'None':
            groups = pd.qcut(
                pd.to_numeric(par_df[color.value]).values, len(COLORS))
            c = [COLORS[xx] for xx in groups.codes]

        p4.circle(x=xs,
                  y=ys,
                  color=c,
                  size=sz,
                  line_color="white",
                  alpha=0.6,
                  hover_color='white',
                  hover_alpha=0.5)
        return p4

    def update(attr, old, new):
        layout_p4.children[1] = plot_par()

    def cb(attr, old, new):
        ## load experiment ie in plot p1 and p2
        ie = new['value'][0]
        fid = polymer.getfid(ie)
        #print(fid)
        #source_fid = ColumnDataSource.from_df(data=fid)
        source_fid.data = ColumnDataSource.from_df(fid)
        #print(source_fid)
        try:
            tau = get_x_axis(polymer.getparameter(ie))
            #print(tau)
            try:
                startpoint = polymer.getparvalue(ie, 'fid_amp_start')
                endpoint = polymer.getparvalue(ie, 'fid_amp_stop')
            except:
                startpoint = int(0.05 * polymer.getparvalue(ie, 'BS'))
                endpoint = int(0.1 * polymer.getparvalue(ie, 'BS'))
            phi = get_mag_amplitude(fid, startpoint, endpoint,
                                    polymer.getparvalue(ie, 'NBLK'),
                                    polymer.getparvalue(ie, 'BS'))
            df = pd.DataFrame(data=np.c_[tau, phi], columns=['tau', 'phi'])
            df['phi_normalized'] = (df['phi'] - df['phi'].iloc[0]) / (
                df['phi'].iloc[-1] - df['phi'].iloc[1])
            polymer.addparameter(ie, 'df_magnetization', df)
            fit_option = 2  #mono exponential, 3 parameter fit
            p0 = [1.0, polymer.getparvalue(ie, 'T1MX')**-1 * 2, 0]
            df, popt = magnetization_fit(df, p0, fit_option)
            source_df.data = ColumnDataSource.from_df(df)

            polymer.addparameter(ie, 'popt(mono_exp)', popt)
            print(popt)

            #print(df)
            print(polymer.getparvalue(ie, 'df_magnetization'))
            fid_slider = RangeSlider(start=1,
                                     end=polymer.getparvalue(ie, 'BS'),
                                     range=(startpoint, endpoint),
                                     step=1,
                                     callback_policy='mouseup')
            layout_p1.children[2] = fid_slider

        except KeyError:
            print('no relaxation experiment found')
            tau = np.zeros(1)
            phi = np.zeros(1)
            df = pd.DataFrame(data=np.c_[tau, phi], columns=['tau', 'phi'])
            df['phi_normalized'] = np.zeros(1)
            df['fit_phi'] = np.zeros(1)
            source_df.data = ColumnDataSource.from_df(df)

    #this source is only used to communicate to the actual callback (cb)
    source = ColumnDataSource(data=dict(value=[]))
    source.on_change('data', cb)

    slider = Slider(start=1,
                    end=nr_experiments,
                    value=1,
                    step=1,
                    callback_policy='mouseup')
    slider.callback = CustomJS(
        args=dict(source=source),
        code="""
        source.data = { value: [cb_obj.value] }
    """
    )  #unfortunately this customjs is needed to throttle the callback in current version of bokeh

    def calculate_mag_dec(attr, old, new):
        ie = slider.value
        polymer.addparameter(ie, 'fid_range', new['range'])
        print(polymer.getparvalue(ie, 'fid_range'))  #this works
        start = new['range'][0]
        stop = new['range'][1]
        fid = polymer.getfid(ie)
        tau = polymer.getparvalue(ie, 'df_magnetization').tau
        phi = get_mag_amplitude(fid, start, stop,
                                polymer.getparvalue(ie, 'NBLK'),
                                polymer.getparvalue(ie, 'BS'))

        df = pd.DataFrame(data=np.c_[tau, phi], columns=['tau', 'phi'])
        df['phi_normalized'] = (df['phi'] - df['phi'].iloc[0]) / (
            df['phi'].iloc[-1] - df['phi'].iloc[1])

        fit_option = 2  #mono exponential, 3 parameter fit
        p0 = polymer.getparvalue(ie, 'popt(mono_exp)')
        df, popt = magnetization_fit(df, p0, fit_option)
        source_df.data = ColumnDataSource.from_df(df)
        polymer.addparameter(ie, 'df_magnetization', df)
        polymer.addparameter(ie, 'popt(mono_exp)', popt)
        pass

    source2 = ColumnDataSource(data=dict(range=[], ie=[]))
    source2.on_change('data', calculate_mag_dec)
    fid_slider.callback = CustomJS(
        args=dict(source=source2),
        code="""
        source.data = { range: cb_obj.range }
    """
    )  #unfortunately this customjs is needed to throttle the callback in current version of bokeh

    # select boxes for p4
    x = Select(title='X-Axis', value='ZONE', options=columns)
    x.on_change('value', update)

    y = Select(title='Y-Axis', value='TIME', options=columns)
    y.on_change('value', update)

    size = Select(title='Size', value='None', options=['None'] + quantileable)
    size.on_change('value', update)

    color = Select(title='Color',
                   value='None',
                   options=['None'] + quantileable)
    color.on_change('value', update)

    controls_p4 = widgetbox([x, y, color, size], width=150)
    layout_p4 = row(controls_p4, plot_par())

    #fitting on all experiments
    p3 = figure(plot_width=300,
                plot_height=300,
                title='normalized phi vs normalized tau',
                webgl=True,
                y_axis_type='log',
                x_axis_type='linear')

    #fit magnetization decay for all experiments
    r1 = np.zeros(nr_experiments)
    MANY_COLORS = 0
    p3_line_glyph = []
    for i in range(nr_experiments):
        try:
            par = polymer.getparameter(i)
            fid = polymer.getfid(i)
            tau = get_x_axis(polymer.getparameter(i))
            startpoint = int(0.05 * polymer.getparameter(i)['BS'])
            endpoint = int(
                0.1 * polymer.getparameter(i)['BS']
            )  #TODO: make a range slider to get start- and endpoint interactively
            phi = get_mag_amplitude(fid, startpoint, endpoint,
                                    polymer.getparameter(i)['NBLK'],
                                    polymer.getparameter(i)['BS'])
            df = pd.DataFrame(data=np.c_[tau, phi], columns=['tau', 'phi'])
            df['phi_normalized'] = (df['phi'] - df['phi'].iloc[0]) / (
                df['phi'].iloc[-1] - df['phi'].iloc[1])
            polymer.addparameter(i, 'df_magnetization', df)

            p0 = [1, 2 * polymer.getparvalue(i, 'T1MX'), 0]
            df, popt = magnetization_fit(df, p0, fit_option=2)
            polymer.addparameter(i, 'amp', popt[0])
            polymer.addparameter(i, 'r1', popt[1])
            polymer.addparameter(i, 'noise', popt[2])
            r1[i] = popt[1]
            tau = popt[1] * df.tau
            phi = popt[0]**-1 * (df.phi_normalized - popt[2])
            p3_df = pd.DataFrame(data=np.c_[tau, phi], columns=['tau', 'phi'])
            source_p3 = ColumnDataSource(data=ColumnDataSource.from_df(p3_df))
            p3_line_glyph.append(p3.line(
                'tau', 'phi', source=source_p3))  #TODO add nice colors
            MANY_COLORS += 1
        except KeyError:
            print('no relaxation experiment found')
    COLORS = viridis(MANY_COLORS)
    for ic in range(MANY_COLORS):
        p3_line_glyph[ic].glyph.line_color = COLORS[ic]
    par_df['r1'] = r1
    layout_p1 = column(slider, p1, fid_slider, p2, p3)
    doc.add_root(layout_p1)
    doc.add_root(layout_p4)
    doc.add_root(source)  # i need to add source to detect changes
    doc.add_root(source2)