def change(choice, time_v, scn_v):
     if choice == 'time':
         scn.disabled = True
         time.disabled = False
         display(go.FigureWidget(plot.gaussian(t=time_v)))
     if choice == 'scn':
         scn.disabled = False
         time.disabled = True
         display(go.FigureWidget(plot.gaussian(scn=scn_v)))
예제 #2
0
    def set_out_fig(self):

        timeseries = self.controls['timeseries'].value

        time_window = self.controls['time_window'].value

        istart = timeseries_time_to_ind(timeseries, time_window[0])
        istop = timeseries_time_to_ind(timeseries, time_window[1])

        data, units = get_timeseries_in_units(timeseries, istart, istop)

        tt = get_timeseries_tt(timeseries, istart, istop)

        if len(data.shape) > 1:
            self.out_fig = go.FigureWidget(
                make_subplots(rows=data.shape[1], cols=1))

            for i, (yy, xyz) in enumerate(zip(data.T, ('x', 'y', 'z'))):
                self.out_fig.add_trace(go.Scatter(x=tt, y=yy),
                                       row=i + 1,
                                       col=1)
                if units:
                    yaxes_label = '{} ({})'.format(xyz, units)
                else:
                    yaxes_label = xyz
                self.out_fig.update_yaxes(title_text=yaxes_label,
                                          row=i + 1,
                                          col=1)
            self.out_fig.update_xaxes(title_text='time (s)', row=i + 1, col=1)
        else:
            self.out_fig = go.FigureWidget()
            self.out_fig.add_trace(go.Scatter(x=tt, y=data))
            self.out_fig.update_xaxes(title_text='time (s)')

        self.out_fig.update_layout(showlegend=False, title=timeseries.name)

        def on_change(change):
            time_window = self.controls['time_window'].value
            istart = timeseries_time_to_ind(timeseries, time_window[0])
            istop = timeseries_time_to_ind(timeseries, time_window[1])

            tt = get_timeseries_tt(timeseries, istart, istop)
            yy, units = get_timeseries_in_units(timeseries, istart, istop)

            with self.out_fig.batch_update():
                if len(yy.shape) == 1:
                    self.out_fig.data[0].x = tt
                    self.out_fig.data[0].y = yy
                else:
                    for i, dd in enumerate(yy.T):
                        self.out_fig.data[i].x = tt
                        self.out_fig.data[i].y = dd

        self.controls['time_window'].observe(on_change)
예제 #3
0
 def _init_cross_section_figure(self):
     data = []
     layout = {
         'showlegend': False,
         'template': 'none',
         'margin': {
             't': 0,
             'r': 0,
             'b': 0,
             'l': 30,
         },
         'width': self.figure_width / 2,
         'xaxis': {
             'scaleanchor': 'y',
             'showticklabels': False,
             'zeroline': False,
             'showgrid': False,
         },
         'yaxis': {
             'showticklabels': False,
             'zeroline': False,
             'showgrid': False,
         },
     }
     fig = go.FigureWidget(data=data, layout=layout)
     return fig
    def graph_tft(self, cumul=False, single=False):
        import plotly.graph_objects as go

        names = ["farmed", "cultivated", "sold", "burned", "farmer_income"]
        if cumul:
            names.append("cumul")

        start = self.month_start
        end = self.month_start + self.months_left

        fig = go.FigureWidget()
        for name in names:
            # values = eval(f"self.rows.tft_{name}.values")
            row = getattr(self.rows, f"tft_{name}")
            values = row.values_all[start:end]
            x = [i for i in range(start, end)]
            if single:
                values = [i / self.nrnodes for i in values]
            fig.add_trace(go.Scatter(x=x, y=values, name=name))
        if not single:
            nrnodes = self.nrnodes
        else:
            nrnodes = 1
        fig.update_layout(
            title="Tokens movement per month (batch:%s,nrnodes:%s)." %
            (self.batch_nr, nrnodes),
            showlegend=True)

        return fig
예제 #5
0
def base_plot(X, y, slope, intercept):
    N = len(X)
    f = go.FigureWidget([
        go.Scatter(x=X,
                   y=y,
                   mode='markers',
                   showlegend=False,
                   hoverinfo='none')
    ])

    y_hat = slope * X + intercept
    residual = y - y_hat

    scatter = f.data[0]
    colors = ["rgba(.0,.0,.6,0.5)"] * N
    scatter.marker.color = colors
    scatter.marker.size = [8] * N
    f.layout.hovermode = 'closest'

    trace_line = go.Scatter(x=X,
                            y=y_hat,
                            mode="lines",
                            line=go.scatter.Line(color="gray"),
                            showlegend=False)
    f.add_trace(trace_line)
    return f, scatter
def show_traction_separation(tensile_test, ts_data_name="ts_energy"):

    ts_plot_data = tensile_test.get_traction_separation_plot_data(ts_data_name)
    fig = graph_objects.FigureWidget(
        data=[
            {
                **ts_plot_data,
                "mode": "markers+lines",
                "marker": {},
                "line": {
                    "width": 0.5,
                },
            },
        ],
        layout={
            "width": 400,
            "height": 400,
            "xaxis": {
                "title": "Separation distance /Ang"
            },
            "yaxis": {
                "title": "TS energy / Jm<sup>-2</sup>",
            },
        },
    )

    return fig
예제 #7
0
def woe_plot_widget(iv_table, width=500, height=500):
    data = [
        go.Bar(x=iv_table['VAR_NAME'],
               y=iv_table['WOE'],
               text=iv_table['VAR_NAME'],
               marker=dict(color='orange',
                           line=dict(
                               color='rgb(8,48,107)',
                               width=1.5,
                           )),
               opacity=0.6)
    ]

    layout = go.Layout(
        title='Weight of Evidence(WOE)',
        xaxis=dict(title='Features',
                   tickangle=-45,
                   tickfont=dict(size=10, color='rgb(107, 107, 107)')),
        yaxis=dict(title='Weight of Evidence(WOE)',
                   titlefont=dict(size=14, color='rgb(107, 107, 107)'),
                   tickfont=dict(size=14, color='rgb(7, 7, 7)')),
    )

    fig = go.Figure(data, layout)
    fig.update_layout(autosize=False, width=width, height=height)
    woe_widget = go.FigureWidget(fig)

    return woe_widget
예제 #8
0
def iv_plot_widget(iv, width=500, height=500):
    data = [
        go.Bar(x=iv['VAR_NAME'],
               y=iv['IV'],
               text=iv['VAR_NAME'],
               marker=dict(color='rgb(58,256,225)',
                           line=dict(
                               color='rgb(8,48,107)',
                               width=1.5,
                           )),
               opacity=0.6)
    ]

    layout = go.Layout(
        title='Information Values',
        xaxis=dict(tickangle=-45,
                   title='Features',
                   tickfont=dict(size=10, color='rgb(7, 7, 7)')),
        yaxis=dict(title='Information Value(IV)',
                   titlefont=dict(size=14, color='rgb(107, 107, 107)'),
                   tickfont=dict(size=14, color='rgb(107, 107, 107)')),
    )

    fig = go.Figure(data, layout)
    fig.update_layout(autosize=False, width=width, height=height)
    iv_widget = go.FigureWidget(fig)
    return iv_widget
예제 #9
0
파일: final.py 프로젝트: twlim1/dse241
def update_expl_vis_parcats(features, df_input):
    # print('Draw parcats')
    order = ['Country'] + [feature for feature in features]
    group_by = ['Region', 'Country', 'Attack Type', 'Weapon Type', 'Suicide', 'Success']
    agg_on = {'eventid': ['size'], 'Killed': ['sum'], 'Wounded': ['sum']}
    df_tmp = df_input.groupby(group_by).agg(agg_on).reset_index()
    df_tmp.columns = ['Region', 'Country', 'Attack Type', 'Weapon Type', 'Suicide', 'Success',
                      'Attack', 'Killed', 'Wounded']
    dimensions = [dict(values=df_tmp[label], label=label) for label in order]

    # Build color scale
    parcats_length = len(df_tmp)
    color = np.zeros(parcats_length, dtype='uint8')
    colorscale = [[0, 'gray'], [1, 'firebrick']]

    # Build figure as FigureWidget
    fig = go.FigureWidget(
        data=[go.Scatter(x=df_tmp['Killed'],
                         y=df_tmp['Wounded'],
                         marker={'color': 'gray'},
                         mode='markers',
                         selected={'marker': {'color': 'firebrick'}},
                         unselected={'marker': {'opacity': 0.3}}),
              go.Parcats(domain={'y': [0, 0.4]},
                         dimensions=dimensions,
                         line={'colorscale': colorscale, 'cmin': 0, 'cmax': 1, 'color': color, 'shape': 'hspline'},
                         labelfont={'size': 18, 'family': 'Times'},
                         tickfont={'size': 16, 'family': 'Times'})
              ])

    fig.update_layout(margin={'l': 40, 'b': 40, 't': 40, 'r': 40},
                      height=800, xaxis={'title': 'Killed'},
                      yaxis={'title': 'Wounded', 'domain': [0.6, 1]},
                      dragmode='lasso', hovermode='closest')
    return fig, len(df_tmp)
예제 #10
0
def score_table(quality_estimation, field_accuracy) -> go.FigureWidget:
    cells = [
        ["<b>Field Accuracy Score</b>", "<b>Overall Quality Score</b>"],
        [
            "<b>" + str(field_accuracy) + "<b>",
            "<b>" + str(quality_estimation) + "</b>"
        ],
    ]

    font = dict(color="black", size=20)
    trace = go.Table(
        header=dict(values=[cells[0][0], cells[1][0]],
                    fill=dict(color="gray"),
                    font=font),
        cells=dict(
            values=[cells[0][1:], cells[1][1:]],
            fill=dict(color=[[get_color(quality_estimation)]]),
            font=font,
        ),
    )

    layout = go.Layout(autosize=True,
                       margin=dict(l=0, t=25, b=25, r=0),
                       height=150)
    return go.FigureWidget(data=[trace], layout=layout)
예제 #11
0
 def _add_fig_trace(img_fig: go.Figure, index):
     if self.figure is None:
         self.figure = go.FigureWidget(img_fig)
     else:
         self.figure.for_each_trace(
             lambda trace: trace.update(img_fig.data[0]))
     self.figure.layout.title = f"Frame no: {index}"
예제 #12
0
def visualize_simrank(g1, u, pos):
    """takes a graph and plots it, coloring vertices by RDD

    Args:
    -----
        g1: a networkx graph
        u: source node
        v: target radius
        m: a measure function from measures

    Returns:
    --------
        fig: a figure object of a scatter plot"""

    df = other_sims.simrank(g1, u)
    # pos = nx.spring_layout(g1)
    nodes_x = []
    nodes_y = []

    for p in pos.values():
        x, y = p[0], p[1]
        nodes_x.append(x)
        nodes_y.append(y)

    df['nodes_x'] = nodes_x
    df['nodes_y'] = nodes_y

    edges_x = []
    edges_y = []
    for e in g1.edges():
        x0, y0 = pos[e[0]]
        x1, y1 = pos[e[1]]
        edges_x.append(x0)
        edges_x.append(x1)
        edges_x.append(None)
        edges_y.append(y0)
        edges_y.append(y1)
        edges_y.append(None)

    # fig = px.scatter(df, x='nodes_x', y='nodes_y', text='node_name', custom_data=['rdd'], color='rdd')
    # fig.update_traces(hovertemplate='Node: %{text}, RDD: %{customdata[0]}')
    # fig.update_layout(font_size=20)
    # fig.update_traces(marker={'size': 20})
    # fig.add_trace(go.Scatter(x=edges_x, y=edges_y, mode='lines', line={'width': 3}))

    fig = go.FigureWidget()
    fig.add_trace(go.Scatter(x=edges_x, y=edges_y, name='edges', mode='lines', line={'width': 1}))
    fig.add_trace(go.Scatter(x=df['nodes_x'],
                             y=df['nodes_y'],
                             customdata=df[['simrank', 'degree']].values,
                             hovertemplate="Node: %{text} <br> SimRank: %{customdata[0]} <br> Degree: %{customdata[1]} <extra></extra>",
                             text=df['node_name'],
                             name="nodes",
                             mode='markers+text'))
    fig.update_layout(template="plotly_dark", dragmode='pan')
    fig.update_traces(marker={'size': 15, 'color': df['simrank'], 'colorscale': 'Jet'})
    fig.write_html("graph.html", config={'scrollZoom': True})

    # return fig.show(config={'scrollZoom':True})
    return fig
예제 #13
0
    def __init__(self, electrodes: pynwb.base.DynamicTable, **kwargs):

        super().__init__()
        self.electrodes = electrodes

        slider_kwargs = dict(value=1.0,
                             min=0.0,
                             max=1.0,
                             style={"description_width": "initial"})

        left_opacity_slider = widgets.FloatSlider(
            description="left hemi opacity", **slider_kwargs)

        right_opacity_slider = widgets.FloatSlider(
            description="right hemi opacity", **slider_kwargs)

        color_by_dropdown = widgets.Dropdown(
            options=list(electrodes.colnames),
            value="group_name",
            description="Color By:",
            disabled=False,
        )

        color_by_dropdown.observe(self.color_electrode_by)
        left_opacity_slider.observe(self.observe_left_opacity)
        right_opacity_slider.observe(self.observe_right_opacity)

        self.fig = go.FigureWidget()
        self.plot_human_brain()
        self.show_electrodes(electrodes, color_by_dropdown.value)
        sliders = widgets.HBox([left_opacity_slider, right_opacity_slider])
        self.children = [self.fig, widgets.VBox([sliders, color_by_dropdown])]
예제 #14
0
def default_chart():
    # empty_scatter = go.Scatter(
    #     x=[], y=[],
    #     mode='markers',
    #     hoverinfo='text',
    #     marker=dict(size=1))

    empty_scatter = {
        'data': {
            'x': [[]],
            'y': [[]],
            'mode': 'markers',
            'marker': {
                'size': 1
            }
        },
        'layout': go.Layout(xaxis=default_axis_params,
                            yaxis=default_axis_params)
    }

    return go.FigureWidget(data=empty_scatter,
                           layout=go.Layout(titlefont_size=16,
                                            showlegend=False,
                                            margin=chart_size,
                                            xaxis=default_axis_params,
                                            yaxis=default_axis_params))
예제 #15
0
def multi_trace(x, y, color, label=None, fig=None):
    """ Create multiple traces that are associated with a single legend label

    Parameters
    ----------
    x: array-like
    y: array-like
    color: str
    label: str, optional
    fig: go.FigureWidget

    Returns
    -------

    """
    if fig is None:
        fig = go.FigureWidget()

    for i, yy in enumerate(y):
        if label is not None and i:
            showlegend = False
        else:
            showlegend = True

        fig.add_scatter(x=x, y=yy, legendgroup=label, name=label, showlegend=showlegend, line={'color': color})

    return fig
예제 #16
0
def show_gamma_surface_fit(gamma_surface, shift, data_name='energy'):

    fit_plot_dat = gamma_surface.get_fit_plot_data(data_name, shift)
    fig = graph_objects.FigureWidget(data=[
        {
            **fit_plot_dat['fitted_data'],
            'name': 'Fit',
        },
        {
            **fit_plot_dat['data'],
            'name': data_name,
        },
        {
            **fit_plot_dat['minimum'],
            'name': 'Fit min.',
        },
    ],
                                     layout={
                                         'xaxis': {
                                             'title': 'Expansion',
                                         },
                                         'yaxis': {
                                             'title': data_name,
                                         },
                                         'width': 400,
                                         'height': 400,
                                     })
    return fig
    def _build_plots(self):
        """
        Add a plot of the mission
        """

        x_name = self._x_widget.value
        y_name = self._y_widget.value

        for name in self.missions:
            if self._fig is None:
                self._fig = go.Figure()
            # pylint: disable=invalid-name # that's a common naming
            x = self.missions[name][x_name]
            # pylint: disable=invalid-name # that's a common naming
            y = self.missions[name][y_name]

            scatter = go.Scatter(x=x, y=y, mode="lines", name=name)

            self._fig.add_trace(scatter)

            self._fig = go.FigureWidget(self._fig)

        self._fig.update_layout(title_text="Mission",
                                title_x=0.5,
                                xaxis_title=x_name,
                                yaxis_title=y_name)
예제 #18
0
    def create_figure(plot_node, node_trace, my_shapes):
        """Create a plotly figure based on node trace and edge shape

        Parameters
        ----------
        plot_node: str
            The node of interest (the one selected in the figure).
        node_trace : plotly.graph_objs._scatter.Scatter
            Scatter plot of node location
        my_shapes : [dict]
            Shape objects suitable for plotly layout inclusion.

        Returns
        -------
        fig : plotly.graph_objs._figurewidget.FigureWidget
            Figure widget capable of responding to click events
        """
        fig = go.FigureWidget(data=[node_trace],
                         layout=go.Layout(
                             title=f'Interactive Graph of Network Failures<br>Selected Node: {plot_node}',
                             titlefont_size=16,
                             showlegend=False,
                             hovermode='closest',
                             margin=dict(b=20, l=5, r=5, t=40),
                             annotations=[dict(
                                 text="<a href='https://www.youtube.com/watch?v=dQw4w9WgXcQ'> Click me for more info</a>",
                                 showarrow=False,
                                 xref="paper", yref="paper",
                                 x=0.005, y=-0.002)],
                             xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                             yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                             shapes=my_shapes)
                         )
        return fig
예제 #19
0
def scroll_images(images, show=True):
    if len(images.shape) == 4:
        pass
    elif len(images.shape) == 3:
        images = images[..., np.newaxis]

    #build each trace
    data = [go.Image(z=image, visible=False) for image in images]
    data[0]['visible'] = True
    steps = [
        dict(method='update',
             args=[{
                 'visible': [t == i for t in range(len(data))]
             }]) for i in range(len(images))
    ]

    fig = go.Figure(data=data, layout=__layout_noaxis__())

    fig = go.FigureWidget(fig)
    fig.layout.sliders = [
        dict(active=0,
             currentvalue={"prefix": "image : "},
             pad={"t": 20},
             steps=steps)
    ]

    if show:
        fig.show()

    return fig
예제 #20
0
def plot_grouped_events_plotly(data, window=None, group_inds=None, colors=color_wheel, labels=None,
                               show_legend=True, unobserved_intervals_list=None, progress_bar=None, fig=None, **kwargs):
    data = np.array(data, dtype=object)

    if fig is None:
        fig = go.FigureWidget()
    if group_inds is not None:
        ugroup_inds = np.unique(group_inds)
        offset = 0
        for i in np.arange(len(ugroup_inds)):
            ui = ugroup_inds[i]
            color = colors[ugroup_inds[i] % len(colors)]
            this_data = data[group_inds == ui]
            event_group(this_data,
                        offset=offset,
                        label=labels[ui],
                        color=color,
                        fig=fig,
                        **kwargs)
            offset += len(this_data)

    else:
        event_group(data, fig=fig, **kwargs)

    fig.update_layout(xaxis_title="time (s)")

    return fig
예제 #21
0
def pp_barplot_widget(preds):
    names, scores = [p[0] for p in reversed(preds)
                     ], [p[1] for p in reversed(preds)]
    bar = go.Bar(x=scores, y=names, orientation='h')
    bar_w = go.FigureWidget(bar)
    bar_w
    return bar_w
def create_grid(head_vals,cell_vals,head_frmt,cell_frmt,width,height,head_height,cell_height,l,r,b,t,header_fill_colour=c.lightgray,header_font_colour=c.white,cell_font_colour=c.white):
    grid = go.FigureWidget(data=[go.Table(
            header=dict(
                values=head_vals,
                height=head_height,
                font=dict(size=12,color=header_font_colour),
                format=head_frmt,
                line_color=c.darkgray,
                fill_color=header_fill_colour,
                align=["left"]*len(head_vals)),
            cells=dict(
                values=cell_vals,
                height=cell_height,
                font=dict(size=12,color=cell_font_colour),
                format=cell_frmt,
                line_color=c.darkgray,
                fill_color=c.gray,
                align=["left"]*len(cell_vals),))],
            layout=go.Layout(
                margin=go.layout.Margin(l=l,r=r,b=b,t=t,),
                paper_bgcolor=c.transparent,
                plot_bgcolor=c.transparent,
                template=c.template,
                height=height,
                width=width))
    return grid
예제 #23
0
def plot(traces, show=True, **kwargs):
    """
    General plot functions used to plot any plotly list of traces.

    :param traces: list of plotly traces to plot
    :type traces: list

    :param show: Boolean controlling whether or not to plot the curves
    :type show: bool, optional

    :param kwargs: optional arguments used in plot functions.
    Possible kwargs are :
    - x_axis_name string representing name of x axis
    - y_axis_name string representing name of y axis
    - x_min float value or string date representing minimal value to show along x_axis
    - x_max float value or string date representing maximal value to show along x_axis
    - y_min float value representing minimal value to show along y_axis
    - y_max float value  representing maximal value to show along y_axis
    - title tile of the graph
    - template string indicating plotly graph template to use. Possible choices are :
            "plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none".
            See https://plot.ly/python/templates for more informations

    :return: plotly figure object
    """
    x_axis_name = kwargs.pop('x_axis_name', None)
    y_axis_name = kwargs.pop('y_axis_name', None)
    template = kwargs.pop('template', 'plotly_dark')
    widget = kwargs.pop('widget', False)

    fig = go.Figure()
    for trace in traces:
        fig.add_trace(trace)

    props = {}
    for arg_name in ['x_min', 'x_max', 'y_min', 'y_max']:
        props[arg_name] = kwargs[arg_name] if arg_name in kwargs else None

    fig.update_layout(
        title=kwargs['title'] if 'title' in kwargs else '',
        xaxis={
            'title': x_axis_name,
            'range': [props['x_min'], props['x_max']]
        },
        yaxis={
            'title': y_axis_name,
            'range': [props['y_min'], props['y_max']]
        },
        showlegend=True,
        legend=dict(x=-0.1, y=1.1, bgcolor='rgba(0,0,0,0)'
                    ),  # use of rgba to make rectangle transparent
        legend_orientation="h",
        template=template)

    if widget:
        return go.FigureWidget(fig)
    else:
        if show is True:
            fig.show()
        return fig
예제 #24
0
def create_interactive_plot(df, ccaa_series, dates, traces, title_text):
    widget = widgets.Dropdown(options=ccaa_series.unique().tolist(),
                              description='CCAA')
    g = go.FigureWidget(layout=go.Layout(title=dict(text=title_text),
                                         legend=dict(orientation='h',
                                                     bgcolor='LightSteelBlue'),
                                         barmode='overlay'))
    for trace, trace_name in traces:
        g.add_trace(go.Bar(x=dates, y=df[trace], name=trace_name))

    def validate():
        if widget.value in ccaa_series.unique():
            return True
        else:
            return False

    def response(change):
        if validate():
            filter_list = [i for i in ccaa_series == widget.value]
            temp_df = df[filter_list]
            x = temp_df['fecha']
            i = 0
        with g.batch_update():
            for trace, trace_name in traces:
                g.data[i].x = x
                g.data[i].y = temp_df[trace]
                i += 1
            g.layout.barmode = 'overlay'
            g.layout.xaxis.title = ''
            g.layout.yaxis.title = ''

    widget.observe(response, names='value')
    return widgets.VBox([widget, g])
예제 #25
0
    def __init__(self, electrodes: pynwb.base.DynamicTable, **kwargs):

        super().__init__()

        slider_kwargs = dict(value=1.,
                             min=0.,
                             max=1.,
                             style={'description_width': 'initial'})

        left_opacity_slider = widgets.FloatSlider(
            description='left hemi opacity', **slider_kwargs)

        right_opacity_slider = widgets.FloatSlider(
            description='right hemi opacity', **slider_kwargs)

        left_opacity_slider.observe(self.observe_left_opacity)
        right_opacity_slider.observe(self.observe_right_opacity)

        self.fig = go.FigureWidget()
        self.plot_human_brain()
        self.show_electrodes(electrodes)

        self.children = [
            self.fig,
            widgets.HBox([left_opacity_slider, right_opacity_slider])
        ]
예제 #26
0
def reaction_plot_threshold(thresh_df):

    idxs = thresh_df.segment_id.unique()
    cols = colors.get_colors(plt.cm.plasma, len(idxs) + 1)
    traces = []

    for segment_idx in idxs:
        sub_df = thresh_df[thresh_df['segment_id'] == segment_idx]
        traces.append(
            go.Scatter(
                x=sub_df['Time (s)'],
                y=sub_df['Correlation'],
                name=f'Segment: {segment_idx} | No of spikes: {len(sub_df)}',
                opacity=0.20,
                mode='markers',
                marker=dict(color=cols[segment_idx])))

    gg = go.FigureWidget(data=traces,
                         layout=go.Layout(
                             title=dict(text=f'Threshold signals correlation'),
                             barmode='overlay',
                             xaxis=dict(title='Time (s)', range=[-0.02, 4.02]),
                             yaxis=dict(title='Correlation',
                                        range=[-0.02, 1.02])))

    gg.data[0].opacity = 1
    return gg
예제 #27
0
    def update(
        self,
        index: int,
        start_label: str = "start_time",
        before: float = 0.0,
        after: float = 1.0,
        order=None,
        group_inds=None,
        labels=None,
        align_to_zero=False,
        fig: go.FigureWidget = None,
    ):

        data, time_ts_aligned = self.align_data(start_label, before, after,
                                                index)
        if group_inds is None:
            group_inds = np.zeros(len(self.trials), dtype=np.int)
        if align_to_zero:
            for trial_no in order:
                data_zero_id = bisect(time_ts_aligned[trial_no], 0)
                data[trial_no] -= data[trial_no][data_zero_id]
        fig = fig if fig is not None else go.FigureWidget()
        fig.data = []
        fig.layout = {}
        return self.plot_group(group_inds, data, time_ts_aligned, fig, order)
예제 #28
0
    def _get_figure_widget(self):
        config = Config()
        trace = go.Parcats(
            dimensions=[{
                "label": col,
                "values": self.data_source.data[col]
            } for col in self.selected_columns],
            line=dict(
                color=config.color_scale[1][1],
                colorscale=config.color_scale,
                cmin=0,
                cmax=1,
                shape="hspline",
            ),
        )

        figure_widget = go.FigureWidget(
            data=[trace],
            layout=go.Layout(
                margin=dict(l=20, r=20, b=20, t=20, pad=5),
                autosize=True,
                showlegend=False,
            ),
        )

        figure_widget.data[0].on_click(self.on_selection)
        return trace, figure_widget
예제 #29
0
 def before_fit(self, **kwargs):
     self.train_pca = ifnone(
         self.train_pca,
         self.do_pca(get_xy(self.learn.dls)[0], do_train=True))
     self.train_trace = scatter(self.train_pca,
                                name='Training data',
                                mode='markers',
                                marker_color='#539dcc',
                                marker_size=1.5 if self.is_3d else 4)
     self.weight_pca = self.do_pca(self.learn.model.weights)
     self.weight_trace = scatter(self.weight_pca,
                                 name='SOM weights',
                                 mode='markers',
                                 marker_color='#e58368',
                                 marker_size=3 if self.is_3d else 6)
     expl_var = str(
         tuple(
             map(lambda pct: f'{pct:.0f}%',
                 self.pca.explained_variance_ratio_ * 100)))[1:-1]
     layout = go.Layout(
         title=f"SOM Visualization ({expl_var} explained variance)")
     self.fig = go.FigureWidget([self.train_trace, self.weight_trace],
                                layout=layout)
     self.fig.update_layout(margin=dict(l=20, r=20, t=20, b=20),
                            paper_bgcolor="LightSteelBlue")
예제 #30
0
def show_master_gamma_surface(gamma_surface, data_name='energy'):

    master_plot_data = gamma_surface.get_fitted_surface_plot_data(
        data_name, xy_as_grid=False)
    grid_dat = gamma_surface.get_xy_plot_data()
    fig = graph_objects.FigureWidget(data=[
        {
            'type': 'contour',
            'colorscale': 'viridis',
            'colorbar': {
                'title': data_name,
            },
            **master_plot_data,
        },
        {
            **grid_dat,
            'mode': 'markers',
            'marker': {
                'size': 2,
            },
            'showlegend': True,
        },
    ],
                                     layout={'xaxis': {
                                         'scaleanchor': 'y',
                                     }})
    return fig