Beispiel #1
0
    def process(self, inputs):
        """
        Takes `datetime`, `open`, `close`, `high`, `volume` columns in the
        dataframe to plot the bqplot figure for this stock.

        Arguments
        -------
         inputs: list
            list of input dataframes.
        Returns
        -------
        bqplot.Figure
        """
        stock = inputs[0]
        num_points = self.conf['points']
        stride = max(len(stock) // num_points, 1)
        label = 'stock'
        if 'label' in self.conf:
            label = self.conf['label']
        sc = LinearScale()
        sc2 = LinearScale()
        dt_scale = DateScale()
        ax_x = Axis(label='Date', scale=dt_scale)
        ax_y = Axis(label='Price',
                    scale=sc,
                    orientation='vertical',
                    tick_format='0.0f')
        # Construct the marks
        ohlc = OHLC(x=stock['datetime'][::stride].to_array(),
                    y=cp.asnumpy(stock[['open', 'high', 'low', 'close'
                                        ]].as_gpu_matrix()[::stride, :]),
                    marker='candle',
                    scales={
                        'x': dt_scale,
                        'y': sc
                    },
                    format='ohlc',
                    stroke='blue',
                    display_legend=True,
                    labels=[label])
        bar = Bars(x=stock['datetime'][::stride].to_array(),
                   y=stock['volume'][::stride].to_array(),
                   scales={
                       'x': dt_scale,
                       'y': sc2
                   },
                   padding=0.2)
        def_tt = Tooltip(fields=['x', 'y'], formats=['%Y-%m-%d', '.2f'])
        bar.tooltip = def_tt
        bar.interactions = {
            'legend_hover': 'highlight_axes',
            'hover': 'tooltip',
            'click': 'select',
        }
        sc.min = stock['close'].min() - 0.3 * \
            (stock['close'].max() - stock['close'].min())
        sc.max = stock['close'].max()
        sc2.max = stock['volume'].max() * 4.0
        f = Figure(axes=[ax_x, ax_y],
                   marks=[ohlc, bar],
                   fig_margin={
                       "top": 0,
                       "bottom": 60,
                       "left": 60,
                       "right": 60
                   })
        return f
Beispiel #2
0
    def process(self, inputs):
        """
        Plot the Scatter plot

        Arguments
        -------
         inputs: list
            list of input dataframes.
        Returns
        -------
        Figure

        """
        input_df = inputs[self.INPUT_PORT_NAME]
        if isinstance(input_df, dask_cudf.DataFrame):
            input_df = input_df.compute()  # get the computed value
        num_points = self.conf['points']
        stride = max(len(input_df) // num_points, 1)

        sc_x = scaleMap[self.conf.get('col_x_scale', 'LinearScale')]()
        sc_y = scaleMap[self.conf.get('col_y_scale', 'LinearScale')]()

        x_col = self.conf['col_x']
        y_col = self.conf['col_y']
        ax_y = Axis(label=y_col,
                    scale=sc_y,
                    orientation='vertical',
                    side='left')

        ax_x = Axis(label=x_col,
                    scale=sc_x,
                    num_ticks=10,
                    label_location='end')
        m_chart = dict(top=50, bottom=70, left=50, right=100)
        if 'col_color' in self.conf:
            color_col = self.conf['col_color']
            sc_c1 = ColorScale()
            ax_c = ColorAxis(scale=sc_c1,
                             tick_format='0.2%',
                             label=color_col,
                             orientation='vertical',
                             side='right')
            if isinstance(input_df, (cudf.DataFrame, dask_cudf.DataFrame)):
                scatter = Scatter(
                    x=input_df[x_col][::stride].to_array(),
                    y=input_df[y_col][::stride].to_array(),
                    color=input_df[color_col][::stride].to_array(),
                    scales={
                        'x': sc_x,
                        'y': sc_y,
                        'color': sc_c1
                    },
                    stroke='black')
            else:
                scatter = Scatter(x=input_df[x_col][::stride],
                                  y=input_df[y_col][::stride],
                                  color=input_df[color_col][::stride],
                                  scales={
                                      'x': sc_x,
                                      'y': sc_y,
                                      'color': sc_c1
                                  },
                                  stroke='black')
            fig = Figure(axes=[ax_x, ax_c, ax_y],
                         marks=[scatter],
                         fig_margin=m_chart,
                         title=self.conf['title'])

        else:
            if isinstance(input_df, (cudf.DataFrame, dask_cudf.DataFrame)):
                scatter = Scatter(x=input_df[x_col][::stride].to_array(),
                                  y=input_df[y_col][::stride].to_array(),
                                  scales={
                                      'x': sc_x,
                                      'y': sc_y
                                  },
                                  stroke='black')
            else:
                scatter = Scatter(x=input_df[x_col][::stride],
                                  y=input_df[y_col][::stride],
                                  scales={
                                      'x': sc_x,
                                      'y': sc_y
                                  },
                                  stroke='black')
            fig = Figure(axes=[ax_x, ax_y],
                         marks=[scatter],
                         fig_margin=m_chart,
                         title=self.conf['title'])
        return {self.OUTPUT_PORT_NAME: fig}
Beispiel #3
0
def create_figure(stock, dt_scale, sc, color_id, f, indicator_figure_height,
                  figure_width, add_new_indicator):
    sc_co = LinearScale()
    sc_co2 = LinearScale()
    sc_co3 = LinearScale()
    sc_co4 = LinearScale()
    sc_co5 = LinearScale()
    sc_co6 = LinearScale()
    sc_co7 = LinearScale()

    ax_y = Axis(label='PPSR PP', scale=sc_co, orientation='vertical')
    ax_y2 = Axis(label='PPSR R1',
                 scale=sc_co2,
                 orientation='vertical',
                 side='right')
    ax_y3 = Axis(label='PPSR S1',
                 scale=sc_co3,
                 orientation='vertical',
                 side='right')
    ax_y4 = Axis(label='PPSR R2',
                 scale=sc_co4,
                 orientation='vertical',
                 side='right')
    ax_y5 = Axis(label='PPSR S2',
                 scale=sc_co5,
                 orientation='vertical',
                 side='right')
    ax_y6 = Axis(label='PPSR R3',
                 scale=sc_co6,
                 orientation='vertical',
                 side='right')
    ax_y7 = Axis(label='PPSR S3',
                 scale=sc_co7,
                 orientation='vertical',
                 side='right')
    new_line = Lines(x=stock.datetime.to_array(),
                     y=stock['out0'].to_array(),
                     scales={
                         'x': dt_scale,
                         'y': sc_co
                     },
                     colors=[CATEGORY20[color_id[0]]])
    new_line2 = Lines(x=stock.datetime.to_array(),
                      y=stock['out1'].to_array(),
                      scales={
                          'x': dt_scale,
                          'y': sc_co2
                      },
                      colors=[CATEGORY20[(color_id[0] + 1) % len(CATEGORY20)]])
    new_line3 = Lines(x=stock.datetime.to_array(),
                      y=stock['out2'].to_array(),
                      scales={
                          'x': dt_scale,
                          'y': sc_co3
                      },
                      colors=[CATEGORY20[(color_id[0] + 2) % len(CATEGORY20)]])
    new_line4 = Lines(x=stock.datetime.to_array(),
                      y=stock['out3'].to_array(),
                      scales={
                          'x': dt_scale,
                          'y': sc_co4
                      },
                      colors=[CATEGORY20[(color_id[0] + 3) % len(CATEGORY20)]])
    new_line5 = Lines(x=stock.datetime.to_array(),
                      y=stock['out4'].to_array(),
                      scales={
                          'x': dt_scale,
                          'y': sc_co5
                      },
                      colors=[CATEGORY20[(color_id[0] + 4) % len(CATEGORY20)]])
    new_line6 = Lines(x=stock.datetime.to_array(),
                      y=stock['out5'].to_array(),
                      scales={
                          'x': dt_scale,
                          'y': sc_co6
                      },
                      colors=[CATEGORY20[(color_id[0] + 5) % len(CATEGORY20)]])
    new_line7 = Lines(x=stock.datetime.to_array(),
                      y=stock['out6'].to_array(),
                      scales={
                          'x': dt_scale,
                          'y': sc_co7
                      },
                      colors=[CATEGORY20[(color_id[0] + 6) % len(CATEGORY20)]])
    new_fig = Figure(marks=[
        new_line, new_line2, new_line3, new_line4, new_line5, new_line6,
        new_line7
    ],
                     axes=[ax_y, ax_y2, ax_y3, ax_y4, ax_y5, ax_y6, ax_y7])
    new_fig.layout.height = indicator_figure_height
    new_fig.layout.width = figure_width
    figs = [
        new_line, new_line2, new_line3, new_line4, new_line5, new_line6,
        new_line7
    ]
    add_new_indicator(new_fig)
    return figs
Beispiel #4
0
def cost_display(n_days=7):

    users = widgets.IntText(value=8, description='Number of total users')
    storage_per_user = widgets.IntText(value=10, description='Storage per user (GB)')
    mem_per_user = widgets.IntText(value=2, description="RAM per user (GB)")
    machines = widgets.Dropdown(description='Machine',
                                options=machines_list['Machine type'].values.tolist())
    persistent = widgets.Dropdown(description="Persistent Storage?",
                                  options={'HDD': 'hdd', 'SSD': 'ssd'},
                                  value='hdd')
    autoscaling = widgets.Checkbox(value=False, description='Autoscaling?')
    text_avg_num_machine = widgets.Text(value='', description='Average # Machines:')
    text_cost_machine = widgets.Text(value='', description='Machine Cost:')
    text_cost_storage = widgets.Text(value='', description='Storage Cost:')
    text_cost_total = widgets.Text(value='', description='Total Cost:')

    hr = widgets.HTML(value="---")

    # Define axes limits
    y_max = 100.
    date_stop, date_range = create_date_range(n_days)

    # Create axes and extra variables for the viz
    xs_hd = DateScale(min=date_start, max=date_stop, )
    ys_hd = LinearScale(min=0., max=y_max)

    # Shading for weekends
    is_weekend = np.where([ii in [6, 7] for ii in date_range.dayofweek], 1, 0)
    is_weekend = is_weekend * (float(y_max) + 50.)
    is_weekend[is_weekend == 0] = -10
    line_fill = Lines(x=date_range, y=is_weekend,
                      scales={'x': xs_hd, 'y': ys_hd}, colors=['black'],
                      fill_opacities=[.2], fill='bottom')

    # Set up hand draw widget
    line_hd = Lines(x=date_range, y=10 * np.ones(len(date_range)),
                    scales={'x': xs_hd, 'y': ys_hd}, colors=['#E46E2E'])
    line_users = Lines(x=date_range, y=10 * np.ones(len(date_range)),
                       scales={'x': xs_hd, 'y': ys_hd}, colors=['#e5e5e5'])
    line_autoscale = Lines(x=date_range, y=10 * np.ones(len(date_range)),
                           scales={'x': xs_hd, 'y': ys_hd}, colors=['#000000'])
    handdraw = HandDraw(lines=line_hd)
    xax = Axis(scale=xs_hd, label='Day', grid_lines='none',
               tick_format='%b %d')
    yax = Axis(scale=ys_hd, label='Numer of Users',
               orientation='vertical', grid_lines='none')
    # FIXME add `line_autoscale` when autoscale is enabled
    fig = Figure(marks=[line_fill, line_hd, line_users],
                 axes=[xax, yax], interaction=handdraw)

    def _update_cost(change):
        # Pull values from the plot
        max_users = max(handdraw.lines.y)
        max_buffer = max_users * 1.05  # 5% buffer
        line_users.y = [max_buffer] * len(handdraw.lines.y)
        if max_users > users.value:
            users.value = max_users

        autoscaled_users = autoscale(handdraw.lines.y)
        line_autoscale.y = autoscaled_users

        # Calculate costs
        active_machine = machines_list[machines_list['Machine type'] == machines.value]
        machine_cost = active_machine['Price (USD / hr)'].values.astype(float) * 24  # To make it cost per day
        users_for_cost = autoscaled_users if autoscaling.value is True else [max_buffer] * len(handdraw.lines.y)
        num_machines = calculate_machines_needed(users_for_cost, mem_per_user.value, active_machine)
        avg_num_machines = np.mean(num_machines)
        cost_machine = integrate_cost(num_machines, machine_cost)
        cost_storage = integrate_cost(num_machines, storage_cost[persistent.value] * storage_per_user.value)
        cost_total = cost_machine + cost_storage

        # Set the values
        for iwidget, icost in [(text_cost_machine, cost_machine),
                               (text_cost_storage, cost_storage),
                               (text_cost_total, cost_total),
                               (text_avg_num_machine, avg_num_machines)]:
            if iwidget is not text_avg_num_machine:
                icost = locale.currency(icost, grouping=True)
            else:
                icost = '{:.2f}'.format(icost)
            iwidget.value = icost

        # Set the color
        if autoscaling.value is True:
            line_autoscale.colors = ['#000000']
            line_users.colors = ['#e5e5e5']
        else:
            line_autoscale.colors = ['#e5e5e5']
            line_users.colors = ['#000000']

    line_hd.observe(_update_cost, names='y')
    # autoscaling.observe(_update_cost)  # FIXME Uncomment when we implement autoscaling
    persistent.observe(_update_cost)
    machines.observe(_update_cost)
    storage_per_user.observe(_update_cost)
    mem_per_user.observe(_update_cost)

    # Show it
    fig.title = 'Draw your usage pattern over time.'
    # FIXME autoscaling when it's ready
    display(users, machines, mem_per_user, storage_per_user, persistent, fig, hr,
            text_cost_machine, text_avg_num_machine, text_cost_storage, text_cost_total)
    return fig
Beispiel #5
0
    def create_fig(self, ts):

        if self.ptype != 'PCA' and self.dims == None:
            ts.sort_index(inplace=True)
            df = ts.reset_index()  # time = ts.Time

        else:
            df = ts
        self.xd = df[self.xlabel]
        self.yd = df[self.cols].T

        if self.ptype == 'PCA' or self.dims is not None:
            pplt = Scatter(x=self.xd.values.ravel(), y=self.yd.values.ravel(), scales={'x': self.xScale, \
            'y': self.yScale, 'color': ColorScale(scheme=self.scheme)}, selected_style={'opacity': '1'}, \
            unselected_style={'opacity': '0.2'},color = self.colors, default_size=32)

        elif not self.ptype:
            pplt = Lines(x=self.xd, y=self.yd, scales={'x': self.xScale, 'y': self.yScale}, labels=self.legends,
                         display_legend=True, line_style=self.linestyle, stroke_width = 1, marker = 'circle', \
                         interpolation = self.interp)
            # {‘linear’, ‘basis’, ‘cardinal’, ‘monotone’}
        else:
            pplt = Lines(x=self.xd, y=self.yd, scales={'x': self.xScale, 'y': self.yScale}, labels=self.legends, \
                         display_legend=True, line_style=self.linestyle, selected_style={'opacity': '1'}, \
                         unselected_style={'opacity': '0.2'},interpolation=self.interp)
            # enable_hover=True)  # axes_options=axes_options)

        x_axis = Axis(scale=self.xScale, label=self.xlabel, grid_lines='none')
        y_axis = Axis(scale=self.yScale,
                      label=self.ylabel,
                      orientation='vertical',
                      grid_lines='none')
        c_axis = ColorAxis(scale=ColorScale(scheme=self.scheme),
                           orientation='vertical',
                           side='right')

        axis = [x_axis, y_axis, c_axis] if isinstance(
            pplt, Scatter) else [x_axis, y_axis]

        if self.debug:
            margin = dict(top=0, bottom=40, left=50, right=50)
        else:
            margin = dict(top=0, bottom=50, left=50, right=50)

        self.fig = Figure(marks=[pplt],
                          axes=axis,
                          legend_location='top-right',
                          fig_margin=margin)  # {'top':50,'left':60})

        if self.debug:
            self.deb = HTML()

        y = getattr(self, "vbox", None)
        if y is not None:
            box_layout = Layout(display='flex',
                                flex_flow='column',
                                align_items='stretch')
            if self.debug:
                self.vbox = VBox(
                    [self.selection_interacts, self.fig, self.deb],
                    layout=box_layout)
            else:
                self.vbox = VBox([self.selection_interacts, self.fig],
                                 layout=box_layout)
Beispiel #6
0
    def init_dashboard(self):
        from bqplot import DateScale, LinearScale, DateScale, Axis, Lines, Figure, Tooltip
        from bqplot.colorschemes import CATEGORY10, CATEGORY20
        from bqplot.toolbar import Toolbar
        from ipywidgets import VBox, Tab
        from IPython.display import display
        from tornado import gen

        cpu_sx = LinearScale()
        cpu_sy = LinearScale()
        cpu_x = Axis(label='Time (s)', scale=cpu_sx)
        cpu_y = Axis(label='CPU Usage (%)',
                     scale=cpu_sy,
                     orientation='vertical')
        mem_sx = LinearScale()
        mem_sy = LinearScale()
        mem_x = Axis(label='Time (s)', scale=mem_sx)
        mem_y = Axis(label='Memory Usage (MB)',
                     scale=mem_sy,
                     orientation='vertical')
        thru_sx = LinearScale()
        thru_sy = LinearScale()
        thru_x = Axis(label='Time (s)', scale=thru_sx)
        thru_y = Axis(label='Data Processed (MB)',
                      scale=thru_sy,
                      orientation='vertical')

        colors = CATEGORY20
        tt = Tooltip(fields=['name'], labels=['Filter Name'])
        self.cpu_lines = {
            str(n): Lines(labels=[str(n)],
                          x=[0.0],
                          y=[0.0],
                          colors=[colors[i]],
                          tooltip=tt,
                          scales={
                              'x': cpu_sx,
                              'y': cpu_sy
                          })
            for i, n in enumerate(self.other_nodes)
        }
        self.mem_lines = {
            str(n): Lines(labels=[str(n)],
                          x=[0.0],
                          y=[0.0],
                          colors=[colors[i]],
                          tooltip=tt,
                          scales={
                              'x': mem_sx,
                              'y': mem_sy
                          })
            for i, n in enumerate(self.other_nodes)
        }
        self.thru_lines = {
            str(n): Lines(labels=[str(n)],
                          x=[0.0],
                          y=[0.0],
                          colors=[colors[i]],
                          tooltip=tt,
                          scales={
                              'x': thru_sx,
                              'y': thru_sy
                          })
            for i, n in enumerate(self.other_nodes)
        }

        self.cpu_fig = Figure(marks=list(self.cpu_lines.values()),
                              axes=[cpu_x, cpu_y],
                              title='CPU Usage',
                              animation_duration=50)
        self.mem_fig = Figure(marks=list(self.mem_lines.values()),
                              axes=[mem_x, mem_y],
                              title='Memory Usage',
                              animation_duration=50)
        self.thru_fig = Figure(marks=list(self.thru_lines.values()),
                               axes=[thru_x, thru_y],
                               title='Data Processed',
                               animation_duration=50)

        tab = Tab()
        tab.children = [self.cpu_fig, self.mem_fig, self.thru_fig]
        tab.set_title(0, 'CPU')
        tab.set_title(1, 'Memory')
        tab.set_title(2, 'Throughput')
        display(tab)

        perf_queue = Queue()
        self.exit_perf = Event()

        def wait_for_perf_updates(q, exit, cpu_lines, mem_lines, thru_lines):
            while not exit.is_set():
                messages = []

                while not exit.is_set():
                    try:
                        messages.append(q.get(False))
                    except queue.Empty as e:
                        time.sleep(0.05)
                        break

                for message in messages:
                    filter_name, time_val, cpu, mem_info, processed = message
                    mem = mem_info[0] / 2.**20
                    vmem = mem_info[1] / 2.**20
                    proc = processed / 2.**20
                    cpu_lines[filter_name].x = np.append(
                        cpu_lines[filter_name].x, [time_val.total_seconds()])
                    cpu_lines[filter_name].y = np.append(
                        cpu_lines[filter_name].y, [cpu])
                    mem_lines[filter_name].x = np.append(
                        mem_lines[filter_name].x, [time_val.total_seconds()])
                    mem_lines[filter_name].y = np.append(
                        mem_lines[filter_name].y, [mem])
                    thru_lines[filter_name].x = np.append(
                        thru_lines[filter_name].x, [time_val.total_seconds()])
                    thru_lines[filter_name].y = np.append(
                        thru_lines[filter_name].y, [proc])

        for n in self.other_nodes:
            n.perf_queue = perf_queue

        self.perf_thread = Thread(target=wait_for_perf_updates,
                                  args=(perf_queue, self.exit_perf,
                                        self.cpu_lines, self.mem_lines,
                                        self.thru_lines))
        self.perf_thread.start()
file_box = Box([save_btn, open_btn])
file_box.layout.display = 'flex'
file_box.layout.justify_content = 'flex-end'
file_box.layout.align_itmes = 'stretch'

control_box1 = VBox([HBox(m),HBox(params_box+[sound_chk, sound_btn]),
                     HBox([add_btn, art_slt, delete_btn, replace_btn, play_all_btn])])


fig_margin_default = {'top':40, 'bottom':40, 'left':40, 'right':40}
min_height_default = 10
min_width_default = 10
# Create Sound Wave plot container
x_time = LinearScale(min=0., max=100)
y_sound = LinearScale(min=-.5, max=.5)
ax_sound_y = Axis(label='Amplitude', scale=y_sound, orientation='vertical', side='left', grid_lines='solid')
ax_time_x = Axis(label='Time', scale=x_time, grid_lines='solid')
#Initialization
sound_line = Lines(x=[], y=[], colors=['Blue'],
                       scales={'x': x_time, 'y': y_sound}, visible=True)
fig_sound = plt.figure(marks=[sound_line], axes=[ax_time_x, ax_sound_y], title='Sound wave', fig_margin = fig_margin_default,
                    min_height = min_height_default, min_widht = min_width_default, preserve_aspect=True)

# Create Articulator Position Evolution Container
y_art = LinearScale(min=-3., max=3.)
ax_art_y = Axis(label='Position', scale=y_art, orientation='vertical', side='left', grid_lines='solid')
# ax_time_x = Axis(label='Time', scale=x_time, grid_lines='solid')
#Initialization
art_lines = Lines(x=[], y=[],  #colors=['Blue'],
                       scales={'x': x_time, 'y': y_art}, visible=True)
fig_art = plt.figure(marks=[art_lines], axes=[ax_time_x, ax_art_y], title='Articulators', fig_margin = fig_margin_default,
Beispiel #8
0
def polynomial_regression():
    """Polynomial regression example"""
    regression_config = {"degree": 1}
    sc_x = LinearScale(min=-100, max=100)
    sc_y = LinearScale(min=-100, max=100)
    scat = Scatter(
        x=[], y=[], scales={"x": sc_x, "y": sc_y}, colors=["orange"], enable_move=True
    )
    lin = Lines(
        x=[], y=[], scales={"x": sc_x, "y": sc_y}, line_style="dotted", colors=["blue"]
    )

    def update_line(change=None):
        if len(scat.x) == 0 or len(scat.y) == 0 or scat.x.shape != scat.y.shape:
            lin.x = []
            lin.y = []
            return
        pipe = make_pipeline(
            PolynomialFeatures(degree=regression_config["degree"]), LinearRegression()
        )

        pipe.fit(scat.x.reshape(-1, 1), scat.y)
        with lin.hold_sync():
            if len(lin.x) == 0:
                lin.x = np.linspace(sc_x.min, sc_x.max)
            lin.y = pipe.predict(np.linspace(sc_x.min, sc_x.max).reshape(-1, 1))

    update_line()
    # update line on change of x or y of scatter
    scat.observe(update_line, names=["x", "y"])
    with scat.hold_sync():
        scat.enable_move = False
        scat.interactions = {"click": "add"}
    ax_x = Axis(scale=sc_x, tick_format="0.0f", label="x")
    ax_y = Axis(scale=sc_y, tick_format="0.0f", orientation="vertical", label="y")

    fig = Figure(marks=[scat, lin], axes=[ax_x, ax_y])

    # reset reset_button
    reset_button = widgets.Button(description="Reset")

    def on_button_clicked(change=None):
        with scat.hold_sync():
            scat.x = []
            scat.y = []
        dropdown_w.value = None

    reset_button.on_click(on_button_clicked)

    # polynomial degree slider
    degree = widgets.IntSlider(
        value=regression_config["degree"], min=1, max=5, step=1, description="Degree"
    )

    def degree_change(change):
        regression_config["degree"] = change["new"]
        update_line()

    degree.observe(degree_change, names="value")

    # dropdown for dataset selection
    dropdown_w = widgets.Dropdown(
        options=fake_datasets(ndim=2), value=None, description="Dataset:"
    )

    def dropdown_on_change(change):
        if change["type"] == "change" and change["name"] == "value":
            if change["new"] is not None:
                x, y = fake_datasets(name=change["new"])
                with scat.hold_sync():
                    scat.x = x.flatten()
                    scat.y = y.flatten()

    dropdown_w.observe(dropdown_on_change)

    return VBox(
        (
            widgets.HTML("<h1>Polynomial regression</h1>"),
            reset_button,
            dropdown_w,
            degree,
            fig,
        )
    )
Beispiel #9
0
    def __init__(self):

        #the model restricts changing step size
        self.steps = 1 / 52

        # Line chart and histo
        self.x_sc = LinearScale()
        self.y_sc = LinearScale()

        self.x_sc_2 = LinearScale()
        self.y_sc_2 = LinearScale()

        self.ax_x = Axis(label='Weeks', scale=self.x_sc, grid_lines='dashed')
        self.ax_y = Axis(label='Rate',
                         scale=self.y_sc,
                         orientation='vertical',
                         grid_lines='dashed')
        self.ax_x_2 = Axis(label='Rate', scale=self.x_sc_2, grid_lines='none')
        self.ax_y_2 = Axis(label='Count',
                           scale=self.y_sc_2,
                           orientation='vertical',
                           grid_lines='dashed')

        #HW1F Charts
        self.line1 = Lines(x=[],
                           y=[[], []],
                           scales={
                               'x': self.x_sc,
                               'y': self.y_sc
                           },
                           stroke_width=3,
                           colors=['red', 'green'])
        self.line2 = Lines(x=[],
                           y=[[]],
                           scales={
                               'x': self.x_sc,
                               'y': self.y_sc
                           },
                           labels=['MC'])
        self.hist1 = Hist(sample=[],
                          scales={
                              'sample': self.x_sc_2,
                              'count': self.y_sc_2
                          },
                          bins=0)
        self.fig1 = Figure(axes=[self.ax_x, self.ax_y],
                           marks=[self.line1, self.line2],
                           title='Hull White 1 Factor Dynamics')
        self.fig2 = Figure(axes=[self.ax_x_2, self.ax_y_2],
                           marks=[self.hist1],
                           title='Final Distribution of Rates')

        #HW2F Charts
        self.line3 = Lines(x=[],
                           y=[[], []],
                           scales={
                               'x': self.x_sc,
                               'y': self.y_sc
                           },
                           stroke_width=3,
                           colors=['red', 'green'])
        self.line4 = Lines(x=[],
                           y=[[]],
                           scales={
                               'x': self.x_sc,
                               'y': self.y_sc
                           })
        self.hist2 = Hist(sample=[],
                          scales={
                              'sample': self.x_sc_2,
                              'count': self.y_sc_2
                          },
                          bins=0)
        self.fig3 = Figure(axes=[self.ax_x, self.ax_y],
                           marks=[self.line3, self.line4],
                           title='Hull White 2 Factor Dynamics')
        self.fig4 = Figure(axes=[self.ax_x_2, self.ax_y_2],
                           marks=[self.hist2],
                           title='Final Distribution of Rates')

        # Input widgets
        self.ZC1W = ipw.FloatText(description='1W',
                                  value=2.07,
                                  layout=ipw.Layout(width='17%',
                                                    height='100%'))
        self.ZC3M = ipw.FloatText(description='3M',
                                  value=2.29,
                                  layout=ipw.Layout(width='17%',
                                                    height='100%'))
        self.ZC6M = ipw.FloatText(description='6M',
                                  value=2.45,
                                  layout=ipw.Layout(width='17%',
                                                    height='100%'))
        self.ZC1Y = ipw.FloatText(description='1Y',
                                  value=2.7,
                                  layout=ipw.Layout(width='17%',
                                                    height='100%'))
        self.ZC2Y = ipw.FloatText(description='2Y',
                                  value=3.02,
                                  layout=ipw.Layout(width='17%',
                                                    height='100%'))
        self.ZC6Y = ipw.FloatText(description='6Y',
                                  value=4,
                                  layout=ipw.Layout(width='17%',
                                                    height='100%'))
        self.ZC10Y = ipw.FloatText(description='10Y',
                                   value=4.7,
                                   layout=ipw.Layout(width='17%',
                                                     height='100%'))

        self.normalVol = ipw.FloatSlider(value=100,
                                         min=0,
                                         max=1000,
                                         description='Normal Vol (bps)')
        self.normalVol2 = ipw.FloatSlider(value=100,
                                          min=0,
                                          max=1000,
                                          description='Normal Vol 2 (bps)')
        self.meanRev = ipw.FloatSlider(value=1,
                                       min=0.0001,
                                       max=10,
                                       steps=0.01,
                                       description='Mean Reversion')
        self.meanRev2 = ipw.FloatSlider(value=3,
                                        min=0.0001,
                                        max=10,
                                        steps=0.01,
                                        description='Mean Reversion 2')
        self.numPaths = ipw.IntSlider(value=50,
                                      min=1,
                                      max=500,
                                      step=1,
                                      description='Paths ')
        self.showPaths = ipw.IntSlider(value=5,
                                       min=0,
                                       max=100,
                                       step=1,
                                       description='Display Paths ')
        self.correlation = ipw.FloatSlider(value=.1,
                                           min=-1,
                                           max=1,
                                           steps=0.01,
                                           description='Correlation')

        # Layout with tabs
        self.tab = ipw.Tab()
        self.tab_contents = [0, 1]
        self.children = [
            ipw.VBox([
                ipw.HBox([
                    self.normalVol, self.meanRev, self.numPaths, self.showPaths
                ]),
                ipw.HBox([self.fig1, self.fig2])
            ]),
            ipw.VBox([
                ipw.HBox([
                    self.normalVol, self.normalVol2, self.meanRev,
                    self.meanRev2
                ]),
                ipw.HBox([self.correlation, self.numPaths, self.showPaths]),
                ipw.HBox([self.fig3, self.fig4])
            ])
        ]

        self.tab.children = self.children
        self.tab.set_title(0, 'HW1F')
        self.tab.set_title(1, 'HW2F')

        #Observers
        self.ZC1W.observe(self.process_wrapper, 'value')
        self.ZC3M.observe(self.process_wrapper, 'value')
        self.ZC6M.observe(self.process_wrapper, 'value')
        self.ZC1Y.observe(self.process_wrapper, 'value')
        self.ZC2Y.observe(self.process_wrapper, 'value')
        self.ZC6Y.observe(self.process_wrapper, 'value')
        self.ZC10Y.observe(self.process_wrapper, 'value')

        self.meanRev.observe(self.process_wrapper, 'value')
        self.meanRev2.observe(self.process_wrapper, 'value')
        self.normalVol.observe(self.process_wrapper, 'value')
        self.normalVol2.observe(self.process_wrapper, 'value')
        self.numPaths.observe(self.process_wrapper, 'value')
        self.showPaths.observe(self.process_wrapper, 'value')
        self.correlation.observe(self.process_wrapper, 'value')
        self.paths = self.numPaths.value

        self.form = ipw.VBox([
            ipw.HBox([
                self.ZC1W, self.ZC3M, self.ZC6M, self.ZC1Y, self.ZC2Y,
                self.ZC6Y, self.ZC10Y
            ],
                     layout={'width': '80%'}), self.tab
        ])

        display(self.form)
def simple_optimazation_app():
    population_cnt = 20
    itter_time = 50
    crossover_rate = 0.1
    drop_rate = 0.5
    mutation_rate = 0.1

    i = 0
    best_score = 0
    best_ind = []
    best_ind_ready = []
    population = []
    '''
    dynamic figure
    '''
    X = np.linspace(0, 1, 1000)
    y = np.array([target_function(x) for x in X])

    x_sc = LinearScale()
    y_sc = LinearScale()

    ref = Lines(x=X, y=y, scales={'x': x_sc, 'y': y_sc})
    # scatter = Scatter(x=[population], y=np.array([target_function(ind) for ind in population]),
    #                     scales={'x': x_sc, 'y': y_sc},
    #                     colors=['DarkOrange'], stroke='red',
    #                     stroke_width=0.4, default_size=20)
    scatter = Scatter(x=[],
                      y=[],
                      scales={
                          'x': x_sc,
                          'y': y_sc
                      },
                      colors=['DarkOrange'],
                      stroke='red',
                      stroke_width=0.4,
                      default_size=20)

    x_ax = Axis(label='X', scale=x_sc)
    y_ax = Axis(label='Y', scale=y_sc, orientation='vertical')

    x_ax.min = 0
    x_ax.max = 1
    x_ax.num_ticks = 7
    x_ax.grid_color = 'orangered'

    fig = Figure(marks=[ref, scatter],
                 title='A Figure',
                 axes=[x_ax, y_ax],
                 animation_duration=1000)
    # display(fig)
    # %%
    run_itter_slider = population_slider = widgets.IntSlider(
        value=50, description='#Iteration', min=1, max=100, step=1)

    run_btn = widgets.Button(description='Run', icon='play', disabled=True)

    population_cnt_slider = widgets.IntSlider(value=30,
                                              description='#Population',
                                              min=0,
                                              max=100,
                                              step=10)

    init_population_btn = widgets.Button(description='Initialize Population')

    descriptor1 = widgets.Label('crossover_rate')
    crossover_rate_slider = widgets.FloatSlider(value=0.1,
                                                description='',
                                                min=0,
                                                max=1.0,
                                                step=0.1)
    descriptor2 = widgets.Label('drop_rate')
    drop_rate_slider = widgets.FloatSlider(value=0.5,
                                           description='',
                                           min=0,
                                           max=1.0,
                                           step=0.1)
    descriptor3 = widgets.Label('mutation_rate')
    mutation_rate_slider = widgets.FloatSlider(value=0.3,
                                               description='',
                                               min=0,
                                               max=1.0,
                                               step=0.1)
    patch1 = widgets.HBox([descriptor1, crossover_rate_slider])
    patch2 = widgets.HBox([descriptor2, drop_rate_slider])
    patch3 = widgets.HBox([descriptor3, mutation_rate_slider])

    blank = widgets.Label('')

    run_out = widgets.Output(layout={
        'border': '1px solid black',
        'height': '50px'
    })
    row1 = widgets.VBox([population_cnt_slider, init_population_btn])
    row2 = widgets.HBox([patch1, patch2, patch3])
    row_n = widgets.VBox([run_itter_slider, run_btn])

    app = widgets.VBox([row1, blank, row2, blank, row_n, run_out, fig])

    # %%
    def initialize():
        nonlocal population, i
        population = np.random.rand(population_cnt_slider.value)
        scatter.x = population
        scatter.y = get_scores(scatter.x)
        i = 0
        fig.title = f'迭代{i}次\n'

    @run_out.capture()
    def update(itter_time=itter_time,
               crossover_rate=crossover_rate,
               drop_rate=drop_rate,
               mutation_rate=mutation_rate):
        nonlocal scatter, fig, best_score, best_ind_ready, best_ind, i
        for j in range(itter_time):
            new_population = select_and_crossover(
                population, crossover_rate=crossover_rate, drop_rate=drop_rate)
            new_population_ready = encode_all(new_population)

            new_population_ready = mutatie_all(new_population_ready,
                                               mutation_rate=mutation_rate)

            new_population = decode_all(new_population_ready)

            ind, score = get_best(new_population)
            if score > best_score:
                best_ind = ind
                best_score = score
                best_ind_ready = encode(best_ind)
            i += 1
        scatter.x = new_population
        scatter.y = get_scores(new_population)
        fig.title = f'迭代{i}次'  # + f'最优个体为: {best_ind_ready}; 函数值为:{best_score}'
        clear_output(wait=True)
        display(f'最优个体为: {best_ind_ready}; 函数值为:{best_score}')

    # %%
    # update()

    # %%
    def on_click_init(change):
        initialize()
        run_btn.disabled = False

    def on_click_run(change):
        update(itter_time=run_itter_slider.value,
               crossover_rate=crossover_rate_slider.value,
               drop_rate=drop_rate_slider.value,
               mutation_rate=mutation_rate_slider.value)

    init_population_btn.on_click(on_click_init)
    run_btn.on_click(on_click_run)
    return app
def gradient_descent():
    line_params = {"b": 0, "m": 0, "iter": 1}

    sc_x = LinearScale(min=-100, max=100)
    sc_y = LinearScale(min=-100, max=100)
    scat = Scatter(x=[],
                   y=[],
                   scales={
                       "x": sc_x,
                       "y": sc_y
                   },
                   colors=["orange"],
                   enable_move=True)
    lin = Lines(x=[], y=[], scales={"x": sc_x, "y": sc_y}, colors=["blue"])

    ax_x = Axis(scale=sc_x, tick_format="0.0f", label="x")
    ax_y = Axis(scale=sc_y,
                tick_format="0.0f",
                orientation="vertical",
                label="y")

    fig_function = Figure(marks=[scat, lin], axes=[ax_x, ax_y])

    sc_x_cost = LinearScale(min=0, max=100)
    sc_y_cost = LinearScale(min=0, max=100)
    lin_cost = Lines(x=[], y=[], scales={"x": sc_x_cost, "y": sc_y_cost})
    ax_x_cost = Axis(scale=sc_x_cost, tick_format="0.0f", label="iteration")
    ax_y_cost = Axis(
        scale=sc_y_cost,
        tick_format="0.0f",
        orientation="vertical",
        label="Mean Squared Error",
    )

    fig_cost = Figure(marks=[lin_cost], axes=[ax_x_cost, ax_y_cost])

    def draw_line():
        x = np.linspace(-100, 100)
        y = line_params["b"] + line_params["m"] * x
        with lin.hold_sync():
            lin.x = x
            lin.y = y

    play_button = widgets.Play(
        interval=100,
        value=0,
        min=0,
        max=100,
        step=1,
        repeat=True,
        description="Run gradient descent",
        disabled=False,
    )

    year_slider = widgets.IntSlider(min=0,
                                    max=100,
                                    step=1,
                                    description="Step",
                                    value=0,
                                    disabled=True)

    def mse():
        b = line_params["b"]
        m = line_params["m"]
        return (((scat.x * m + b) - scat.y)**2).mean()

    def play_change(change):
        b = line_params["b"]
        m = line_params["m"]
        b_gradient = 0
        m_gradient = 0
        n = len(scat.x)
        learning_rate = 0.0001
        for i in range(0, len(scat.x)):
            b_gradient += -(2 / n) * (scat.y[i] - ((m * scat.x[i]) + b))
            m_gradient += -(2 / n) * scat.x[i] * (scat.y[i] -
                                                  ((m * scat.x[i]) + m))
        b = b - (learning_rate * 500 * b_gradient)
        m = m - (learning_rate * m_gradient)

        line_params["b"] = b
        line_params["m"] = m
        lin_cost.x = np.append(np.array(lin_cost.x),
                               np.array([line_params["iter"]]))
        lin_cost.y = np.append(np.array(lin_cost.y), mse())
        sc_x_cost.min = np.min(lin_cost.x)
        sc_x_cost.max = np.max(lin_cost.x)
        sc_y_cost.min = 0
        sc_y_cost.max = np.max(lin_cost.y)

        line_params["iter"] = line_params["iter"] + 1

        draw_line()

    play_button.observe(play_change, "value")
    widgets.jslink((play_button, "value"), (year_slider, "value"))

    # reset reset_button
    reset_button = widgets.Button(description="Reset")

    def on_button_clicked(change=None):
        x, y = fake_datasets("Linear")
        with scat.hold_sync():
            scat.x = x.flatten()
            scat.y = y.flatten()
        with lin_cost.hold_sync():
            lin_cost.x = []
            lin_cost.y = []

        line_params["b"] = (np.random.random() - 0.5) * 100
        line_params["m"] = np.random.random() - 0.5
        line_params["iter"] = 1
        draw_line()

    on_button_clicked()

    reset_button.on_click(on_button_clicked)

    return VBox((
        widgets.HTML("<h1>Gradient Descent</h1>"),
        reset_button,
        HBox((Label("Run gradient descent"), play_button, year_slider)),
        HBox((fig_function, fig_cost)),
    ))
Beispiel #12
0
def plot_pulse_files(metafile, time=True, backend='bqplot'):
    '''
    plot_pulse_files(metafile)

    Helper function to plot a list of AWG files. A jupyter slider widget allows choice of sequence number.
    '''
    #If we only go one filename turn it into a list

    with open(metafile, 'r') as FID:
        meta_info = json.load(FID)
    fileNames = []
    for el in meta_info["instruments"].values():
        # Accomodate seq_file per instrument and per channel
        if isinstance(el, str):
            fileNames.append(el)
        elif isinstance(el, dict):
            for file in el.values():
                fileNames.append(file)

    line_names, num_seqs, data_dicts = extract_waveforms(fileNames, time=time)
    localname = os.path.split(fileNames[0])[1]
    sequencename = localname.split('-')[0]

    if backend == 'matplotlib':
        import matplotlib.pyplot as plt
        from ipywidgets import interact, IntSlider

        def update_plot(seq_ind):
            for line_name in line_names:
                dat = data_dicts[f"{line_name}_{seq_ind}"]
                plt.plot(dat['x'], dat['y'], label=line_name, linewidth=1.0)

        interact(update_plot,
                 seq_ind=IntSlider(min=1,
                                   max=num_seqs,
                                   step=1,
                                   value=1,
                                   description="Sequence Number"))

    elif backend == 'bqplot':
        from bqplot import DateScale, LinearScale, Axis, Lines, Figure, Tooltip
        from bqplot.colorschemes import CATEGORY10, CATEGORY20
        from ipywidgets import interact, IntSlider, VBox
        sx = LinearScale()
        sy = LinearScale(min=-1.0, max=2 * len(line_names) - 1.0)
        if time:
            ax = Axis(label='Time (ns)', scale=sx)
        else:
            ax = Axis(label="Samples", scale=sx)
        ay = Axis(label='Amplitude', scale=sy, orientation='vertical')

        colors = CATEGORY10 if len(line_names) < 10 else CATEGORY20
        lines = []
        tt = Tooltip(fields=['name'], labels=['Channel'])
        x_mult = 1.0e9 if time else 1
        for i, line_name in enumerate(line_names):
            dat = data_dicts[f"{line_name}_1"]
            lines.append(
                Lines(labels=[line_name],
                      x=x_mult * dat['x'],
                      y=dat['y'],
                      scales={
                          'x': sx,
                          'y': sy
                      },
                      tooltip=tt,
                      animate=False,
                      colors=[colors[i]]))

        slider = IntSlider(min=1,
                           max=num_seqs,
                           step=1,
                           description='Segment',
                           value=1)

        def segment_changed(change):
            for line, line_name in zip(lines, line_names):
                dat = data_dicts[f"{line_name}_{slider.value}"]
                line.x = x_mult * dat['x']
                line.y = dat['y']

        slider.observe(segment_changed, 'value')
        fig = Figure(marks=lines,
                     axes=[ax, ay],
                     title='Waveform Plotter',
                     animation_duration=50)
        return VBox([slider, fig])
Beispiel #13
0
    return np.exp(-(x-0.1) ** 2) * np.sin(6 * np.pi * x ** (3 / 4)) ** 2
X = np.linspace(0, 1, 1000)
y = np.array([target_function(x) for x in X])
population = np.random.rand(30)


x_sc = LinearScale()
y_sc = LinearScale()

ref = Lines(x=X, y=y, scales={'x': x_sc, 'y': y_sc})
scatter = Scatter(x=population, y=np.array([target_function(ind) for ind in population]), 
                    scales={'x': x_sc, 'y': y_sc},
                    colors=['DarkOrange'], stroke='red', 
                    stroke_width=0.4, default_size=20)

x_ax = Axis(label='X', scale=x_sc)
y_ax = Axis(label='Y', scale=y_sc, orientation='vertical')

x_ax.min = 0
x_ax.max = 1
x_ax.num_ticks = 7
x_ax.grid_color = 'orangered'

fig = Figure(marks=[ref, scatter], title='A Figure', axes=[x_ax, y_ax],
                animation_duration=1000)
fig
# %%
scatter.x=np.random.rand(30)
scatter.y=np.array([target_function(ind) for ind in scatter.x])

# %%
Beispiel #14
0
# %% [markdown]
# #### Defining Axes and Scales
#
# The inherent skewness of the income data favors the use of a `LogScale`. Also, since the color coding by regions does not follow an ordering, we use the `OrdinalColorScale`.

# %% {"collapsed": true}
x_sc = LogScale(min=income_min, max=income_max)
y_sc = LinearScale(min=life_exp_min, max=life_exp_max)
c_sc = OrdinalColorScale(domain=data['region'].unique().tolist(),
                         colors=CATEGORY10[:6])
size_sc = LinearScale(min=pop_min, max=pop_max)

# %% {"collapsed": true}
ax_y = Axis(label='Life Expectancy',
            scale=y_sc,
            orientation='vertical',
            side='left',
            grid_lines='solid')
ax_x = Axis(label='Income per Capita', scale=x_sc, grid_lines='solid')

# %% [markdown]
# #### Creating the Scatter Mark with the appropriate size and color parameters passed
#
# To generate the appropriate graph, we need to pass the population of the country to the `size` attribute and its region to the `color` attribute.

# %% {"collapsed": true}
# Start with the first year's data
cap_income, life_exp, pop = get_data(initial_year)

# %% {"collapsed": true}
wealth_scat = Scatter(x=cap_income,
Beispiel #15
0
    def create_fig(self, ts):

        ts.sort_index(inplace=True)

        df = ts.reset_index()  # time = ts.Time

        self.xd = df[self.xlabel]
        self.yd = df[self.cols].T

        if self.ptype == 'PCA':

            pplt = Scatter(x=self.xd,
                           y=self.yd,
                           scales={
                               'x': self.xScale,
                               'y': self.yScale
                           },
                           color=self.colors)  #labels=self.legends,
            #display_legend=True, line_style='solid', stroke_width = 0, marker = 'circle')

        elif not self.ptype:
            pplt = Lines(x=self.xd,
                         y=self.yd,
                         scales={
                             'x': self.xScale,
                             'y': self.yScale
                         },
                         labels=self.legends,
                         display_legend=True,
                         line_style='solid',
                         stroke_width=0,
                         marker='circle')

        else:
            pplt = Lines(x=self.xd,
                         y=self.yd,
                         scales={
                             'x': self.xScale,
                             'y': self.yScale
                         },
                         labels=self.legends,
                         display_legend=True,
                         line_style='solid',
                         selected_style={'opacity': '1'},
                         unselected_style={
                             'opacity': '0.2'
                         })  # enable_hover=True)  # axes_options=axes_options)

        x_axis = Axis(scale=self.xScale, label=self.xlabel, grid_lines='none')
        y_axis = Axis(scale=self.yScale,
                      label=self.ylabel,
                      orientation='vertical',
                      grid_lines='none')

        if self.debug:
            margin = dict(top=0, bottom=40, left=50, right=50)
        else:
            margin = dict(top=0, bottom=50, left=50, right=50)

        self.fig = Figure(marks=[pplt],
                          axes=[x_axis, y_axis],
                          legend_location='top-right',
                          fig_margin=margin)  # {'top':50,'left':60})

        if self.debug:
            self.deb = HTML()
        # self.deb2 = HTML()
        y = getattr(self, "vbox", None)
        if y is not None:
            box_layout = Layout(display='flex',
                                flex_flow='column',
                                align_items='stretch')
            if self.debug:
                self.vbox = VBox(
                    [self.selection_interacts, self.fig, self.deb],
                    layout=box_layout)
            else:
                self.vbox = VBox([self.selection_interacts, self.fig],
                                 layout=box_layout)
Beispiel #16
0
    def __init__(self, brain):
        self._brain = brain
        self._input_fmin = None
        self._input_fmid = None
        self._input_fmax = None
        self._btn_upd_mesh = None
        self._colors = None

        if brain.data['center'] is None:
            dt_min = brain.data['fmin']
            dt_max = brain.data['fmax']
        else:
            dt_min = -brain.data['fmax']
            dt_max = brain.data['fmax']

        self._lut = brain.data['lut']
        self._cbar_data = np.linspace(0, 1, self._lut.N)
        cbar_ticks = np.linspace(dt_min, dt_max, self._lut.N)
        color = np.array((self._cbar_data, self._cbar_data))
        cbar_w = 500
        cbar_fig_margin = {'top': 15, 'bottom': 15, 'left': 5, 'right': 5}
        self._update_colors()

        x_sc, col_sc = LinearScale(), ColorScale(colors=self._colors)
        ax_x = Axis(scale=x_sc)
        heat = HeatMap(x=cbar_ticks,
                       color=color,
                       scales={'x': x_sc, 'color': col_sc})

        self._add_inputs()
        fig_layout = widgets.Layout(width='%dpx' % cbar_w,
                                    height='60px')
        cbar_fig = Figure(axes=[ax_x],
                          marks=[heat],
                          fig_margin=cbar_fig_margin,
                          layout=fig_layout)

        def on_update(but_event):
            u"""Update button click event handler."""
            val_min = self._input_fmin.value
            val_mid = self._input_fmid.value
            val_max = self._input_fmax.value
            center = brain.data['center']
            time_idx = brain.data['time_idx']
            time_arr = brain.data['time']

            if not val_min < val_mid < val_max:
                raise ValueError('Incorrect relationship between' +
                                 ' fmin, fmid, fmax. Given values ' +
                                 '{0}, {1}, {2}'
                                 .format(val_min, val_mid, val_max))
            if center is None:
                # 'hot' or another linear color map
                dt_min = val_min
                dt_max = val_max
            else:
                # 'mne' or another divergent color map
                dt_min = -val_max
                dt_max = val_max

            self._lut = self._brain.update_lut(fmin=val_min, fmid=val_mid,
                                               fmax=val_max)
            k = 1 / (dt_max - dt_min)
            b = 1 - k * dt_max
            self._brain.data['k'] = k
            self._brain.data['b'] = b

            for v in brain.views:
                for h in brain.hemis:
                    if (time_arr is None) or (time_idx is None):
                        act_data = brain.data[h + '_array']
                    else:
                        act_data = brain.data[h + '_array'][:, time_idx]

                    smooth_mat = brain.data[h + '_smooth_mat']
                    act_data = smooth_mat.dot(act_data)

                    act_data = k * act_data + b
                    act_data = np.clip(act_data, 0, 1)
                    act_color_new = self._lut(act_data)
                    brain.overlays[h + '_' + v].color = act_color_new
            self._update_colors()
            x_sc, col_sc = LinearScale(), ColorScale(colors=self._colors)
            ax_x = Axis(scale=x_sc)

            heat = HeatMap(x=cbar_ticks,
                           color=color,
                           scales={'x': x_sc, 'color': col_sc})
            cbar_fig.axes = [ax_x]
            cbar_fig.marks = [heat]

        self._btn_upd_mesh.on_click(on_update)

        info_widget = widgets.VBox((cbar_fig,
                                    self._input_fmin,
                                    self._input_fmid,
                                    self._input_fmax,
                                    self._btn_upd_mesh))

        ipv.gcc().children += (info_widget,)
    def __init__(
        self,
        measures,
        x=None,
        y=None,
        c=None,
        mouseover=False,
        host="localhost",
        port=4090,
    ):
        """Interactive scatter plot visualisation - this is a base class,
        use either `ROIScatterViz` for one image with multiple ROIs
        or `ImageScatterViz` for a scatterplot with multiple images
        """
        self.port = port
        self.measures = measures
        self.columns = list(measures.columns)
        x_col = x if x else self.columns[0]
        y_col = y if y else self.columns[1]
        c_col = c if c else self.columns[2]

        selector_layout = widgets.Layout(height="40px", width="100px")
        self.x_selecta = widgets.Dropdown(
            options=self.columns,
            value=x_col,
            description="",
            disabled=False,
            layout=selector_layout,
        )
        self.y_selecta = widgets.Dropdown(
            options=self.columns,
            value=y_col,
            description="",
            disabled=False,
            layout=selector_layout,
        )

        self.c_selecta = widgets.Dropdown(
            options=self.columns,
            value=c_col,
            description="",
            disabled=False,
            layout=selector_layout,
        )

        self.sheet = widgets.Output()
        self.thumbs = {}
        self.goto = widgets.HTML("")
        if is_datetime(self.measures, x_col):
            x_sc = DateScale()
        else:
            x_sc = LinearScale()

        if is_datetime(self.measures, y_col):
            y_sc = DateScale()
        else:
            y_sc = LinearScale()

        if is_datetime(self.measures, c_col):
            c_sc = DateColorScale(scheme="viridis")
        else:
            c_sc = ColorScale()

        self.scat = Scatter(
            x=self.measures[self.x_selecta.value],
            y=self.measures[self.y_selecta.value],
            color=self.measures[self.c_selecta.value],
            scales={
                "x": x_sc,
                "y": y_sc,
                "color": c_sc,
            },
            names=self.measures.index,
            display_names=False,
            fill=True,
            default_opacities=[
                0.8,
            ],
        )
        self.ax_x = Axis(scale=x_sc, label=self.x_selecta.value)
        self.ax_y = Axis(scale=y_sc,
                         label=self.y_selecta.value,
                         orientation="vertical")
        self.ax_c = ColorAxis(
            scale=c_sc,
            label=self.c_selecta.value,
            orientation="vertical",
            offset={
                "scale": y_sc,
                "value": 100
            },
        )
        self.fig = Figure(
            marks=[
                self.scat,
            ],
            axes=[self.ax_x, self.ax_y, self.ax_c],
        )
        self.scat.on_element_click(self.goto_db)
        self.scat.on_element_click(self.show_data)
        if mouseover:
            self.scat.on_hover(self.show_thumb)
            self.scat.tooltip = widgets.HTML("")
        self.x_selecta.observe(self.update_scatter)
        self.y_selecta.observe(self.update_scatter)
        self.c_selecta.observe(self.update_scatter)
        self.connector = OMEConnect(host=host, port=4064)
        self.connector.gobtn.on_click(self.setup_graph)
        super().__init__([self.connector])