def __init__(self):

        micron_units = Label(
            'micron')  # use "option m" (Mac, for micro symbol)

        constWidth = '180px'
        tab_height = '500px'
        stepsize = 10

        #style = {'description_width': '250px'}
        style = {'description_width': '25%'}
        layout = {'width': '400px'}

        name_button_layout = {'width': '25%'}
        widget_layout = {'width': '15%'}
        widget2_layout = {'width': '10%'}
        units_button_layout = {'width': '15%'}
        desc_button_layout = {'width': '45%'}
        divider_button_layout = {'width': '40%'}

        param_name1 = Button(description='random_seed',
                             disabled=True,
                             layout=name_button_layout)
        param_name1.style.button_color = 'lightgreen'

        self.random_seed = IntText(value=0,
                                   step=1,
                                   style=style,
                                   layout=widget_layout)

        div_row1 = Button(description='---Initialization settings---',
                          disabled=True,
                          layout=divider_button_layout)

        param_name2 = Button(description='number_of_cells',
                             disabled=True,
                             layout=name_button_layout)
        param_name2.style.button_color = 'tan'

        self.number_of_cells = IntText(value=50,
                                       step=1,
                                       style=style,
                                       layout=widget_layout)

        units_button1 = Button(description='',
                               disabled=True,
                               layout=units_button_layout)
        units_button1.style.button_color = 'lightgreen'
        units_button2 = Button(description='',
                               disabled=True,
                               layout=units_button_layout)
        units_button2.style.button_color = 'lightgreen'
        units_button3 = Button(description='',
                               disabled=True,
                               layout=units_button_layout)
        units_button3.style.button_color = 'tan'

        desc_button1 = Button(description='',
                              tooltip='',
                              disabled=True,
                              layout=desc_button_layout)
        desc_button1.style.button_color = 'lightgreen'
        desc_button2 = Button(
            description='initial number of cells (for each cell type)',
            tooltip='initial number of cells (for each cell type)',
            disabled=True,
            layout=desc_button_layout)
        desc_button2.style.button_color = 'tan'

        row1 = [param_name1, self.random_seed, units_button1, desc_button1]
        row2 = [param_name2, self.number_of_cells, units_button3, desc_button2]

        box_layout = Layout(display='flex',
                            flex_flow='row',
                            align_items='stretch',
                            width='100%')
        box1 = Box(children=row1, layout=box_layout)
        box2 = Box(children=row2, layout=box_layout)

        self.tab = VBox([
            box1,
            div_row1,
            box2,
        ])
Beispiel #2
0
    def __init__(self,
                 data,
                 figsize=None,
                 title=None,
                 reverse_scroll=False,
                 show=True):
        data.check1D()
        self.list = {}  # list is actually a dictionnary - sorry about that
        # hold {pivot_in_points: (ph0, ph1)}
        if data.itype == 0:
            jsalert('Data is Real - Please redo Fourier Transform')
            return
        super().__init__(data,
                         figsize=figsize,
                         title=title,
                         reverse_scroll=reverse_scroll,
                         show=False)
        self.p0 = widgets.FloatSlider(description='P0:',
                                      min=-200,
                                      max=200,
                                      step=1,
                                      layout=Layout(width='50%'),
                                      continuous_update=HEAVY)

        self.p1 = widgets.FloatSlider(description='P1:',
                                      min=-10000000,
                                      max=10000000,
                                      step=1000.0,
                                      layout=Layout(width='100%'),
                                      continuous_update=HEAVY)
        self.p1 = FloatButt(4, description='P1  ', layout=Layout(width='100%'))
        self.p2 = FloatButt(4, description='P2  ', layout=Layout(width='100%'))
        P1, P2 = firstguess(data)
        self.p1.value, self.p2.value = round(P1), round(P2)
        if self.data.axis1.currentunit == 'm/z':
            pvmin = self.data.axis1.lowmass
            pvmax = self.data.axis1.highmass
        else:
            pvmin = 0
            pvmax = self.data.axis1.itoc(self.data.size1)
        self.pivot = widgets.BoundedFloatText(description='Pivot',
                                              value=0,
                                              min=self.data.axis1.itoc(0),
                                              max=self.data.axis1.itoc(
                                                  self.data.size1),
                                              step=0.1,
                                              layout=Layout(width='20%'))
        self.cancel = widgets.Button(description="Cancel",
                                     button_style='warning')
        self.cancel.on_click(self.on_cancel)
        # remove done button and create an Apply one
        self.done.close()
        self.apply = widgets.Button(description="Done", button_style='success')
        self.apply.on_click(self.on_Apply)
        # list managment
        self.clearlist = widgets.Button(description="Clear")
        self.clearlist.on_click(self.on_clearlist)
        self.addlist = widgets.Button(description="Add Entry")
        self.addlist.on_click(self.on_addlist)
        self.listlabel = Label(value="0 entry")
        self.printlist = widgets.Button(description="Print")
        self.printlist.on_click(self.on_printlist)

        # draw HBox
        if PHASEQUAD:
            phbutton = [self.p1, self.p2]
        else:
            phbutton = [self.p1]
        self.children = [
            VBox([
                HBox([
                    self.apply,
                    self.cancel,
                ]),
                HBox([
                    HTML('<i>Phase List Managment</i>'), self.clearlist,
                    self.addlist, self.listlabel, self.printlist
                ]),
                HBox([
                    self.p0, self.pivot,
                    HTML('<i>set with right-click on spectrum</i>')
                ]), *phbutton,
                HBox([
                    VBox([self.blank, self.reset, self.scale]), self.fig.canvas
                ])
            ])
        ]
        # add interaction
        for w in [self.p0, self.p1.field, self.p2.field, self.scale]:
            w.observe(self.ob)
        self.pivot.observe(self.on_movepivot)

        # add click event on spectral window
        def on_press(event):
            if event.button == 3:
                self.pivot.value = event.xdata

        cids = self.fig.canvas.mpl_connect('button_press_event', on_press)
        self.lp0, self.lp1, self.lp2, self.lpv = self.ppivot()
        if show: self.show()
Beispiel #3
0
 def __init__(self, gen, hide_graph=False):
     super().__init__(gen, NBProgressBar)
     self.text = HTML()
     self.vbox = VBox([self.first_bar.box, self.text])
     self.hide_graph = hide_graph
    def __init__(self,
                 inputs,
                 targets,
                 indexes=None,
                 keyboard_shortcuts=True,
                 save_hook=None,
                 vertical=False):

        self.path = ''

        self.dirty_uindexes = set()

        self.save_hook = save_hook

        self.datamanager = DataManager(inputs, targets, indexes)

        slider = IntSlider(min=0, max=0)

        self.slider = slider

        self.prevbtn = Button(description='< Previous')
        self.nextbtn = Button(description='Next >')

        self.input_widgets = [
            dw.get_widget() for dw in self.datamanager.get_inputs()
        ]
        self.target_widgets = [
            dw.get_widget() for dw in self.datamanager.get_targets()
        ]

        self.add_class('innotater-base')

        cbar_widgets = [self.prevbtn, slider, self.nextbtn]
        if self.save_hook:
            self.savebtn = Button(description='Save', disabled=True)
            cbar_widgets.append(self.savebtn)

        controlbar_widget = HBox(cbar_widgets)
        controlbar_widget.add_class('innotater-controlbar')

        InnotaterBox = HBox
        if vertical:
            InnotaterBox = VBox
            self.add_class('innotater-base-vertical')

        super().__init__([
            InnotaterBox([VBox(self.input_widgets),
                          VBox(self.target_widgets)]), controlbar_widget
        ])

        widgets.jslink((slider, 'value'), (self, 'index'))

        self._observe_targets(self.datamanager.get_targets())

        for dw in list(self.datamanager.get_all()):
            dw.post_widget_create(self.datamanager)

        self.prevbtn.on_click(lambda c: self.move_slider(-1))
        self.nextbtn.on_click(lambda c: self.move_slider(1))

        if self.save_hook:
            self.savebtn.on_click(lambda c: self.save_hook_fire())

        self.slider.max = self.datamanager.get_data_len() - 1

        self.index = 0
        self.keyboard_shortcuts = keyboard_shortcuts

        self.on_msg(self.handle_message)

        self.suspend_observed_changes = False
        self.update_ui()
Beispiel #5
0
#### -----
sep = widgets.HTML(value="<h4></h4>")

fem_par = params = VBox(children=[
    fem_header,
    wl_box,
    pola_dropdown,
    angle_slider,
    mesh_slider,
    hx_box,
    hy_box,
    target_x_box,
    target_y_box,
    epsmin_re_box,
    epsmin_im_box,
    epsmax_re_box,
    epsmax_im_box,
    ## -------
    opt_header,
    maxeval_slider,
    Nitmax_slider,
    rfilt_box,
    starting_dropdown,
    p0_slider,
    ## -------
    sep,
    run_button,
])

conv_plt = go.FigureWidget()
conv_plt.add_scatter(fill="tozeroy")
Beispiel #6
0
    def _widget(self):
        """ Create IPython widget for display within a notebook """
        try:
            return self._cached_widget
        except AttributeError:
            pass

        try:
            from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion
        except ImportError:
            self._cached_widget = None
            return None

        layout = Layout(width="150px")

        if self.dashboard_link:
            dashboard_link = (
                '<p><b>Dashboard: </b><a href="%s" target="_blank">%s</a></p>\n'
                % (
                    self.dashboard_link,
                    self.dashboard_link,
                ))
        else:
            dashboard_link = ""

        if self.jupyter_link:
            jupyter_link = (
                '<p><b>Jupyter: </b><a href="%s" target="_blank">%s</a></p>\n'
                % (
                    self.jupyter_link,
                    self.jupyter_link,
                ))
        else:
            jupyter_link = ""

        title = "<h2>%s</h2>" % self._cluster_class_name
        title = HTML(title)
        dashboard = HTML(dashboard_link)
        jupyter = HTML(jupyter_link)

        status = HTML(self._widget_status(), layout=Layout(min_width="150px"))

        if self._supports_scaling:
            request = IntText(self.initial_node_count,
                              description="Nodes",
                              layout=layout)
            scale = Button(description="Scale", layout=layout)

            minimum = IntText(0, description="Minimum", layout=layout)
            maximum = IntText(0, description="Maximum", layout=layout)
            adapt = Button(description="Adapt", layout=layout)

            accordion = Accordion(
                [HBox([request, scale]),
                 HBox([minimum, maximum, adapt])],
                layout=Layout(min_width="500px"),
            )
            accordion.selected_index = None
            accordion.set_title(0, "Manual Scaling")
            accordion.set_title(1, "Adaptive Scaling")

            def adapt_cb(b):
                self.adapt(minimum=minimum.value, maximum=maximum.value)
                update()

            adapt.on_click(adapt_cb)

            def scale_cb(b):
                with log_errors():
                    n = request.value
                    with suppress(AttributeError):
                        self._adaptive.stop()
                    self.scale(n)
                    update()

            scale.on_click(scale_cb)
        else:
            accordion = HTML("")

        box = VBox([title, HBox([status, accordion]), jupyter, dashboard])

        self._cached_widget = box

        def update():
            self.close_when_disconnect()
            status.value = self._widget_status()

        pc = PeriodicCallback(update, 500)  # , io_loop=self.loop)
        self.periodic_callbacks["cluster-repr"] = pc
        pc.start()

        return box
Beispiel #7
0
def plot_slices_matplotlib(batch,
                           predict=None,
                           select=None,
                           size=20,
                           grid=True):
    ''' Plot slices and mask with interact
    '''
    if isinstance(size, int):
        size1, size2 = size, size
    else:
        size1, size2 = size
    select = [1, 2, 3, 4, 5, 6, 7] if select is None else select
    indices = list(batch.indices)
    n_s = batch.images_shape[0][0]
    if predict is not None:
        try:
            predict = np.squeeze(predict, axis=1)
        except ValueError:
            predict = np.squeeze(predict, axis=-1)
    batch.fetch_nodules_from_mask()
    center_scaled = (
        np.abs(batch.nodules.nodule_center - batch.nodules.origin) /
        batch.nodules.spacing)
    nod_size_scaled = (np.rint(batch.nodules.nodule_size /
                               batch.nodules.spacing)).astype(np.int)
    nods = np.concatenate([
        batch.nodules.patient_pos.reshape(-1, 1),
        center_scaled.astype(np.int), nod_size_scaled[:, 0].reshape(-1, 1)
    ],
                          axis=1)
    nods = widgets.Dropdown(
        options=nods.tolist(),
        value=nods.tolist()[0],
        description='Nodule:',
    )
    slid = widgets.IntSlider(min=0, max=n_s, value=0)
    patient_id = widgets.Dropdown(
        options=indices,
        value=indices[0],
        description='Patient ID',
    )

    def update_loc(*args):
        slid.value = nods.value[1]
        patient_id.value = indices[nods.value[0]]

    nods.observe(update_loc, 'value')

    def upd(patient_id, n_slice, nods):
        img = batch.get(patient_id, 'images')[n_slice]
        mask = batch.get(patient_id, 'masks')[n_slice]
        nonlocal predict
        if predict is None:
            rows, cols = 1, len(select)
            pred = np.zeros_like(img)
            fig, axes = plt.subplots(rows,
                                     cols,
                                     squeeze=False,
                                     figsize=(size1, size2))
        else:
            rows = np.ceil(len(select) / 3).astype(np.int)
            cols = 3 if rows > 1 else len(select)
            fig, axes = plt.subplots(rows,
                                     cols,
                                     squeeze=False,
                                     figsize=(size1, size2))
            # where fun begins :D
            if 4 or 5 or 6 in select:
                pred = predict[indices.index(patient_id), n_slice, ...]
        if grid:
            inv_spacing = 1 / batch.get(patient_id, 'spacing').reshape(-1)[1:]
            step_mult = 10
            xticks = np.arange(0, img.shape[0], step_mult * inv_spacing[0])
            yticks = np.arange(0, img.shape[1], step_mult * inv_spacing[1])
        all_plots = {
            1: {
                'args': (img, plt.cm.bone)
            },
            2: {
                'args': (mask, plt.cm.bone)
            },
            3: {
                'args': (mask * img, plt.cm.bone)
            },
            4: {
                'args': (img + mask * 300, plt.cm.seismic)
            },
            5: {
                'args': (pred, plt.cm.bone)
            },
            6: {
                'args': (pred * img, plt.cm.bone)
            },
            7: {
                'args': (img + pred * 300, plt.cm.seismic)
            }
        }
        i = 0
        for r in range(rows):
            for c in range(cols):
                axes[r][c].imshow(*all_plots[select[i]]['args'])
                axes[r][c].set_xticks(xticks, minor=True)
                axes[r][c].set_yticks(yticks, minor=True)
                axes[r][c].grid(color='r',
                                linewidth=1.5,
                                alpha=0.15,
                                which='minor')
                i += 1
                if i == len(select):
                    break
        fig.subplots_adjust(left=None,
                            bottom=0.1,
                            right=None,
                            top=0.4,
                            wspace=0.1,
                            hspace=0.1)
        flush_figures()

    w = interactive(upd, patient_id=patient_id, n_slice=slid, nods=nods)
    w.children = (VBox([HBox(w.children[:-1]), w.children[-1]]), )
    display(w)
Beispiel #8
0
    def __init__(self):
        
        micron_units = Label('micron')   # use "option m" (Mac, for micro symbol)

        constWidth = '180px'
        tab_height = '500px'
        stepsize = 10

        #style = {'description_width': '250px'}
        style = {'description_width': '25%'}
        layout = {'width': '400px'}

        name_button_layout={'width':'25%'}
        widget_layout = {'width': '15%'}
        units_button_layout ={'width':'15%'}
        desc_button_layout={'width':'45%'}

        param_name1 = Button(description='random_seed', disabled=True, layout=name_button_layout)
        param_name1.style.button_color = 'lightgreen'

        self.random_seed = IntText(
          value=0,
          step=1,
          style=style, layout=widget_layout)

        param_name2 = Button(description='ecm_file', disabled=True, layout=name_button_layout)
        param_name2.style.button_color = 'tan'

        self.ecm_file = Text(
          value='ecm.txt',
          style=style, layout=widget_layout)

        param_name3 = Button(description='bnd_file', disabled=True, layout=name_button_layout)
        param_name3.style.button_color = 'lightgreen'

        self.bnd_file = Text(
          value='./config/boolean_network/ECM_mod.bnd',
          style=style, layout=widget_layout)

        param_name4 = Button(description='cfg_file', disabled=True, layout=name_button_layout)
        param_name4.style.button_color = 'tan'

        self.cfg_file = Text(
          value='./config/boolean_network/ECM_mod.bnd.cfg',
          style=style, layout=widget_layout)

        param_name5 = Button(description='init_cells_filename', disabled=True, layout=name_button_layout)
        param_name5.style.button_color = 'lightgreen'

        self.init_cells_filename = Text(
          value='./config/init.txt',
          style=style, layout=widget_layout)

        param_name6 = Button(description='x_threshold', disabled=True, layout=name_button_layout)
        param_name6.style.button_color = 'tan'

        self.x_threshold = FloatText(
          value=1.,
          step=0.1,
          style=style, layout=widget_layout)

        units_button1 = Button(description='', disabled=True, layout=units_button_layout) 
        units_button1.style.button_color = 'lightgreen'
        units_button2 = Button(description='', disabled=True, layout=units_button_layout) 
        units_button2.style.button_color = 'tan'
        units_button3 = Button(description='', disabled=True, layout=units_button_layout) 
        units_button3.style.button_color = 'lightgreen'
        units_button4 = Button(description='', disabled=True, layout=units_button_layout) 
        units_button4.style.button_color = 'tan'
        units_button5 = Button(description='', disabled=True, layout=units_button_layout) 
        units_button5.style.button_color = 'lightgreen'
        units_button6 = Button(description='', disabled=True, layout=units_button_layout) 
        units_button6.style.button_color = 'tan'

        desc_button1 = Button(description='' , tooltip='', disabled=True, layout=desc_button_layout) 
        desc_button1.style.button_color = 'lightgreen'
        desc_button2 = Button(description='' , tooltip='', disabled=True, layout=desc_button_layout) 
        desc_button2.style.button_color = 'tan'
        desc_button3 = Button(description='' , tooltip='', disabled=True, layout=desc_button_layout) 
        desc_button3.style.button_color = 'lightgreen'
        desc_button4 = Button(description='' , tooltip='', disabled=True, layout=desc_button_layout) 
        desc_button4.style.button_color = 'tan'
        desc_button5 = Button(description='' , tooltip='', disabled=True, layout=desc_button_layout) 
        desc_button5.style.button_color = 'lightgreen'
        desc_button6 = Button(description='' , tooltip='', disabled=True, layout=desc_button_layout) 
        desc_button6.style.button_color = 'tan'

        row1 = [param_name1, self.random_seed, units_button1, desc_button1] 
        row2 = [param_name2, self.ecm_file, units_button2, desc_button2] 
        row3 = [param_name3, self.bnd_file, units_button3, desc_button3] 
        row4 = [param_name4, self.cfg_file, units_button4, desc_button4] 
        row5 = [param_name5, self.init_cells_filename, units_button5, desc_button5] 
        row6 = [param_name6, self.x_threshold, units_button6, desc_button6] 

        box_layout = Layout(display='flex', flex_flow='row', align_items='stretch', width='100%')
        box1 = Box(children=row1, layout=box_layout)
        box2 = Box(children=row2, layout=box_layout)
        box3 = Box(children=row3, layout=box_layout)
        box4 = Box(children=row4, layout=box_layout)
        box5 = Box(children=row5, layout=box_layout)
        box6 = Box(children=row6, layout=box_layout)

        self.tab = VBox([
          box1,
          box2,
          box3,
          box4,
          box5,
          box6,
        ])
Beispiel #9
0
    def __init__(self):

        self.output_dir = '.'
        #        self.output_dir = 'tmpdir'

        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot

        self.use_defaults = True
        self.svg_xmin = 0
        self.svg_xrange = 1500
        self.xmin = -750.
        self.xmax = 750.
        self.ymin = -750.
        self.ymax = 750.
        self.x_range = 1500.
        self.y_range = 1500.
        self.show_nucleus = 0
        self.show_edge = True

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        # define dummy size of mesh (set in the tool's primary module)
        self.numx = 0
        self.numy = 0

        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(
            width='900px',  # border='2px solid black',
            height=tab_height,
        )  #overflow_y='scroll')

        max_frames = 1
        # self.mcds_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)
        self.mcds_plot = interactive(self.plot_plots,
                                     frame=(0, max_frames),
                                     continuous_update=False)

        # "plot_size" controls the size of the tab height, not the plot (rf. figsize for that)
        # NOTE: the Substrates Plot tab has an extra row of widgets at the top of it (cf. Cell Plots tab)
        svg_plot_size = '700px'
        svg_plot_size = '600px'
        svg_plot_size = '700px'
        self.mcds_plot.layout.width = svg_plot_size
        self.mcds_plot.layout.height = svg_plot_size

        self.fontsize = 20

        self.max_frames = BoundedIntText(
            min=0,
            max=99999,
            value=max_frames,
            description='Max',
            layout=Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.field_min_max = {'dummy': [0., 1.]}
        # hacky I know, but make a dict that's got (key,value) reversed from the dict in the Dropdown below
        self.field_dict = {0: 'dummy'}

        self.mcds_field = Dropdown(
            options={'dummy': 0},
            value=0,
            #     description='Field',
            layout=Layout(width=constWidth))
        # print("substrate __init__: self.mcds_field.value=",self.mcds_field.value)
        #        self.mcds_field.observe(self.mcds_field_cb)
        self.mcds_field.observe(self.mcds_field_changed_cb)

        # self.field_cmap = Text(
        #     value='viridis',
        #     description='Colormap',
        #     disabled=True,
        #     layout=Layout(width=constWidth),
        # )
        self.field_cmap = Dropdown(
            options=['viridis', 'jet', 'YlOrRd'],
            value='viridis',
            #     description='Field',
            layout=Layout(width=constWidth))
        #self.field_cmap.observe(self.plot_substrate)
        #        self.field_cmap.observe(self.plot_substrate)
        self.field_cmap.observe(self.mcds_field_cb)

        self.cmap_fixed = Checkbox(
            description='Fix',
            disabled=False,
            #           layout=Layout(width=constWidth2),
        )

        self.save_min_max = Button(
            description='Save',  #style={'description_width': 'initial'},
            button_style=
            'success',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Save min/max for this substrate',
            disabled=True,
            layout=Layout(width='90px'))

        def save_min_max_cb(b):
            #            field_name = self.mcds_field.options[]
            #            field_name = next(key for key, value in self.mcds_field.options.items() if value == self.mcds_field.value)
            field_name = self.field_dict[self.mcds_field.value]
            #            print(field_name)
            #            self.field_min_max = {'oxygen': [0., 30.], 'glucose': [0., 1.], 'H+ ions': [0., 1.], 'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 1.]}
            self.field_min_max[field_name][0] = self.cmap_min.value
            self.field_min_max[field_name][1] = self.cmap_max.value
#            print(self.field_min_max)

        self.save_min_max.on_click(save_min_max_cb)

        self.cmap_min = FloatText(
            description='Min',
            value=0,
            step=0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_min.observe(self.mcds_field_cb)

        self.cmap_max = FloatText(
            description='Max',
            value=38,
            step=0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_max.observe(self.mcds_field_cb)

        def cmap_fixed_cb(b):
            if (self.cmap_fixed.value):
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.save_min_max.disabled = False
            else:
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.save_min_max.disabled = True
#            self.mcds_field_cb()

        self.cmap_fixed.observe(cmap_fixed_cb)

        field_cmap_row2 = HBox([self.field_cmap, self.cmap_fixed])

        #        field_cmap_row3 = HBox([self.save_min_max, self.cmap_min, self.cmap_max])
        items_auto = [
            self.save_min_max,  #layout=Layout(flex='3 1 auto', width='auto'),
            self.cmap_min,
            self.cmap_max,
        ]
        box_layout = Layout(display='flex',
                            flex_flow='row',
                            align_items='stretch',
                            width='80%')
        field_cmap_row3 = Box(children=items_auto, layout=box_layout)

        #        field_cmap_row3 = Box([self.save_min_max, self.cmap_min, self.cmap_max])

        # mcds_tab = widgets.VBox([mcds_dir, mcds_plot, mcds_play], layout=tab_layout)
        mcds_params = VBox([
            self.mcds_field, field_cmap_row2, field_cmap_row3, self.max_frames
        ])  # mcds_dir
        #        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3,])  # mcds_dir

        #        self.tab = HBox([mcds_params, self.mcds_plot], layout=tab_layout)
        #        self.tab = HBox([mcds_params, self.mcds_plot])

        help_label = Label('select slider: drag or left/right arrows')
        row1 = Box([
            help_label,
            Box([self.max_frames, self.mcds_field, self.field_cmap],
                layout=Layout(border='0px solid black',
                              width='50%',
                              height='',
                              align_items='stretch',
                              flex_direction='row',
                              display='flex'))
        ])
        row2 = Box([self.cmap_fixed, self.cmap_min, self.cmap_max],
                   layout=Layout(border='0px solid black',
                                 width='50%',
                                 height='',
                                 align_items='stretch',
                                 flex_direction='row',
                                 display='flex'))
        if (hublib_flag):
            self.download_button = Download('mcds.zip',
                                            style='warning',
                                            icon='cloud-download',
                                            tooltip='Download data',
                                            cb=self.download_cb)
            download_row = HBox([
                self.download_button.w,
                Label(
                    "Download all substrate data (browser must allow pop-ups)."
                )
            ])

            #        self.tab = VBox([row1, row2, self.mcds_plot])
            self.tab = VBox([row1, row2, self.mcds_plot, download_row])
        else:
            # self.tab = VBox([row1, row2])
            self.tab = VBox([row1, row2, self.mcds_plot])
Beispiel #10
0
                    if especie == d:
                        posicion_sp = e
                        break
                colores[posicion_sp] = 'purple'
            else:
                pass

            F.data[16].x = x
            F.data[16].y = y
            F.data[16].hovertext = [('Reads='+str(int(e))+' || '+LinajE_otu[i]) for e, i in zip(dff[muestra_sel], dff.Species)]

            F.data[16].marker.color = colores
            F.data[16].name = '<b>'+muestra_sel+'</b>'

            F.layout.annotations[-1].text = '<b>'+muestra_sel+'</b>'
            F.layout.annotations[-1].x = int(len(dff)/2)
            F.layout.annotations[-1].y = np.log2(dff2[muestra_sel][0])/2
        
        F.layout.template = bg_color[tema.value]


EspecieS_asv.observe(res, names="value")
EspecieS_otu.observe(res, names="value")
limite.observe(res, names="value")
tema.observe(res, names="value")
centro.observe(res, names="value")
tipo.observe(res, names="value")
all_samples.observe(res, names="value")

interactiveITS = VBox([HBox([widgets.Label('Samples:'), all_samples]), HBox([VBox([HBox([boton_data, VBox([EspecieS_asv, EspecieS_otu])]), limite, tema, centro, OUT, mostratdf]), VBox([F, OUT3])])])
Beispiel #11
0
def interact_gravity_Dike():
    s1 = FloatSlider(
        description=r"$\Delta\rho$",
        min=-5.0,
        max=5.0,
        step=0.1,
        value=1.0,
        continuous_update=False,
    )
    s2 = FloatSlider(
        description=r"$z_1$",
        min=0.1,
        max=4.0,
        step=0.1,
        value=1 / 3,
        continuous_update=False,
    )
    s3 = FloatSlider(
        description=r"$z_2$",
        min=0.1,
        max=5.0,
        step=0.1,
        value=4 / 3,
        continuous_update=False,
    )
    s4 = FloatSlider(description="b",
                     min=0.1,
                     max=5.0,
                     step=0.1,
                     value=1.0,
                     continuous_update=False)
    s5 = FloatSlider(
        description=r"$\beta$",
        min=-85,
        max=85,
        step=5,
        value=45,
        continuous_update=False,
    )
    s6 = FloatSlider(
        description="Step",
        min=0.005,
        max=0.10,
        step=0.005,
        value=0.01,
        continuous_update=False,
        readout_format=".3f",
    )
    b1 = ToggleButton(
        value=True,
        description="keep previous plots",
        disabled=False,
        button_style="",  # 'success', 'info', 'warning', 'danger' or ''
        tooltip="Click me",
        layout=Layout(width="20%"),
    )
    v1 = VBox([s1, s2, s3])
    v2 = VBox([s4, s5, s6])
    out1 = HBox([v1, v2, b1])
    out = interactive_output(
        drawfunction,
        {
            "delta_rho": s1,
            "z1": s2,
            "z2": s3,
            "b": s4,
            "beta": s5,
            "stationSpacing": s6,
            "B": b1,
        },
    )
    return VBox([out1, out])
Beispiel #12
0
    def __init__(self):
        # tab_height = '520px'
        # tab_layout = Layout(width='900px',   # border='2px solid black',
        #                     height=tab_height, overflow_y='scroll')

        self.output_dir = '.'

        constWidth = '180px'

        #        self.fig = plt.figure(figsize=(6, 6))
        # self.fig = plt.figure(figsize=(7, 7))

        max_frames = 1
        self.cells_plot = interactive(self.plot_cells,
                                      frame=(0, max_frames),
                                      continuous_update=False)

        # https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20List.html#Play-(Animation)-widget
        # play = widgets.Play(
        # #     interval=10,
        #     value=50,
        #     min=0,
        #     max=100,
        #     step=1,
        #     description="Press play",
        #     disabled=False
        # )
        # slider = widgets.IntSlider()
        # widgets.jslink((play, 'value'), (slider, 'value'))
        # widgets.HBox([play, slider])

        # "plot_size" controls the size of the tab height, not the plot (rf. figsize for that)
        plot_size = '500px'  # small:
        plot_size = '750px'  # medium
        plot_size = '700px'  # medium
        plot_size = '600px'  # medium
        self.cells_plot.layout.width = plot_size
        self.cells_plot.layout.height = plot_size
        self.use_defaults = True
        self.show_nucleus = 1  # 0->False, 1->True in Checkbox!
        self.show_edge = 1  # 0->False, 1->True in Checkbox!
        self.show_tracks = 0  # 0->False, 1->True in Checkbox!
        self.trackd = {
        }  # dictionary to hold cell IDs and their tracks: (x,y) pairs
        # self.scale_radius = 1.0
        # self.axes_min = 0
        # self.axes_max = 2000
        self.axes_min = -1000.0
        self.axes_max = 1000.  # TODO: get from input file
        self.axes_min = -500.0
        self.axes_max = 500.  # TODO: get from input file

        self.max_frames = BoundedIntText(
            min=0,
            max=99999,
            value=max_frames,
            description='Max',
            layout=Layout(width='160px'),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.show_nucleus_checkbox = Checkbox(
            description='nucleus',
            value=True,
            disabled=False,
            layout=Layout(width=constWidth),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.show_nucleus_checkbox.observe(self.show_nucleus_cb)

        self.show_edge_checkbox = Checkbox(
            description='edge',
            value=True,
            disabled=False,
            layout=Layout(width=constWidth),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.show_edge_checkbox.observe(self.show_edge_cb)

        self.show_tracks_checkbox = Checkbox(
            description='tracks',
            value=True,
            disabled=False,
            layout=Layout(width=constWidth),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        # self.show_tracks_checkbox.observe(self.show_tracks_cb)

        #        row1 = HBox([Label('(select slider: drag or left/right arrows)'),
        #            self.max_frames, VBox([self.show_nucleus_checkbox, self.show_edge_checkbox])])
        #            self.max_frames, self.show_nucleus_checkbox], layout=Layout(width='500px'))

        #        self.tab = VBox([row1,self.cells_plot], layout=tab_layout)

        items_auto = [
            Label('select slider: drag or left/right arrows'),
            self.max_frames,
            # self.show_nucleus_checkbox,
            # self.show_edge_checkbox,
            # self.show_tracks_checkbox,
        ]
        #row1 = HBox([Label('(select slider: drag or left/right arrows)'),
        #            max_frames, show_nucleus_checkbox, show_edge_checkbox],
        #            layout=Layout(width='800px'))
        box_layout = Layout(display='flex',
                            flex_flow='row',
                            align_items='stretch',
                            width='70%')
        row1 = Box(children=items_auto, layout=box_layout)

        #     if (hublib_flag):
        #         self.download_button = Download('svg.zip', style='warning', icon='cloud-download',
        #                                         tooltip='You need to allow pop-ups in your browser', cb=self.download_cb)
        #         download_row = HBox([self.download_button.w, Label("Download all cell plots (browser must allow pop-ups).")])
        # #        self.tab = VBox([row1, self.cells_plot, self.download_button.w], layout=tab_layout)
        # #        self.tab = VBox([row1, self.cells_plot, self.download_button.w])
        #         self.tab = VBox([row1, self.cells_plot, download_row])
        #     else:
        #         self.tab = VBox([row1, self.cells_plot])

        self.tab = VBox([row1, self.cells_plot])
Beispiel #13
0
    def __init__(self, data, title=None, figsize=None):
        super().__init__(data, title=title, create_children=False)
        self.isDOSY = isinstance(data.axis1, NPKData.LaplaceAxis)
        try:
            self.proj2 = data.projF2
        except:
            self.proj2 = data.proj(axis=2).real()
        try:
            self.proj1 = data.projF1
        except:
            self.proj1 = data.proj(axis=1).real()
        # Controls
        self.scale.min = 0.2
        self.posview = widgets.Checkbox(value=True,
                                        description='Positive',
                                        tooltip='Display Positive levels',
                                        layout=Layout(width='20%'))
        self.negview = widgets.Checkbox(value=False,
                                        description='Negative',
                                        tooltip='Display Negative levels',
                                        layout=Layout(width='20%'))
        self.cursors = widgets.Checkbox(
            value=False,
            description='Cursors',
            tooltip='show cursors (cpu intensive !)',
            layout=Layout(width='20%'))
        self.showlogo = widgets.Checkbox(description="Logo", value=True)

        def switchlogo(e):
            if self.showlogo.value:
                self.axlogo.set_visible(True)
            else:
                self.axlogo.set_visible(False)

        self.showlogo.observe(switchlogo)
        for w in (self.scale, self.posview, self.negview, self.cursors):
            w.observe(self.ob)
        # Grid
        grid = {'height_ratios': [1, 4], 'hspace': 0, 'wspace': 0}
        if self.isDOSY:
            if figsize is None:
                fsize = (10, 5)
            else:
                fsize = figsize
            grid['width_ratios'] = [7, 1]
        else:
            if figsize is None:
                fsize = (8, 8)
            else:
                fsize = figsize
            grid['width_ratios'] = [4, 1]


#        fig, self.axarr = plt.subplots(2, 1, sharex=True, figsize=fsize, gridspec_kw=grid)
# Figure
        plt.ioff()
        self.fig = plt.figure(figsize=fsize,
                              constrained_layout=False,
                              tight_layout=True)
        plt.ion()
        self.fig.canvas.toolbar_position = 'left'
        spec2 = gridspec.GridSpec(ncols=2, nrows=2, figure=self.fig, **grid)
        axarr = np.empty((2, 2), dtype=object)
        axarr[0, 0] = self.fig.add_subplot(spec2[0, 0])
        axarr[1, 0] = self.fig.add_subplot(spec2[1, 0], sharex=axarr[0, 0])
        axarr[1, 1] = self.fig.add_subplot(spec2[1, 1], sharey=axarr[1, 0])
        axarr[0, 1] = self.fig.add_subplot(spec2[0, 1])
        self.top_ax = axarr[0, 0]
        self.spec_ax = axarr[1, 0]
        self.side_ax = axarr[1, 1]
        self.axlogo = axarr[0, 1]
        self.axlogo.set_visible(False)
        self.multitop = None
        self.multiside = None
        self.ax = self.spec_ax
        # Children
        self.topbar = HBox([self.posview, self.negview, self.cursors])
        self.controlbar = VBox(
            [self.reset, self.scale, self.savepdf, self.done])
        self.middlebar = HBox([self.controlbar, self.fig.canvas])

        self.children = [VBox([self.topbar, self.middlebar])]
        self.set_on_redraw()
        self.disp(new=True)
Beispiel #14
0
    def __init__(self, data):
        if data.itype != 3:
            print(
                'Dataset should be complex along both axes, Phasing is not possible'
            )
            return
        super().__init__(data)
        self.data_ref = data
        # print('WARNING this tool is not functional/tested yet')
        # create additional widgets
        slidersize = Layout(width='500px')
        self.F1p0 = widgets.FloatSlider(min=-180,
                                        max=180,
                                        step=1.0,
                                        description='P0',
                                        continuous_update=HEAVY,
                                        layout=slidersize)
        self.F1p1 = widgets.FloatSlider(min=-250,
                                        max=250,
                                        step=1.0,
                                        description='P1',
                                        continuous_update=HEAVY,
                                        layout=slidersize)
        self.F2p0 = widgets.FloatSlider(min=-180,
                                        max=180,
                                        step=1.0,
                                        description='P0',
                                        continuous_update=HEAVY,
                                        layout=slidersize)
        self.F2p1 = widgets.FloatSlider(min=-250,
                                        max=250,
                                        step=1.0,
                                        description='P1',
                                        continuous_update=HEAVY,
                                        layout=slidersize)
        pivotsize = Layout(width='200px')
        self.pivotF1 = widgets.BoundedFloatText(
            description='Pivot',
            value=round(self.data.axis1.itoc(0.5 * self.data.size1), 2),
            min=self.data.axis1.itoc(self.data.size1),
            max=self.data.axis1.itoc(0),
            format='%.2f',
            layout=pivotsize,
            step=0.1)
        self.pivotF2 = widgets.BoundedFloatText(
            description='Pivot',
            value=round(self.data.axis2.itoc(0.5 * self.data.size2), 2),
            min=self.data.axis2.itoc(self.data.size2),
            max=self.data.axis2.itoc(0),
            format='%.2f',
            layout=pivotsize,
            step=0.1)
        # modify defaults
        self.negview.value = True
        self.done.description = 'Apply'
        for w in [
                self.F1p0, self.F1p1, self.F2p0, self.F2p1, self.pivotF1,
                self.pivotF2
        ]:
            w.observe(self.ob)
        self.cancel = widgets.Button(description="Cancel",
                                     button_style='warning',
                                     layout=self.blay)
        self.cancel.on_click(self.on_cancel)

        stcenter = "<b>F%d</b>"
        box_layout = widgets.Layout(display='flex',
                                    flex_flow='column',
                                    align_items='center',
                                    grid_gap="100px",
                                    width='100%')
        grid_layout = widgets.Layout(grid_template_columns="40% 40%",
                                     justify_items='center')
        self.phasebar = \
            widgets.GridBox( [widgets.HTML(stcenter%1),     widgets.HTML(stcenter%2),
                             self.F1p0,                             self.F2p0,
                             self.F1p1,                             self.F2p1,
                             self.pivotF1,                          self.pivotF2],
                             layout=grid_layout)
        doc = widgets.HTML("""
            <p>Use the sliders to adjust the phase parameters, &nbsp; the pivot can be set with a right click on the spectrum<br>
            Top and Side spectra are taken at the pivot level.<br>
            </p>
            """)

        self.topbar = HBox([self.posview, self.negview])
        self.controlbar = VBox(
            [self.reset, self.scale, self.savepdf, self.done, self.cancel])
        self.middlebar = HBox([self.controlbar, self.fig.canvas])

        self.children = [
            VBox([self.phasebar, doc, self.topbar, self.middlebar])
        ]
        self.pivotF1.observe(self.on_movepivot)
        self.pivotF2.observe(self.on_movepivot)

        # add right-click event on spectral window
        def on_press(event):
            if event.button == 3:
                self.pivotF1.value = round(event.ydata, 4)
                self.pivotF2.value = round(event.xdata, 4)

        cids = self.fig.canvas.mpl_connect('button_press_event', on_press)
Beispiel #15
0
def interactive_selection():
    from ipywidgets import Button, HBox, VBox, widgets, Layout
    from IPython.display import display

    ## Generate interactive grid
    d = {}
    output = widgets.Output
    #output.clear_output
    global hold_selection
    hold_selection = []

    #generate hold_key_matrix
    hold_key_matrix = generate_hold_key_matrix()

    def on_button_clicked(b):
        b.style.button_color = 'lightgreen'
        global hold_selection
        hold_selection.append(b.description)

    ##define grid
    for R in range(hold_key_matrix.shape[0]):
        hold = hold_key_matrix[
            R, :].tolist()  #convert to list, start with "18"
        item = [
            Button(description=h, layout=Layout(width='45px', height='60%'))
            for h in hold
        ]  #list of buttons
        d['{}{}'.format('H_box',
                        R)] = HBox(item)  #store all Hboxes in dictionary
        #define buttons
        for C in range(hold_key_matrix.shape[1]):
            button = item[C]  #
            button.on_click(on_button_clicked)

    whole_grid = VBox([
        d['H_box0'], d['H_box1'], d['H_box2'], d['H_box3'], d['H_box4'],
        d['H_box5'], d['H_box6'], d['H_box7'], d['H_box8'], d['H_box9'],
        d['H_box10'], d['H_box11'], d['H_box12'], d['H_box13'], d['H_box14'],
        d['H_box15'], d['H_box16'], d['H_box17']
    ])
    display(whole_grid)

    ## generate Termination buttons

    # predict grade button function
    def end_button_clicked(b):
        predict_grade_JF(hold_selection)

    #define reload_button.on_clicked
    def reload_on_clicked(b):
        global hold_selection
        for H in hold_selection:
            index_nd_array = np.where(hold_key_matrix == H)
            whole_grid.children[int(index_nd_array[0])].children[int(
                index_nd_array[1])].style.button_color = None
        hold_selection = []

    ##display reload and predict grade button
    end_button = widgets.Button(description='Predict difficulty!',
                                button_style='danger')
    end_button.on_click(end_button_clicked)

    reload_button = widgets.Button(description='Reload',
                                   button_style='primary')
    reload_button.on_click(reload_on_clicked)

    final_buttons = HBox([end_button, reload_button])
    display(final_buttons)
Beispiel #16
0
def get_interactive_logistic_regression_advanced(X,
                                                 y,
                                                 X_test=None,
                                                 y_test=None):
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(1, 1, 1)

    w1 = 0.5
    w2 = 0.5
    bias = 0.0

    x1 = 0.0
    x2 = 0.0
    h = x1 * w1 + x2 * w2 - bias
    y_hat = int(h >= 0)

    w1_slider = widgets.FloatSlider(
        value=0.5,
        min=-5.0,
        max=5.0,
        step=0.1,
        description="w1",
        disabled=False,
        continuous_update=False,
        # orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    w2_slider = widgets.FloatSlider(
        value=-1.0,
        min=-5.0,
        max=5.0,
        step=0.1,
        description="w2",
        disabled=False,
        continuous_update=False,
        # orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    bias_slider = widgets.FloatSlider(
        value=0.0,
        min=-5.0,
        max=5.0,
        step=0.1,
        description=r'$\theta$',
        disabled=False,
        continuous_update=False,
        # orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    x1_slider = widgets.FloatSlider(
        value=0.0,
        min=-2.0,
        max=2.0,
        step=0.1,
        description="x1",
        disabled=False,
        continuous_update=True,
        # orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    x2_slider = widgets.FloatSlider(
        value=0.0,
        min=-2.0,
        max=2.0,
        step=0.1,
        description="x2",
        disabled=False,
        continuous_update=True,
        # orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    show_train = widgets.Checkbox(value=False,
                                  description='Show train data',
                                  disabled=False)

    show_test = widgets.Checkbox(value=False,
                                 description='Show test data',
                                 disabled=False)

    show_boundary = widgets.Checkbox(value=False,
                                     description='Show decision boundary',
                                     disabled=False)

    show_prediction = widgets.Checkbox(value=False,
                                       description='Show prediction',
                                       disabled=False)

    show_weight_vector = widgets.Checkbox(value=False,
                                          description='Show weight vector',
                                          disabled=False)

    show_h = widgets.Checkbox(value=False,
                              description=r'Show $h$',
                              disabled=False)

    caption1 = widgets.Label(
        value=r"$h = w_1 \cdot x_1 + w_2 \cdot x_2 - \theta$")
    caption2 = widgets.Label(
        #value=f"{w1} * {x1} + {w2} * {x2} - {bias}"
        # value=f"({w1}) * ({x1}) + ({w2}) * ({x2}) - ({bias})"
        value=
        f"{format(h, '.3f')} = ({w1}) * ({x1}) + ({w2}) * ({x2}) - ({bias})")
    caption3 = widgets.Label(value=r"$\hat{y} = f(h)$")
    caption4 = widgets.Label(value=f"{y_hat} = f({h})")

    #label2 = widgets.Label(
    #    value=fr"${w1} * {x1} + {w2} * {x2} - {bias} = {y_hat} $"
    #)

    box1 = VBox(children=[x1_slider, x2_slider, show_train, show_test])
    box2 = VBox(children=[
        w1_slider, w2_slider, bias_slider, show_boundary, show_prediction
    ])
    box3 = VBox(children=[caption1, caption2, caption3, caption4])

    ui = HBox(children=[box3, box2, box1])

    xmin = X.min(axis=0)
    xmax = X.max(axis=0)
    xrange_ = xmax - xmin
    lim_x = (xmin[0] - 0.1 * xrange_[0], xmax[0] + 0.1 * xrange_[0])

    ax.set_xlim(lim_x[0], lim_x[1])
    ax.set_ylim(xmin[1] - 0.1 * xrange_[1], xmax[1] + 0.1 * xrange_[1])
    ax.set_xlabel(r"$x_1$")
    ax.set_ylabel(r"$x_2$")
    ax.set_aspect("equal")

    training_data_handle = plt.scatter(X[:, 0],
                                       X[:, 1],
                                       c=y,
                                       alpha=0.0,
                                       marker="x",
                                       s=15)
    has_test = X_test is not None and y_test is not None
    if has_test:
        test_data_handle = plt.scatter(X_test[:, 0],
                                       X_test[:, 1],
                                       c=y_test,
                                       alpha=0.0,
                                       marker="D",
                                       s=15)

    test_point_handle1 = plt.scatter([x1], [x2],
                                     s=150,
                                     linewidth=2,
                                     facecolors='none',
                                     edgecolors='black')
    test_point_handle2 = plt.scatter([x1], [x2],
                                     s=50,
                                     edgecolors='none',
                                     alpha=(y_hat == 0),
                                     c=0.0,
                                     vmin=0.0,
                                     vmax=1.0)
    test_point_handle3 = plt.scatter([x1], [x2],
                                     s=50,
                                     edgecolors='none',
                                     alpha=(y_hat == 1),
                                     c=1.0,
                                     vmin=0.0,
                                     vmax=1.0)
    # test_point_handle4 = plt.scatter([x1], [x2], s=50, edgecolors='none', alpha=0.0, c=h, vmin=-2.0, vmax=2.0)

    decision_boundary, = ax.plot([0, -w2], [0, w1], color="red", alpha=0.0)

    #projection_vector, = ax.plot([0, w1], [0, w2], color="blue")
    #projection_vector2, = ax.plot([-w1, w1], [-w2, w2], linestyle="--", color="blue")

    #vector_tip, = ax.plot(w1, w2, marker="x", markersize=15, color="blue")
    #ax.plot(0, 0, markersize=10, color="red", marker="o")

    xx = np.linspace(lim_x[0], lim_x[1], num=100)
    yy = -(w1 / w2) * xx
    top_filler = ax.fill_between(xx, y1=-10, y2=yy, color="purple", alpha=0.0)
    bottom_filler = ax.fill_between(xx,
                                    y1=yy,
                                    y2=10,
                                    color="yellow",
                                    alpha=0.0)

    def update(w1=0.5,
               w2=0.5,
               bias=0.0,
               x1=0.0,
               x2=0.0,
               show_train=False,
               show_test=False,
               show_boundary=False,
               show_prediction=False):
        # vector_tip.set_data(w1, w2)

        h = x1 * w1 + x2 * w2 - bias
        y_hat = int(h >= 0)

        w = np.array([w1, w2])
        # bias_vec = w * bias / np.linalg.norm(w)  # TODO figure this out
        # b1, b2 = bias_vec

        # UPDATE HANDLES

        # training data scatterplot - set alpha to enable/disable
        training_data_handle.set_alpha(0.75 if show_train else 0.0)
        if has_test:
            test_data_handle.set_alpha(0.75 if show_test else 0.0)

        # test point - move along x1/x2 and switch color
        test_point_handle1.set_offsets([x1, x2])
        test_point_handle2.set_offsets([x1, x2])
        test_point_handle3.set_offsets([x1, x2])
        # test_point_handle4.set_offsets([x1, x2])
        #if not show_h:
        test_point_handle2.set_alpha((y_hat == 0))
        test_point_handle3.set_alpha((y_hat == 1))
        # test_point_handle4.set_alpha(0.0)
        # else:
        #test_point_handle2.set_alpha(0.0)
        #test_point_handle3.set_alpha(0.0)
        #test_point_handle4.set_alpha(1.0)

        caption2.value = f"{format(round(h, 3), '.3f')} = ({round(w1, 2)}) * ({round(x1, 2)}) + ({round(w2, 2)}) * ({round(x2, 2)}) - ({round(bias, 2)})"
        caption4.value = f"{y_hat} = f({round(h, 2)})"

        # decision_boundary.set_data([w2+b1, -w2+b1], [-w1+b2, w1+b2])
        # projection_vector.set_data([0, w1], [0, w2])
        if w2:
            decision_boundary.set_data([lim_x[0], lim_x[1]], [
                -w1 / w2 * lim_x[0] + bias / w2,
                -w1 / w2 * lim_x[1] + bias / w2
            ])
        else:
            if w1:
                decision_boundary.set_data([bias / w1, bias / w1],
                                           [lim_x[0], lim_x[1]])
            else:
                decision_boundary.set_data([], [])
                decision_boundary.set_alpha(0.0)

        decision_boundary.set_alpha(1.0 if show_boundary else 0.0)

        # yy = -(w1/w2) * xx + b2 + w1/w2 * b1
        ax.collections = ax.collections[:-2]
        alpha = 0.1 if show_prediction else 0.0
        if w2:
            yy = -(w1 / w2) * xx + bias / w2
            if w2 > 0:
                top_filler = ax.fill_between(xx,
                                             y1=-10,
                                             y2=yy,
                                             color="purple",
                                             alpha=alpha)
                bottom_filler = ax.fill_between(xx,
                                                y1=yy,
                                                y2=10,
                                                color="yellow",
                                                alpha=alpha)
            else:
                top_filler = ax.fill_between(xx,
                                             y1=-10,
                                             y2=yy,
                                             color="yellow",
                                             alpha=alpha)
                bottom_filler = ax.fill_between(xx,
                                                y1=yy,
                                                y2=10,
                                                color="purple",
                                                alpha=alpha)
        else:
            if w1:
                if w1 > 0:
                    top_filler = ax.fill_betweenx([lim_x[0], lim_x[1]],
                                                  x1=-10,
                                                  x2=bias / w1,
                                                  color="purple",
                                                  alpha=alpha)
                    bottom_filler = ax.fill_betweenx([lim_x[0], lim_x[1]],
                                                     x1=bias / w1,
                                                     x2=10,
                                                     color="yellow",
                                                     alpha=alpha)
                else:
                    top_filler = ax.fill_betweenx([lim_x[0], lim_x[1]],
                                                  x1=-10,
                                                  x2=bias / w1,
                                                  color="yellow",
                                                  alpha=alpha)
                    bottom_filler = ax.fill_betweenx([lim_x[0], lim_x[1]],
                                                     x1=bias / w1,
                                                     x2=10,
                                                     color="purple",
                                                     alpha=alpha)
            else:
                if bias > 0:
                    top_filler = ax.fill_betweenx([lim_x[0], lim_x[1]],
                                                  x1=-10,
                                                  x2=10,
                                                  color="purple",
                                                  alpha=0.1)
                    bottom_filler = ax.fill_betweenx([lim_x[0], lim_x[1]],
                                                     x1=-10,
                                                     x2=10,
                                                     color="yellow",
                                                     alpha=0.0)
                else:
                    top_filler = ax.fill_betweenx([lim_x[0], lim_x[1]],
                                                  x1=-10,
                                                  x2=10,
                                                  color="purple",
                                                  alpha=0.0)
                    bottom_filler = ax.fill_betweenx([lim_x[0], lim_x[1]],
                                                     x1=-10,
                                                     x2=10,
                                                     color="yellow",
                                                     alpha=0.1)

        #w /= np.linalg.norm(w)
        #w *= 5
        #w1, w2 = w
        #projection_vector2.set_data([-w1, w1], [-w2, w2])

        fig.canvas.draw_idle()

    interactive_plot = interactive_output(
        update, {
            "w1": w1_slider,
            "w2": w2_slider,
            "bias": bias_slider,
            "x1": x1_slider,
            "x2": x2_slider,
            "show_train": show_train,
            "show_test": show_test,
            "show_boundary": show_boundary,
            "show_prediction": show_prediction
        })

    #interactive_plot = interactive(
    #    update,
    #    w1=w1_slider,
    #    w2=w2_slider,
    #    bias=bias_slider,
    #    x1=x1_slider,
    #    x2=x2_slider,
    #    train=show_train,
    #    test=show_test
    #)

    return interactive_plot, ui