Exemple #1
1
    def create_param_widget(self, param, value):
        from ipywidgets import Layout, HBox
        children = (HBox(),)
        if isinstance(value, bool):
            from ipywidgets import Label, ToggleButton
            p = Label(value=param, layout=Layout(width='10%'))
            t = ToggleButton(description=str(value), value=value)

            def on_bool_change(change):
                t.description = str(change['new'])
                self.params[self._method][param] = change['new']
                self.replot_peaks()

            t.observe(on_bool_change, names='value')

            children = (p, t)

        elif isinstance(value, float):
            from ipywidgets import FloatSlider, FloatText, BoundedFloatText, \
                Label
            from traitlets import link
            p = Label(value=param, layout=Layout(flex='0 1 auto', width='10%'))
            b = BoundedFloatText(value=0, min=1e-10,
                                 layout=Layout(flex='0 1 auto', width='10%'),
                                 font_weight='bold')
            a = FloatText(value=2 * value,
                          layout=Layout(flex='0 1 auto', width='10%'))
            f = FloatSlider(value=value, min=b.value, max=a.value,
                            step=np.abs(a.value - b.value) * 0.01,
                            layout=Layout(flex='1 1 auto', width='60%'))
            l = FloatText(value=f.value,
                          layout=Layout(flex='0 1 auto', width='10%'),
                          disabled=True)
            link((f, 'value'), (l, 'value'))

            def on_min_change(change):
                if f.max > change['new']:
                    f.min = change['new']
                    f.step = np.abs(f.max - f.min) * 0.01

            def on_max_change(change):
                if f.min < change['new']:
                    f.max = change['new']
                    f.step = np.abs(f.max - f.min) * 0.01

            def on_param_change(change):
                self.params[self._method][param] = change['new']
                self.replot_peaks()

            b.observe(on_min_change, names='value')
            f.observe(on_param_change, names='value')
            a.observe(on_max_change, names='value')
            children = (p, l, b, f, a)

        elif isinstance(value, int):
            from ipywidgets import IntSlider, IntText, BoundedIntText, \
                Label
            from traitlets import link
            p = Label(value=param, layout=Layout(flex='0 1 auto', width='10%'))
            b = BoundedIntText(value=0, min=1e-10,
                               layout=Layout(flex='0 1 auto', width='10%'),
                               font_weight='bold')
            a = IntText(value=2 * value,
                        layout=Layout(flex='0 1 auto', width='10%'))
            f = IntSlider(value=value, min=b.value, max=a.value,
                          step=1,
                          layout=Layout(flex='1 1 auto', width='60%'))
            l = IntText(value=f.value,
                        layout=Layout(flex='0 1 auto', width='10%'),
                        disabled=True)
            link((f, 'value'), (l, 'value'))

            def on_min_change(change):
                if f.max > change['new']:
                    f.min = change['new']
                    f.step = 1

            def on_max_change(change):
                if f.min < change['new']:
                    f.max = change['new']
                    f.step = 1

            def on_param_change(change):
                self.params[self._method][param] = change['new']
                self.replot_peaks()

            b.observe(on_min_change, names='value')
            f.observe(on_param_change, names='value')
            a.observe(on_max_change, names='value')
            children = (p, l, b, f, a)
        container = HBox(children)
        return container
Exemple #2
0
class PSLGEditor(widgets.VBox):
    boundaryTypes = List([1, 0]).tag(sync=True)
    regionTypes = List([1, 0]).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super(PSLGEditor, self).__init__(*args, **kwargs)
        self.graph = Graph(*args, **kwargs, parent=self)
        self.select_boundary = Dropdown(options=self.boundaryTypes,
                                        description='Boundary:',
                                        disable=False)
        self.select_region = Dropdown(options=self.regionTypes,
                                      description='Region:',
                                      disable=False)
        self.select_add = Dropdown(
            options=[u'Vertex \u25cf', u'Region \u25a0', u'Hole \u25b2'],
            description='Add:',
            disable=False)
        self.enter_x = FloatText(description='x:')
        self.enter_y = FloatText(description='y:')
        self.help = Label(
            'Click to add vertex, region or hole. Press Delete to remove selection. Press CTRL to drag.'
        )

        def on_boundary_change(change):
            self.graph.boundary_type = change['new']

        def on_region_change(change):
            self.graph.region_type = change['new']

        def on_add_change(change):
            self.graph.add_new = change['new'][:-2].lower()

        def on_x_change(change):
            self.graph.xy = [change['new'], self.graph.xy[1]]

        def on_y_change(change):
            self.graph.xy = [self.graph.xy[0], change['new']]

        self.select_boundary.observe(on_boundary_change, names='value')
        self.select_region.observe(on_region_change, names='value')
        self.select_add.observe(on_add_change, names='value')
        self.enter_x.observe(on_x_change, names='value')
        self.enter_y.observe(on_y_change, names='value')
        self.graph.boundary_type = self.select_boundary.value
        self.graph.region_type = self.select_region.value
        self.graph.add_new = self.select_add.value[:-2].lower()
        self.children = [
            self.graph, self.select_boundary, self.select_region,
            self.select_add, self.enter_x, self.enter_y, self.help
        ]
Exemple #3
0
    def _create_notebook_widget(self, index=None):

        from ipywidgets import (FloatSlider, FloatText, Layout, HBox)

        widget_bounds = self._interactive_slider_bounds(index=index)
        thismin = FloatText(
            value=widget_bounds['min'],
            description='min',
            layout=Layout(flex='0 1 auto', width='auto'),
        )
        thismax = FloatText(
            value=widget_bounds['max'],
            description='max',
            layout=Layout(flex='0 1 auto', width='auto'),
        )
        current_value = self.value if index is None else self.value[index]
        current_name = self.name
        if index is not None:
            current_name += '_{}'.format(index)
        widget = FloatSlider(value=current_value,
                             min=thismin.value,
                             max=thismax.value,
                             step=widget_bounds['step'],
                             description=current_name,
                             layout=Layout(flex='1 1 auto', width='auto'))

        def on_min_change(change):
            if widget.max > change['new']:
                widget.min = change['new']
                widget.step = np.abs(widget.max - widget.min) * 0.001

        def on_max_change(change):
            if widget.min < change['new']:
                widget.max = change['new']
                widget.step = np.abs(widget.max - widget.min) * 0.001

        thismin.observe(on_min_change, names='value')
        thismax.observe(on_max_change, names='value')

        this_observed = functools.partial(self._interactive_update,
                                          index=index)

        widget.observe(this_observed, names='value')
        container = HBox((thismin, widget, thismax))
        return container
Exemple #4
0
    def _create_notebook_widget(self, index=None):

        from ipywidgets import (FloatSlider, FloatText, Layout, HBox)

        widget_bounds = self._interactive_slider_bounds(index=index)
        thismin = FloatText(value=widget_bounds['min'],
                            description='min',
                            layout=Layout(flex='0 1 auto',
                                          width='auto'),)
        thismax = FloatText(value=widget_bounds['max'],
                            description='max',
                            layout=Layout(flex='0 1 auto',
                                          width='auto'),)
        current_value = self.value if index is None else self.value[index]
        current_name = self.name
        if index is not None:
            current_name += '_{}'.format(index)
        widget = FloatSlider(value=current_value,
                             min=thismin.value,
                             max=thismax.value,
                             step=widget_bounds['step'],
                             description=current_name,
                             layout=Layout(flex='1 1 auto', width='auto'))

        def on_min_change(change):
            if widget.max > change['new']:
                widget.min = change['new']
                widget.step = np.abs(widget.max - widget.min) * 0.001

        def on_max_change(change):
            if widget.min < change['new']:
                widget.max = change['new']
                widget.step = np.abs(widget.max - widget.min) * 0.001

        thismin.observe(on_min_change, names='value')
        thismax.observe(on_max_change, names='value')

        this_observed = functools.partial(self._interactive_update,
                                          index=index)

        widget.observe(this_observed, names='value')
        container = HBox((thismin, widget, thismax))
        return container
Exemple #5
0
class SubstrateTab(object):

    def __init__(self):
        
        self.output_dir = '.'
        # self.output_dir = 'tmpdir'

        self.figsize_width_substrate = 15.0  # allow extra for colormap
        self.figsize_height_substrate = 12.5
        self.figsize_width_svg = 12.0
        self.figsize_height_svg = 12.0

        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot

        self.first_time = True
        self.modulo = 1

        self.use_defaults = True

        self.svg_delta_t = 1
        self.substrate_delta_t = 1
        self.svg_frame = 1
        self.substrate_frame = 1

        self.customized_output_freq = False
        self.therapy_activation_time = 1000000
        self.max_svg_frame_pre_therapy = 1000000
        self.max_substrate_frame_pre_therapy = 1000000

        self.svg_xmin = 0

        # Probably don't want to hardwire these if we allow changing the domain size
        # self.svg_xrange = 2000
        # self.xmin = -1000.
        # self.xmax = 1000.
        # self.ymin = -1000.
        # self.ymax = 1000.
        # self.x_range = 2000.
        # self.y_range = 2000.

        self.show_nucleus = True
        self.show_edge = True

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        self.skip_cb = False

        # define dummy size of mesh (set in the tool's primary module)
        self.numx = 0
        self.numy = 0

        self.title_str = ''

        tab_height = '600px'
        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(width='900px',   # border='2px solid black',
                            height=tab_height, ) #overflow_y='scroll')

        max_frames = 1   
        # self.mcds_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)  
        # self.i_plot = interactive(self.plot_plots, frame=(0, max_frames), continuous_update=False)  
        self.i_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)  

        # "plot_size" controls the size of the tab height, not the plot (rf. figsize for that)
        # NOTE: the Substrates Plot tab has an extra row of widgets at the top of it (cf. Cell Plots tab)
        svg_plot_size = '700px'
        svg_plot_size = '600px'
        svg_plot_size = '700px'
        svg_plot_size = '900px'
        self.i_plot.layout.width = svg_plot_size
        self.i_plot.layout.height = svg_plot_size

        self.fontsize = 20

            # description='# cell frames',
        self.max_frames = BoundedIntText(
            min=0, max=99999, value=max_frames,
            description='# frames',
           layout=Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        # self.field_min_max = {'dummy': [0., 1., False]}
        # NOTE: manually setting these for now (vs. parsing them out of data/initial.xml)
        self.field_min_max = {'director signal':[0.,1.,False], 'cargo signal':[0.,1.,False] }
        # hacky I know, but make a dict that's got (key,value) reversed from the dict in the Dropdown below
        # self.field_dict = {0:'dummy'}
        self.field_dict = {0:'director signal', 1:'cargo signal'}

        self.mcds_field = Dropdown(
            options={'director signal': 0, 'cargo signal':1},
            value=0,
            #     description='Field',
           layout=Layout(width=constWidth)
        )
        # print("substrate __init__: self.mcds_field.value=",self.mcds_field.value)
#        self.mcds_field.observe(self.mcds_field_cb)
        self.mcds_field.observe(self.mcds_field_changed_cb)

        self.field_cmap = Dropdown(
            options=['viridis', 'jet', 'YlOrRd'],
            value='YlOrRd',
            #     description='Field',
           layout=Layout(width=constWidth)
        )
#        self.field_cmap.observe(self.plot_substrate)
        self.field_cmap.observe(self.mcds_field_cb)

        self.cmap_fixed_toggle = Checkbox(
            description='Fix',
            disabled=False,
#           layout=Layout(width=constWidth2),
        )
        self.cmap_fixed_toggle.observe(self.mcds_field_cb)

#         def cmap_fixed_toggle_cb(b):
#             # self.update()
# #            self.field_min_max = {'oxygen': [0., 30.,True], 'glucose': [0., 1.,False]}
#             field_name = self.field_dict[self.mcds_field.value]
#             if (self.cmap_fixed_toggle.value):  
#                 self.field_min_max[field_name][0] = self.cmap_min.value
#                 self.field_min_max[field_name][1] = self.cmap_max.value
#                 self.field_min_max[field_name][2] = True
#             else:
#                 # self.field_min_max[field_name][0] = self.cmap_min.value
#                 # self.field_min_max[field_name][1] = self.cmap_max.value
#                 self.field_min_max[field_name][2] = False
#             self.i_plot.update()

        # self.cmap_fixed_toggle.observe(cmap_fixed_toggle_cb)

#         self.save_min_max= Button(
#             description='Save', #style={'description_width': 'initial'},
#             button_style='success',  # 'success', 'info', 'warning', 'danger' or ''
#             tooltip='Save min/max for this substrate',
#             disabled=True,
#            layout=Layout(width='90px')
#         )

#         def save_min_max_cb(b):
# #            field_name = self.mcds_field.options[]
# #            field_name = next(key for key, value in self.mcds_field.options.items() if value == self.mcds_field.value)
#             field_name = self.field_dict[self.mcds_field.value]
# #            print(field_name)
# #            self.field_min_max = {'oxygen': [0., 30.], 'glucose': [0., 1.], 'H+ ions': [0., 1.], 'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 1.]}
#             self.field_min_max[field_name][0] = self.cmap_min.value
#             self.field_min_max[field_name][1] = self.cmap_max.value
# #            print(self.field_min_max)

#         self.save_min_max.on_click(save_min_max_cb)


        self.cmap_min = FloatText(
            description='Min',
            value=0,
            step = 0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_min.observe(self.mcds_field_cb)

        self.cmap_max = FloatText(
            description='Max',
            value=38,
            step = 0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_max.observe(self.mcds_field_cb)

        def cmap_fixed_toggle_cb(b):
            field_name = self.field_dict[self.mcds_field.value]
            # print(self.cmap_fixed_toggle.value)
            if (self.cmap_fixed_toggle.value):  # toggle on fixed range
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.field_min_max[field_name][0] = self.cmap_min.value
                self.field_min_max[field_name][1] = self.cmap_max.value
                self.field_min_max[field_name][2] = True
                # self.save_min_max.disabled = False
            else:  # toggle off fixed range
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.field_min_max[field_name][2] = False
                # self.save_min_max.disabled = True
#            self.mcds_field_cb()
            self.i_plot.update()

        self.cmap_fixed_toggle.observe(cmap_fixed_toggle_cb)

        field_cmap_row2 = HBox([self.field_cmap, self.cmap_fixed_toggle])

#        field_cmap_row3 = HBox([self.save_min_max, self.cmap_min, self.cmap_max])
        items_auto = [
            # self.save_min_max, #layout=Layout(flex='3 1 auto', width='auto'),
            self.cmap_min, 
            self.cmap_max,  
         ]
        box_layout = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    width='80%')
        field_cmap_row3 = Box(children=items_auto, layout=box_layout)

        # self.debug_str = Text(
        #     value='debug info',
        #     description='Debug:',
        #     disabled=True,
        #     layout=Layout(width='600px'),  #constWidth = '180px'
        # )

        #---------------------
        self.cell_nucleus_toggle = Checkbox(
            description='nuclei',
            disabled=False,
            value = self.show_nucleus,
#           layout=Layout(width=constWidth2),
        )
        def cell_nucleus_toggle_cb(b):
            # self.update()
            if (self.cell_nucleus_toggle.value):  
                self.show_nucleus = True
            else:
                self.show_nucleus = False
            self.i_plot.update()

        self.cell_nucleus_toggle.observe(cell_nucleus_toggle_cb)

        #----
        self.cell_edges_toggle = Checkbox(
            description='edges',
            disabled=False,
            value=self.show_edge,
#           layout=Layout(width=constWidth2),
        )
        def cell_edges_toggle_cb(b):
            # self.update()
            if (self.cell_edges_toggle.value):  
                self.show_edge = True
            else:
                self.show_edge = False
            self.i_plot.update()

        self.cell_edges_toggle.observe(cell_edges_toggle_cb)

        self.cells_toggle = Checkbox(
            description='Cells',
            disabled=False,
            value=True,
#           layout=Layout(width=constWidth2),
        )
        def cells_toggle_cb(b):
            # self.update()
            self.i_plot.update()
            if (self.cells_toggle.value):
                self.cell_edges_toggle.disabled = False
                self.cell_nucleus_toggle.disabled = False
            else:
                self.cell_edges_toggle.disabled = True
                self.cell_nucleus_toggle.disabled = True

        self.cells_toggle.observe(cells_toggle_cb)

        #---------------------
        self.substrates_toggle = Checkbox(
            description='Substrates',
            disabled=False,
            value=True,
#           layout=Layout(width=constWidth2),
        )
        def substrates_toggle_cb(b):
            if (self.substrates_toggle.value):  # seems bass-ackwards
                self.cmap_fixed_toggle.disabled = False
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.mcds_field.disabled = False
                self.field_cmap.disabled = False
            else:
                self.cmap_fixed_toggle.disabled = True
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.mcds_field.disabled = True
                self.field_cmap.disabled = True

        self.substrates_toggle.observe(substrates_toggle_cb)

        self.grid_toggle = Checkbox(
            description='grid',
            disabled=False,
            value=True,
#           layout=Layout(width=constWidth2),
        )
        def grid_toggle_cb(b):
            # self.update()
            self.i_plot.update()

        self.grid_toggle.observe(grid_toggle_cb)

#        field_cmap_row3 = Box([self.save_min_max, self.cmap_min, self.cmap_max])

        # mcds_tab = widgets.VBox([mcds_dir, mcds_plot, mcds_play], layout=tab_layout)
        # mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3, self.max_frames])  # mcds_dir
#        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3,])  # mcds_dir

#        self.tab = HBox([mcds_params, self.mcds_plot], layout=tab_layout)

        help_label = Label('select slider: drag or left/right arrows')
        # row1 = Box([help_label, Box( [self.max_frames, self.mcds_field, self.field_cmap], layout=Layout(border='0px solid black',
        row1a = Box( [self.max_frames, self.mcds_field, self.field_cmap], layout=Layout(border='1px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex')) 
        row1b = Box( [self.cells_toggle, self.cell_nucleus_toggle, self.cell_edges_toggle], layout=Layout(border='1px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex')) 
        row1 = HBox( [row1a, Label('.....'), row1b])

        row2a = Box([self.cmap_fixed_toggle, self.cmap_min, self.cmap_max], layout=Layout(border='1px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex'))
        # row2b = Box( [self.substrates_toggle, self.grid_toggle], layout=Layout(border='1px solid black',
        row2b = Box( [self.substrates_toggle, ], layout=Layout(border='1px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex')) 
        # row2 = HBox( [row2a, self.substrates_toggle, self.grid_toggle])
        row2 = HBox( [row2a, Label('.....'), row2b])

        if (hublib_flag):
            self.download_button = Download('mcds.zip', style='warning', icon='cloud-download', 
                                                tooltip='Download data', cb=self.download_cb)

            self.download_svg_button = Download('svg.zip', style='warning', icon='cloud-download', 
                                            tooltip='You need to allow pop-ups in your browser', cb=self.download_svg_cb)
            download_row = HBox([self.download_button.w, self.download_svg_button.w, Label("Download all cell plots (browser must allow pop-ups).")])

            # box_layout = Layout(border='0px solid')
            controls_box = VBox([row1, row2])  # ,width='50%', layout=box_layout)
            self.tab = VBox([controls_box, self.i_plot, download_row])
            # self.tab = VBox([controls_box, self.debug_str, self.i_plot, download_row])
        else:
            # self.tab = VBox([row1, row2])
            self.tab = VBox([row1, row2, self.i_plot])

    #---------------------------------------------------
    def update_dropdown_fields(self, data_dir):
        # print('update_dropdown_fields called --------')
        self.output_dir = data_dir
        tree = None
        try:
            fname = os.path.join(self.output_dir, "initial.xml")
            tree = ET.parse(fname)
            xml_root = tree.getroot()
        except:
            print("Cannot open ",fname," to read info, e.g., names of substrate fields.")
            return

        xml_root = tree.getroot()
        self.field_min_max = {}
        self.field_dict = {}
        dropdown_options = {}
        uep = xml_root.find('.//variables')
        comment_str = ""
        field_idx = 0
        if (uep):
            for elm in uep.findall('variable'):
                # print("-----> ",elm.attrib['name'])
                field_name = elm.attrib['name']
                self.field_min_max[field_name] = [0., 1., False]
                self.field_dict[field_idx] = field_name
                dropdown_options[field_name] = field_idx

                self.field_min_max[field_name][0] = 0   
                self.field_min_max[field_name][1] = 1

                # self.field_min_max[field_name][0] = field_idx   #rwh: helps debug
                # self.field_min_max[field_name][1] = field_idx+1   
                self.field_min_max[field_name][2] = False
                field_idx += 1

#        constWidth = '180px'
        # print('options=',dropdown_options)
        # print(self.field_min_max)  # debug
        self.mcds_field.value = 0
        self.mcds_field.options = dropdown_options
#         self.mcds_field = Dropdown(
# #            options={'oxygen': 0, 'glucose': 1},
#             options=dropdown_options,
#             value=0,
#             #     description='Field',
#            layout=Layout(width=constWidth)
#         )

    # def update_max_frames_expected(self, value):  # called when beginning an interactive Run
    #     self.max_frames.value = value  # assumes naming scheme: "snapshot%08d.svg"
    #     self.mcds_plot.children[0].max = self.max_frames.value

#------------------------------------------------------------------------------
    def update_params(self, config_tab, user_params_tab):
        # xml_root.find(".//x_min").text = str(self.xmin.value)
        # xml_root.find(".//x_max").text = str(self.xmax.value)
        # xml_root.find(".//dx").text = str(self.xdelta.value)
        # xml_root.find(".//y_min").text = str(self.ymin.value)
        # xml_root.find(".//y_max").text = str(self.ymax.value)
        # xml_root.find(".//dy").text = str(self.ydelta.value)
        # xml_root.find(".//z_min").text = str(self.zmin.value)
        # xml_root.find(".//z_max").text = str(self.zmax.value)
        # xml_root.find(".//dz").text = str(self.zdelta.value)

        self.xmin = config_tab.xmin.value 
        self.xmax = config_tab.xmax.value 
        self.x_range = self.xmax - self.xmin
        self.svg_xrange = self.xmax - self.xmin
        self.ymin = config_tab.ymin.value
        self.ymax = config_tab.ymax.value 
        self.y_range = self.ymax - self.ymin

        self.numx =  math.ceil( (self.xmax - self.xmin) / config_tab.xdelta.value)
        self.numy =  math.ceil( (self.ymax - self.ymin) / config_tab.ydelta.value)

        if (self.x_range > self.y_range):  
            ratio = self.y_range / self.x_range
            self.figsize_width_substrate = 15.0  # allow extra for colormap
            self.figsize_height_substrate = 12.5 * ratio
            self.figsize_width_svg = 12.0
            self.figsize_height_svg = 12.0 * ratio
        else:   # x < y
            ratio = self.x_range / self.y_range
            self.figsize_width_substrate = 15.0 * ratio 
            self.figsize_height_substrate = 12.5
            self.figsize_width_svg = 12.0 * ratio
            self.figsize_height_svg = 12.0 

        self.svg_flag = config_tab.toggle_svg.value
        self.substrates_flag = config_tab.toggle_mcds.value
        # print("substrates: update_params(): svg_flag, toggle=",self.svg_flag,config_tab.toggle_svg.value)        
        # print("substrates: update_params(): self.substrates_flag = ",self.substrates_flag)
        self.svg_delta_t = config_tab.svg_interval.value
        self.substrate_delta_t = config_tab.mcds_interval.value
        self.modulo = int(self.substrate_delta_t / self.svg_delta_t)
        # print("substrates: update_params(): modulo=",self.modulo)        

        if self.customized_output_freq:
#            self.therapy_activation_time = user_params_tab.therapy_activation_time.value   # NOTE: edit for user param name
            # print("substrates: update_params(): therapy_activation_time=",self.therapy_activation_time)
            self.max_svg_frame_pre_therapy = int(self.therapy_activation_time/self.svg_delta_t)
            self.max_substrate_frame_pre_therapy = int(self.therapy_activation_time/self.substrate_delta_t)

#------------------------------------------------------------------------------
#    def update(self, rdir):
#   Called from driver module (e.g., pc4*.py) (among other places?)
    def update(self, rdir=''):
        # with debug_view:
        #     print("substrates: update rdir=", rdir)        
        # print("substrates: update rdir=", rdir)        

        if rdir:
            self.output_dir = rdir

        # print('update(): self.output_dir = ', self.output_dir)

        if self.first_time:
        # if True:
            self.first_time = False
            full_xml_filename = Path(os.path.join(self.output_dir, 'config.xml'))
            # print("substrates: update(), config.xml = ",full_xml_filename)        
            # self.num_svgs = len(glob.glob(os.path.join(self.output_dir, 'snap*.svg')))
            # self.num_substrates = len(glob.glob(os.path.join(self.output_dir, 'output*.xml')))
            # print("substrates: num_svgs,num_substrates =",self.num_svgs,self.num_substrates)        
            # argh - no! If no files created, then denom = -1
            # self.modulo = int((self.num_svgs - 1) / (self.num_substrates - 1))
            # print("substrates: update(): modulo=",self.modulo)        
            if full_xml_filename.is_file():
                tree = ET.parse(str(full_xml_filename))  # this file cannot be overwritten; part of tool distro
                xml_root = tree.getroot()
                self.svg_delta_t = float(xml_root.find(".//SVG//interval").text)
                self.substrate_delta_t = float(xml_root.find(".//full_data//interval").text)
                # print("substrates: svg,substrate delta_t values=",self.svg_delta_t,self.substrate_delta_t)        
                self.modulo = int(self.substrate_delta_t / self.svg_delta_t)
                # print("substrates: update(): modulo=",self.modulo)        


        # all_files = sorted(glob.glob(os.path.join(self.output_dir, 'output*.xml')))  # if the substrates/MCDS

        all_files = sorted(glob.glob(os.path.join(self.output_dir, 'snap*.svg')))   # if .svg
        if len(all_files) > 0:
            last_file = all_files[-1]
            self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "snapshot%08d.svg"
        else:
            substrate_files = sorted(glob.glob(os.path.join(self.output_dir, 'output*.xml')))
            if len(substrate_files) > 0:
                last_file = substrate_files[-1]
                self.max_frames.value = int(last_file[-12:-4])

    def download_svg_cb(self):
        file_str = os.path.join(self.output_dir, '*.svg')
        # print('zip up all ',file_str)
        with zipfile.ZipFile('svg.zip', 'w') as myzip:
            for f in glob.glob(file_str):
                myzip.write(f, os.path.basename(f))   # 2nd arg avoids full filename path in the archive

    def download_cb(self):
        file_xml = os.path.join(self.output_dir, '*.xml')
        file_mat = os.path.join(self.output_dir, '*.mat')
        # print('zip up all ',file_str)
        with zipfile.ZipFile('mcds.zip', 'w') as myzip:
            for f in glob.glob(file_xml):
                myzip.write(f, os.path.basename(f)) # 2nd arg avoids full filename path in the archive
            for f in glob.glob(file_mat):
                myzip.write(f, os.path.basename(f))

    def update_max_frames(self,_b):
        self.i_plot.children[0].max = self.max_frames.value

    # called if user selected different substrate in dropdown
    def mcds_field_changed_cb(self, b):
        # print("mcds_field_changed_cb: self.mcds_field.value=",self.mcds_field.value)
        if (self.mcds_field.value == None):
            return
        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
        # print('mcds_field_changed_cb: field_name='+ field_name)
        # print(self.field_min_max[field_name])
        # self.debug_str.value = 'mcds_field_changed_cb: '+ field_name  + str(self.field_min_max[field_name])
        # self.debug_str.value = 'cb1: '+ str(self.field_min_max)

        # BEWARE of these triggering the mcds_field_cb() callback! Hence, the "skip_cb"
        self.skip_cb = True
        self.cmap_min.value = self.field_min_max[field_name][0]
        self.cmap_max.value = self.field_min_max[field_name][1]
        self.cmap_fixed_toggle.value = bool(self.field_min_max[field_name][2])
        self.skip_cb = False

        self.i_plot.update()

    # called if user provided different min/max values for colormap, or a different colormap
    def mcds_field_cb(self, b):
        if self.skip_cb:
            return

        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
        # print('mcds_field_cb: field_name='+ field_name)

        # print('mcds_field_cb: '+ field_name)
        self.field_min_max[field_name][0] = self.cmap_min.value 
        self.field_min_max[field_name][1] = self.cmap_max.value
        self.field_min_max[field_name][2] = self.cmap_fixed_toggle.value
        # print(self.field_min_max[field_name])
        # self.debug_str.value = 'mcds_field_cb: ' + field_name + str(self.field_min_max[field_name])
        # self.debug_str.value = 'cb2: '+ str(self.field_min_max)
        # print('--- cb2: '+ str(self.field_min_max))  #rwh2
        # self.cmap_fixed_toggle.value = self.field_min_max[field_name][2]

        # field_name = self.mcds_field.options[self.mcds_field.value]
        # self.cmap_min.value = self.field_min_max[field_name][0]  # oxygen, etc
        # self.cmap_max.value = self.field_min_max[field_name][1]  # oxygen, etc

#        self.field_index = self.mcds_field.value + 4
#        print('field_index=',self.field_index)
        self.i_plot.update()


    #---------------------------------------------------------------------------
    def circles(self, x, y, s, c='b', vmin=None, vmax=None, **kwargs):
        """
        See https://gist.github.com/syrte/592a062c562cd2a98a83 

        Make a scatter plot of circles. 
        Similar to plt.scatter, but the size of circles are in data scale.
        Parameters
        ----------
        x, y : scalar or array_like, shape (n, )
            Input data
        s : scalar or array_like, shape (n, ) 
            Radius of circles.
        c : color or sequence of color, optional, default : 'b'
            `c` can be a single color format string, or a sequence of color
            specifications of length `N`, or a sequence of `N` numbers to be
            mapped to colors using the `cmap` and `norm` specified via kwargs.
            Note that `c` should not be a single numeric RGB or RGBA sequence 
            because that is indistinguishable from an array of values
            to be colormapped. (If you insist, use `color` instead.)  
            `c` can be a 2-D array in which the rows are RGB or RGBA, however. 
        vmin, vmax : scalar, optional, default: None
            `vmin` and `vmax` are used in conjunction with `norm` to normalize
            luminance data.  If either are `None`, the min and max of the
            color array is used.
        kwargs : `~matplotlib.collections.Collection` properties
            Eg. alpha, edgecolor(ec), facecolor(fc), linewidth(lw), linestyle(ls), 
            norm, cmap, transform, etc.
        Returns
        -------
        paths : `~matplotlib.collections.PathCollection`
        Examples
        --------
        a = np.arange(11)
        circles(a, a, s=a*0.2, c=a, alpha=0.5, ec='none')
        plt.colorbar()
        License
        --------
        This code is under [The BSD 3-Clause License]
        (http://opensource.org/licenses/BSD-3-Clause)
        """

        if np.isscalar(c):
            kwargs.setdefault('color', c)
            c = None

        if 'fc' in kwargs:
            kwargs.setdefault('facecolor', kwargs.pop('fc'))
        if 'ec' in kwargs:
            kwargs.setdefault('edgecolor', kwargs.pop('ec'))
        if 'ls' in kwargs:
            kwargs.setdefault('linestyle', kwargs.pop('ls'))
        if 'lw' in kwargs:
            kwargs.setdefault('linewidth', kwargs.pop('lw'))
        # You can set `facecolor` with an array for each patch,
        # while you can only set `facecolors` with a value for all.

        zipped = np.broadcast(x, y, s)
        patches = [Circle((x_, y_), s_)
                for x_, y_, s_ in zipped]
        collection = PatchCollection(patches, **kwargs)
        if c is not None:
            c = np.broadcast_to(c, zipped.shape).ravel()
            collection.set_array(c)
            collection.set_clim(vmin, vmax)

        ax = plt.gca()
        ax.add_collection(collection)
        ax.autoscale_view()
        # plt.draw_if_interactive()
        if c is not None:
            plt.sci(collection)
        # return collection

    #------------------------------------------------------------
    # def plot_svg(self, frame, rdel=''):
    def plot_svg(self, frame):
        # global current_idx, axes_max
        global current_frame
        current_frame = frame
        fname = "snapshot%08d.svg" % frame
        full_fname = os.path.join(self.output_dir, fname)
        # with debug_view:
            # print("plot_svg:", full_fname) 
        # print("-- plot_svg:", full_fname) 
        if not os.path.isfile(full_fname):
            print("Once output files are generated, click the slider.")   
            return

        xlist = deque()
        ylist = deque()
        rlist = deque()
        rgb_list = deque()

        #  print('\n---- ' + fname + ':')
#        tree = ET.parse(fname)
        tree = ET.parse(full_fname)
        root = tree.getroot()
        #  print('--- root.tag ---')
        #  print(root.tag)
        #  print('--- root.attrib ---')
        #  print(root.attrib)
        #  print('--- child.tag, child.attrib ---')
        numChildren = 0
        for child in root:
            #    print(child.tag, child.attrib)
            #    print("keys=",child.attrib.keys())
            if self.use_defaults and ('width' in child.attrib.keys()):
                self.axes_max = float(child.attrib['width'])
                # print("debug> found width --> axes_max =", axes_max)
            if child.text and "Current time" in child.text:
                svals = child.text.split()
                # remove the ".00" on minutes
                self.title_str += "   cells: " + svals[2] + "d, " + svals[4] + "h, " + svals[7][:-3] + "m"

                # self.cell_time_mins = int(svals[2])*1440 + int(svals[4])*60 + int(svals[7][:-3])
                # self.title_str += "   cells: " + str(self.cell_time_mins) + "m"   # rwh

            # print("width ",child.attrib['width'])
            # print('attrib=',child.attrib)
            # if (child.attrib['id'] == 'tissue'):
            if ('id' in child.attrib.keys()):
                # print('-------- found tissue!!')
                tissue_parent = child
                break

        # print('------ search tissue')
        cells_parent = None

        for child in tissue_parent:
            # print('attrib=',child.attrib)
            if (child.attrib['id'] == 'cells'):
                # print('-------- found cells, setting cells_parent')
                cells_parent = child
                break
            numChildren += 1

        num_cells = 0
        #  print('------ search cells')
        for child in cells_parent:
            #    print(child.tag, child.attrib)
            #    print('attrib=',child.attrib)
            for circle in child:  # two circles in each child: outer + nucleus
                #  circle.attrib={'cx': '1085.59','cy': '1225.24','fill': 'rgb(159,159,96)','r': '6.67717','stroke': 'rgb(159,159,96)','stroke-width': '0.5'}
                #      print('  --- cx,cy=',circle.attrib['cx'],circle.attrib['cy'])
                xval = float(circle.attrib['cx'])

                # map SVG coords into comp domain
                # xval = (xval-self.svg_xmin)/self.svg_xrange * self.x_range + self.xmin
                xval = xval/self.x_range * self.x_range + self.xmin

                s = circle.attrib['fill']
                # print("s=",s)
                # print("type(s)=",type(s))
                if (s[0:3] == "rgb"):  # if an rgb string, e.g. "rgb(175,175,80)" 
                    rgb = list(map(int, s[4:-1].split(",")))  
                    rgb[:] = [x / 255. for x in rgb]
                else:     # otherwise, must be a color name
                    rgb_tuple = mplc.to_rgb(mplc.cnames[s])  # a tuple
                    rgb = [x for x in rgb_tuple]

                # test for bogus x,y locations (rwh TODO: use max of domain?)
                too_large_val = 10000.
                if (np.fabs(xval) > too_large_val):
                    print("bogus xval=", xval)
                    break
                yval = float(circle.attrib['cy'])
                # yval = (yval - self.svg_xmin)/self.svg_xrange * self.y_range + self.ymin
                yval = yval/self.y_range * self.y_range + self.ymin
                if (np.fabs(yval) > too_large_val):
                    print("bogus xval=", xval)
                    break

                rval = float(circle.attrib['r'])
                # if (rgb[0] > rgb[1]):
                #     print(num_cells,rgb, rval)
                xlist.append(xval)
                ylist.append(yval)
                rlist.append(rval)
                rgb_list.append(rgb)

                # For .svg files with cells that *have* a nucleus, there will be a 2nd
                if (not self.show_nucleus):
                #if (not self.show_nucleus):
                    break

            num_cells += 1

            # if num_cells > 3:   # for debugging
            #   print(fname,':  num_cells= ',num_cells," --- debug exit.")
            #   sys.exit(1)
            #   break

            # print(fname,':  num_cells= ',num_cells)

        xvals = np.array(xlist)
        yvals = np.array(ylist)
        rvals = np.array(rlist)
        rgbs = np.array(rgb_list)
        # print("xvals[0:5]=",xvals[0:5])
        # print("rvals[0:5]=",rvals[0:5])
        # print("rvals.min, max=",rvals.min(),rvals.max())

        # rwh - is this where I change size of render window?? (YES - yipeee!)
        #   plt.figure(figsize=(6, 6))
        #   plt.cla()
        # if (self.substrates_toggle.value):
        self.title_str += " (" + str(num_cells) + " agents)"
            # title_str = " (" + str(num_cells) + " agents)"
        # else:
            # mins= round(int(float(root.find(".//current_time").text)))  # TODO: check units = mins
            # hrs = int(mins/60)
            # days = int(hrs/24)
            # title_str = '%dd, %dh, %dm' % (int(days),(hrs%24), mins - (hrs*60))
        plt.title(self.title_str)

        plt.xlim(self.xmin, self.xmax)
        plt.ylim(self.ymin, self.ymax)

        #   plt.xlim(axes_min,axes_max)
        #   plt.ylim(axes_min,axes_max)
        #   plt.scatter(xvals,yvals, s=rvals*scale_radius, c=rgbs)

        # TODO: make figsize a function of plot_size? What about non-square plots?
        # self.fig = plt.figure(figsize=(9, 9))

#        axx = plt.axes([0, 0.05, 0.9, 0.9])  # left, bottom, width, height
#        axx = fig.gca()
#        print('fig.dpi=',fig.dpi) # = 72

        #   im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #   ax.xlim(axes_min,axes_max)
        #   ax.ylim(axes_min,axes_max)

        # convert radii to radii in pixels
        # ax2 = self.fig.gca()
        # N = len(xvals)
        # rr_pix = (ax2.transData.transform(np.vstack([rvals, rvals]).T) -
        #             ax2.transData.transform(np.vstack([np.zeros(N), np.zeros(N)]).T))
        # rpix, _ = rr_pix.T

        # markers_size = (144. * rpix / self.fig.dpi)**2   # = (2*rpix / fig.dpi * 72)**2
        # markers_size = markers_size/4000000.
        # print('max=',markers_size.max())

        #rwh - temp fix - Ah, error only occurs when "edges" is toggled on
        if (self.show_edge):
            try:
                # plt.scatter(xvals,yvals, s=markers_size, c=rgbs, edgecolor='black', linewidth=0.5)
                self.circles(xvals,yvals, s=rvals, color=rgbs, edgecolor='black', linewidth=0.5)
                # cell_circles = self.circles(xvals,yvals, s=rvals, color=rgbs, edgecolor='black', linewidth=0.5)
                # plt.sci(cell_circles)
            except (ValueError):
                pass
        else:
            # plt.scatter(xvals,yvals, s=markers_size, c=rgbs)
            self.circles(xvals,yvals, s=rvals, color=rgbs)

        # if (self.show_tracks):
        #     for key in self.trackd.keys():
        #         xtracks = self.trackd[key][:,0]
        #         ytracks = self.trackd[key][:,1]
        #         plt.plot(xtracks[0:frame],ytracks[0:frame],  linewidth=5)

        # plt.xlim(self.axes_min, self.axes_max)
        # plt.ylim(self.axes_min, self.axes_max)
        #   ax.grid(False)
#        axx.set_title(title_str)
        # plt.title(title_str)

    #---------------------------------------------------------------------------
    # assume "frame" is cell frame #, unless Cells is togggled off, then it's the substrate frame #
    # def plot_substrate(self, frame, grid):
    def plot_substrate(self, frame):
        # global current_idx, axes_max, gFileId, field_index

        # print("plot_substrate(): frame*self.substrate_delta_t  = ",frame*self.substrate_delta_t)
        # print("plot_substrate(): frame*self.svg_delta_t  = ",frame*self.svg_delta_t)
        self.title_str = ''

        # Recall:
        # self.svg_delta_t = config_tab.svg_interval.value
        # self.substrate_delta_t = config_tab.mcds_interval.value
        # self.modulo = int(self.substrate_delta_t / self.svg_delta_t)
        # self.therapy_activation_time = user_params_tab.therapy_activation_time.value

        # print("plot_substrate(): pre_therapy: max svg, substrate frames = ",max_svg_frame_pre_therapy, max_substrate_frame_pre_therapy)

        # Assume: # .svg files >= # substrate files
#        if (self.cells_toggle.value):

        # if (self.substrates_toggle.value and frame*self.substrate_delta_t <= self.svg_frame*self.svg_delta_t):
        # if (self.substrates_toggle.value and (frame % self.modulo == 0)):
        if (self.substrates_toggle.value):
            # self.fig = plt.figure(figsize=(14, 15.6))
            # self.fig = plt.figure(figsize=(15.0, 12.5))
            self.fig = plt.figure(figsize=(self.figsize_width_substrate, self.figsize_height_substrate))

            # rwh - funky way to figure out substrate frame for pc4cancerbots (due to user-defined "save_interval*")
            # self.cell_time_mins 
            # self.substrate_frame = int(frame / self.modulo)
            if (self.customized_output_freq and (frame > self.max_svg_frame_pre_therapy)):
                # max_svg_frame_pre_therapy = int(self.therapy_activation_time/self.svg_delta_t)
                # max_substrate_frame_pre_therapy = int(self.therapy_activation_time/self.substrate_delta_t)
                self.substrate_frame = self.max_substrate_frame_pre_therapy + (frame - self.max_svg_frame_pre_therapy)
            else:
                self.substrate_frame = int(frame / self.modulo)

            # print("plot_substrate(): self.substrate_frame=",self.substrate_frame)        

            # if (self.substrate_frame > (self.num_substrates-1)):
                # self.substrate_frame = self.num_substrates-1

            # print('self.substrate_frame = ',self.substrate_frame)
            # if (self.cells_toggle.value):
            #     self.modulo = int((self.num_svgs - 1) / (self.num_substrates - 1))
            #     self.substrate_frame = frame % self.modulo
            # else:
            #     self.substrate_frame = frame 
            fname = "output%08d_microenvironment0.mat" % self.substrate_frame
            xml_fname = "output%08d.xml" % self.substrate_frame
            # fullname = output_dir_str + fname

    #        fullname = fname
            full_fname = os.path.join(self.output_dir, fname)
            # print("--- plot_substrate(): full_fname=",full_fname)
            full_xml_fname = os.path.join(self.output_dir, xml_fname)
    #        self.output_dir = '.'

    #        if not os.path.isfile(fullname):
            if not os.path.isfile(full_fname):
                print("Once output files are generated, click the slider.")  # No:  output00000000_microenvironment0.mat
                return

    #        tree = ET.parse(xml_fname)
            tree = ET.parse(full_xml_fname)
            xml_root = tree.getroot()
            mins = round(int(float(xml_root.find(".//current_time").text)))  # TODO: check units = mins
            self.substrate_mins= round(int(float(xml_root.find(".//current_time").text)))  # TODO: check units = mins

            hrs = int(mins/60)
            days = int(hrs/24)
            self.title_str = 'substrate: %dd, %dh, %dm' % (int(days),(hrs%24), mins - (hrs*60))
            # self.title_str = 'substrate: %dm' % (mins )   # rwh


            info_dict = {}
    #        scipy.io.loadmat(fullname, info_dict)
            scipy.io.loadmat(full_fname, info_dict)
            M = info_dict['multiscale_microenvironment']
            #     global_field_index = int(mcds_field.value)
            #     print('plot_substrate: field_index =',field_index)
            f = M[self.field_index, :]   # 4=tumor cells field, 5=blood vessel density, 6=growth substrate
            # plt.clf()
            # my_plot = plt.imshow(f.reshape(400,400), cmap='jet', extent=[0,20, 0,20])
        
            # self.fig = plt.figure(figsize=(18.0,15))  # this strange figsize results in a ~square contour plot

            # plt.subplot(grid[0:1, 0:1])
            # main_ax = self.fig.add_subplot(grid[0:1, 0:1])  # works, but tiny upper-left region
            #main_ax = self.fig.add_subplot(grid[0:2, 0:2])
            # main_ax = self.fig.add_subplot(grid[0:, 0:2])
            #main_ax = self.fig.add_subplot(grid[:-1, 0:])   # nrows, ncols
            #main_ax = self.fig.add_subplot(grid[0:, 0:])   # nrows, ncols
            #main_ax = self.fig.add_subplot(grid[0:4, 0:])   # nrows, ncols


            # main_ax = self.fig.add_subplot(grid[0:3, 0:])   # nrows, ncols
            # main_ax = self.fig.add_subplot(111)   # nrows, ncols


            # plt.rc('font', size=10)  # TODO: does this affect the Cell plots fonts too? YES. Not what we want.

            #     fig.set_tight_layout(True)
            #     ax = plt.axes([0, 0.05, 0.9, 0.9 ]) #left, bottom, width, height
            #     ax = plt.axes([0, 0.0, 1, 1 ])
            #     cmap = plt.cm.viridis # Blues, YlOrBr, ...
            #     im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
            #     ax.grid(False)

            # print("substrates.py: ------- numx, numy = ", self.numx, self.numy )
            # if (self.numx == 0):   # need to parse vals from the config.xml
            #     # print("--- plot_substrate(): full_fname=",full_fname)
            #     fname = os.path.join(self.output_dir, "config.xml")
            #     tree = ET.parse(fname)
            #     xml_root = tree.getroot()
            #     self.xmin = float(xml_root.find(".//x_min").text)
            #     self.xmax = float(xml_root.find(".//x_max").text)
            #     dx = float(xml_root.find(".//dx").text)
            #     self.ymin = float(xml_root.find(".//y_min").text)
            #     self.ymax = float(xml_root.find(".//y_max").text)
            #     dy = float(xml_root.find(".//dy").text)
            #     self.numx =  math.ceil( (self.xmax - self.xmin) / dx)
            #     self.numy =  math.ceil( (self.ymax - self.ymin) / dy)

            try:
                xgrid = M[0, :].reshape(self.numy, self.numx)
                ygrid = M[1, :].reshape(self.numy, self.numx)
            except:
                print("substrates.py: mismatched mesh size for reshape: numx,numy=",self.numx, self.numy)
                pass
#                xgrid = M[0, :].reshape(self.numy, self.numx)
#                ygrid = M[1, :].reshape(self.numy, self.numx)

            num_contours = 15
            levels = MaxNLocator(nbins=num_contours).tick_values(self.cmap_min.value, self.cmap_max.value)
            contour_ok = True
            if (self.cmap_fixed_toggle.value):
                try:
                    # substrate_plot = main_ax.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy, self.numx), levels=levels, extend='both', cmap=self.field_cmap.value, fontsize=self.fontsize)
                    substrate_plot = plt.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy, self.numx), levels=levels, extend='both', cmap=self.field_cmap.value, fontsize=self.fontsize)
                except:
                    contour_ok = False
                    # print('got error on contourf 1.')
            else:    
                try:
                    # substrate_plot = main_ax.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy,self.numx), num_contours, cmap=self.field_cmap.value)
                    substrate_plot = plt.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy,self.numx), num_contours, cmap=self.field_cmap.value)
                except:
                    contour_ok = False
                    # print('got error on contourf 2.')

            if (contour_ok):
                # main_ax.set_title(self.title_str, fontsize=self.fontsize)
                plt.title(self.title_str, fontsize=self.fontsize)
                # main_ax.tick_params(labelsize=self.fontsize)
            # cbar = plt.colorbar(my_plot)
                # cbar = self.fig.colorbar(substrate_plot, ax=main_ax)
                cbar = self.fig.colorbar(substrate_plot)
                cbar.ax.tick_params(labelsize=self.fontsize)
                # cbar = main_ax.colorbar(my_plot)
                # cbar.ax.tick_params(labelsize=self.fontsize)
            # axes_min = 0
            # axes_max = 2000

            # main_ax.set_xlim([self.xmin, self.xmax])
            # main_ax.set_ylim([self.ymin, self.ymax])
            plt.xlim(self.xmin, self.xmax)
            plt.ylim(self.ymin, self.ymax)

            # if (frame == 0):  # maybe allow substrate grid display later
            #     xs = np.linspace(self.xmin,self.xmax,self.numx)
            #     ys = np.linspace(self.ymin,self.ymax,self.numy)
            #     hlines = np.column_stack(np.broadcast_arrays(xs[0], ys, xs[-1], ys))
            #     vlines = np.column_stack(np.broadcast_arrays(xs, ys[0], xs, ys[-1]))
            #     grid_lines = np.concatenate([hlines, vlines]).reshape(-1, 2, 2)
            #     line_collection = LineCollection(grid_lines, color="gray", linewidths=0.5)
            #     # ax = main_ax.gca()
            #     main_ax.add_collection(line_collection)
            #     # ax.set_xlim(xs[0], xs[-1])
            #     # ax.set_ylim(ys[0], ys[-1])


        # Now plot the cells (possibly on top of the substrate)
        if (self.cells_toggle.value):
            if (not self.substrates_toggle.value):
                # self.fig = plt.figure(figsize=(12, 12))
                self.fig = plt.figure(figsize=(self.figsize_width_svg, self.figsize_height_svg))
            # self.plot_svg(frame)
            self.svg_frame = frame
            # print('plot_svg with frame=',self.svg_frame)
            self.plot_svg(self.svg_frame)
Exemple #6
0
class SubstrateTab(object):

    def __init__(self):
        
        self.output_dir = '.'

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(width='900px',   # border='2px solid black',
                            height=tab_height, ) #overflow_y='scroll')

        max_frames = 253   # first time + 30240 / 120
        self.mcds_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)  
        svg_plot_size = '700px'
        self.mcds_plot.layout.width = svg_plot_size
        self.mcds_plot.layout.height = svg_plot_size

        self.max_frames = BoundedIntText(
            min=0, max=99999, value=max_frames,
            description='Max frames',
            layout=Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.field_min_max = {'oxygen': [0., 38.], 'glucose': [0.8, 1.], 'H+ ions': [0., 1.], 
                                'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 0.1]}
        # hacky I know, but make a dict that's got (key,value) reversed from the dict in the Dropdown below
        self.field_dict = {0:'oxygen', 1:'glucose', 2:'H+ ions', 3:'ECM', 4:'NP1', 5:'NP2'}
        self.mcds_field = Dropdown(
            options={'oxygen': 0, 'glucose': 1, 'H+ ions': 2, 'ECM': 3, 'NP1': 4, 'NP2': 5},
            value=0,
            #     description='Field',
            layout=Layout(width=constWidth)
        )
#        self.mcds_field.observe(self.mcds_field_cb)
        self.mcds_field.observe(self.mcds_field_changed_cb)

        # self.field_cmap = Text(
        #     value='viridis',
        #     description='Colormap',
        #     disabled=True,
        #     layout=Layout(width=constWidth),
        # )
        self.field_cmap = Dropdown(
            options=['viridis', 'jet', 'YlOrRd'],
            value='viridis',
            #     description='Field',
            layout=Layout(width=constWidth)
        )
        #self.field_cmap.observe(self.plot_substrate)
#        self.field_cmap.observe(self.plot_substrate)
        self.field_cmap.observe(self.mcds_field_cb)

        self.cmap_fixed = Checkbox(
            description='Fix',
            disabled=False,
            layout=Layout(width=constWidth2),
        )

        self.save_min_max= Button(
            description='Save', #style={'description_width': 'initial'},
            button_style='success',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Save min/max for this substrate',
            disabled=True,
            layout=Layout(width='90px')
        )

        def save_min_max_cb(b):
#            field_name = self.mcds_field.options[]
#            field_name = next(key for key, value in self.mcds_field.options.items() if value == self.mcds_field.value)
            field_name = self.field_dict[self.mcds_field.value]
#            print(field_name)
#            self.field_min_max = {'oxygen': [0., 30.], 'glucose': [0., 1.], 'H+ ions': [0., 1.], 'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 1.]}
            self.field_min_max[field_name][0] = self.cmap_min.value
            self.field_min_max[field_name][1] = self.cmap_max.value
#            print(self.field_min_max)

        self.save_min_max.on_click(save_min_max_cb)

        self.cmap_min = FloatText(
            description='Min',
            value=0,
            step = 0.1,
            disabled=True,
            #layout=Layout(width=constWidth2),
        )
        self.cmap_min.observe(self.mcds_field_cb)

        self.cmap_max = FloatText(
            description='Max',
            value=38,
            step = 0.1,
            disabled=True,
            #layout=Layout(width=constWidth2),
        )
        self.cmap_max.observe(self.mcds_field_cb)

        def cmap_fixed_cb(b):
            if (self.cmap_fixed.value):
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.save_min_max.disabled = False
            else:
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.save_min_max.disabled = True
#            self.mcds_field_cb()

        self.cmap_fixed.observe(cmap_fixed_cb)

        field_cmap_row2 = HBox([self.field_cmap, self.cmap_fixed])

#        field_cmap_row3 = HBox([self.save_min_max, self.cmap_min, self.cmap_max])
        items_auto = [
            self.save_min_max, #layout=Layout(flex='3 1 auto', width='auto'),
            self.cmap_min, 
            self.cmap_max,  
         ]
        box_layout = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    width='80%')
        field_cmap_row3 = Box(children=items_auto, layout=box_layout)

#        field_cmap_row3 = Box([self.save_min_max, self.cmap_min, self.cmap_max])

        # mcds_tab = widgets.VBox([mcds_dir, mcds_plot, mcds_play], layout=tab_layout)
        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3, self.max_frames])  # mcds_dir
#        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3,])  # mcds_dir

        self.tab = HBox([mcds_params, self.mcds_plot], layout=tab_layout)
#        self.tab = HBox([mcds_params, self.mcds_plot])

    def update_max_frames(self,_b):
        self.mcds_plot.children[0].max = self.max_frames.value

    def mcds_field_changed_cb(self, b):
        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
#        print('mcds_field_cb: '+field_name)
        self.cmap_min.value = self.field_min_max[field_name][0]
        self.cmap_max.value = self.field_min_max[field_name][1]
        self.mcds_plot.update()

    def mcds_field_cb(self, b):
        #self.field_index = self.mcds_field.value
#        self.field_index = self.mcds_field.options.index(self.mcds_field.value) + 4
#        self.field_index = self.mcds_field.options[self.mcds_field.value]
        self.field_index = self.mcds_field.value + 4

        # field_name = self.mcds_field.options[self.mcds_field.value]
        # self.cmap_min.value = self.field_min_max[field_name][0]  # oxygen, etc
        # self.cmap_max.value = self.field_min_max[field_name][1]  # oxygen, etc

#        self.field_index = self.mcds_field.value + 4

#        print('field_index=',self.field_index)
        self.mcds_plot.update()

    def plot_substrate(self, frame):
        # global current_idx, axes_max, gFileId, field_index
        fname = "output%08d_microenvironment0.mat" % frame
        xml_fname = "output%08d.xml" % frame
        # fullname = output_dir_str + fname

#        fullname = fname
        full_fname = os.path.join(self.output_dir, fname)
        full_xml_fname = os.path.join(self.output_dir, xml_fname)
#        self.output_dir = '.'

#        if not os.path.isfile(fullname):
        if not os.path.isfile(full_fname):
#            print("File does not exist: ", full_fname)
            print("No: ", full_fname)
            return

#        tree = ET.parse(xml_fname)
        tree = ET.parse(full_xml_fname)
        xml_root = tree.getroot()
        mins= round(int(float(xml_root.find(".//current_time").text)))  # TODO: check units = mins
        hrs = mins/60.
        days = hrs/24.
        title_str = '%dd, %dh, %dm' % (int(days),(hrs%24), mins - (hrs*60))


        info_dict = {}
#        scipy.io.loadmat(fullname, info_dict)
        scipy.io.loadmat(full_fname, info_dict)
        M = info_dict['multiscale_microenvironment']
        #     global_field_index = int(mcds_field.value)
        #     print('plot_substrate: field_index =',field_index)
        f = M[self.field_index, :]   # 4=tumor cells field, 5=blood vessel density, 6=growth substrate
        # plt.clf()
        # my_plot = plt.imshow(f.reshape(400,400), cmap='jet', extent=[0,20, 0,20])
    
        fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot
        #     fig.set_tight_layout(True)
        #     ax = plt.axes([0, 0.05, 0.9, 0.9 ]) #left, bottom, width, height
        #     ax = plt.axes([0, 0.0, 1, 1 ])
        #     cmap = plt.cm.viridis # Blues, YlOrBr, ...
        #     im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #     ax.grid(False)

        N = int(math.sqrt(len(M[0,:])))
        grid2D = M[0, :].reshape(N,N)
        xvec = grid2D[0, :]

        num_contours = 15
#        levels = MaxNLocator(nbins=10).tick_values(vmin, vmax)
        levels = MaxNLocator(nbins=num_contours).tick_values(self.cmap_min.value, self.cmap_max.value)
        if (self.cmap_fixed.value):
            my_plot = plt.contourf(xvec, xvec, M[self.field_index, :].reshape(N,N), levels=levels, extend='both', cmap=self.field_cmap.value)
        else:    
#        my_plot = plt.contourf(xvec, xvec, M[self.field_index, :].reshape(N,N), num_contours, cmap=self.field_cmap.value)
            my_plot = plt.contourf(xvec, xvec, M[self.field_index, :].reshape(N,N), num_contours, cmap=self.field_cmap.value)

        plt.title(title_str)
        plt.colorbar(my_plot)
        axes_min = 0
        axes_max = 2000
Exemple #7
0
class SubstrateTab(object):

    def __init__(self):
        
        self.output_dir = '.'
#        self.output_dir = 'tmpdir'

        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        # define dummy size of mesh (set in the tool's primary module)
        self.numx = 0
        self.numy = 0

        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(width='900px',   # border='2px solid black',
                            height=tab_height, ) #overflow_y='scroll')

        max_frames = 1   
        self.mcds_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)  
        svg_plot_size = '500px'  # small: controls the size of the tab height, not the plot (rf. figsize for that)
        svg_plot_size = '800px'  # medium
        svg_plot_size = '750px'  # medium
        self.mcds_plot.layout.width = svg_plot_size
        self.mcds_plot.layout.height = svg_plot_size

        self.max_frames = BoundedIntText(
            min=0, max=99999, value=max_frames,
            description='Max',
           layout=Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.field_min_max = {'dummy': [0., 1.]}
        # hacky I know, but make a dict that's got (key,value) reversed from the dict in the Dropdown below
        self.field_dict = {0:'dummy'}

        self.mcds_field = Dropdown(
            options={'dummy': 0},
            value=0,
            #     description='Field',
           layout=Layout(width=constWidth)
        )
        # print("substrate __init__: self.mcds_field.value=",self.mcds_field.value)
#        self.mcds_field.observe(self.mcds_field_cb)
        self.mcds_field.observe(self.mcds_field_changed_cb)

        # self.field_cmap = Text(
        #     value='viridis',
        #     description='Colormap',
        #     disabled=True,
        #     layout=Layout(width=constWidth),
        # )
        self.field_cmap = Dropdown(
            options=['viridis', 'jet', 'YlOrRd'],
            value='viridis',
            #     description='Field',
           layout=Layout(width=constWidth)
        )
        #self.field_cmap.observe(self.plot_substrate)
#        self.field_cmap.observe(self.plot_substrate)
        self.field_cmap.observe(self.mcds_field_cb)

        self.cmap_fixed = Checkbox(
            description='Fix',
            disabled=False,
#           layout=Layout(width=constWidth2),
        )

        self.save_min_max= Button(
            description='Save', #style={'description_width': 'initial'},
            button_style='success',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Save min/max for this substrate',
            disabled=True,
           layout=Layout(width='90px')
        )

        def save_min_max_cb(b):
#            field_name = self.mcds_field.options[]
#            field_name = next(key for key, value in self.mcds_field.options.items() if value == self.mcds_field.value)
            field_name = self.field_dict[self.mcds_field.value]
#            print(field_name)
#            self.field_min_max = {'oxygen': [0., 30.], 'glucose': [0., 1.], 'H+ ions': [0., 1.], 'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 1.]}
            self.field_min_max[field_name][0] = self.cmap_min.value
            self.field_min_max[field_name][1] = self.cmap_max.value
#            print(self.field_min_max)

        self.save_min_max.on_click(save_min_max_cb)

        self.cmap_min = FloatText(
            description='Min',
            value=0,
            step = 0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_min.observe(self.mcds_field_cb)

        self.cmap_max = FloatText(
            description='Max',
            value=38,
            step = 0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_max.observe(self.mcds_field_cb)

        def cmap_fixed_cb(b):
            if (self.cmap_fixed.value):
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.save_min_max.disabled = False
            else:
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.save_min_max.disabled = True
#            self.mcds_field_cb()

        self.cmap_fixed.observe(cmap_fixed_cb)

        field_cmap_row2 = HBox([self.field_cmap, self.cmap_fixed])

#        field_cmap_row3 = HBox([self.save_min_max, self.cmap_min, self.cmap_max])
        items_auto = [
            self.save_min_max, #layout=Layout(flex='3 1 auto', width='auto'),
            self.cmap_min, 
            self.cmap_max,  
         ]
        box_layout = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    width='80%')
        field_cmap_row3 = Box(children=items_auto, layout=box_layout)

#        field_cmap_row3 = Box([self.save_min_max, self.cmap_min, self.cmap_max])

        # mcds_tab = widgets.VBox([mcds_dir, mcds_plot, mcds_play], layout=tab_layout)
        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3, self.max_frames])  # mcds_dir
#        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3,])  # mcds_dir

#        self.tab = HBox([mcds_params, self.mcds_plot], layout=tab_layout)
#        self.tab = HBox([mcds_params, self.mcds_plot])

        help_label = Label('select slider: drag or left/right arrows')
        row1 = Box([help_label, Box( [self.max_frames, self.mcds_field, self.field_cmap], layout=Layout(border='0px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex'))] )
        row2 = Box([self.cmap_fixed, self.cmap_min, self.cmap_max], layout=Layout(border='0px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex'))
        if (hublib_flag):
            self.download_button = Download('mcds.zip', style='warning', icon='cloud-download', 
                                                tooltip='Download data', cb=self.download_cb)
            download_row = HBox([self.download_button.w, Label("Download all substrate data (browser must allow pop-ups).")])

    #        self.tab = VBox([row1, row2, self.mcds_plot])
            self.tab = VBox([row1, row2, self.mcds_plot, download_row])
        else:
            # self.tab = VBox([row1, row2])
            self.tab = VBox([row1, row2, self.mcds_plot])

    #---------------------------------------------------
    def update_dropdown_fields(self, data_dir):
        # print('update_dropdown_fields called --------')
        self.output_dir = data_dir
        tree = None
        try:
            fname = os.path.join(self.output_dir, "initial.xml")
            tree = ET.parse(fname)
            xml_root = tree.getroot()
        except:
            print("Cannot open ",fname," to read info, e.g., names of substrate fields.")
            return

        xml_root = tree.getroot()
        self.field_min_max = {}
        self.field_dict = {}
        dropdown_options = {}
        uep = xml_root.find('.//variables')
        comment_str = ""
        field_idx = 0
        if (uep):
            for elm in uep.findall('variable'):
                # print("-----> ",elm.attrib['name'])
                self.field_min_max[elm.attrib['name']] = [0., 1.]
                self.field_dict[field_idx] = elm.attrib['name']
                dropdown_options[elm.attrib['name']] = field_idx
                field_idx += 1

#        constWidth = '180px'
        # print('options=',dropdown_options)
        self.mcds_field.value=0
        self.mcds_field.options=dropdown_options
#         self.mcds_field = Dropdown(
# #            options={'oxygen': 0, 'glucose': 1},
#             options=dropdown_options,
#             value=0,
#             #     description='Field',
#            layout=Layout(width=constWidth)
#         )

    def update_max_frames_expected(self, value):  # called when beginning an interactive Run
        self.max_frames.value = value  # assumes naming scheme: "snapshot%08d.svg"
        self.mcds_plot.children[0].max = self.max_frames.value

#    def update(self, rdir):
    def update(self, rdir=''):
        # with debug_view:
        #     print("substrates: update rdir=", rdir)        

        if rdir:
            self.output_dir = rdir

        all_files = sorted(glob.glob(os.path.join(self.output_dir, 'output*.xml')))
        if len(all_files) > 0:
            last_file = all_files[-1]
            self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "snapshot%08d.svg"

        # with debug_view:
        #     print("substrates: added %s files" % len(all_files))


        # self.output_dir = rdir
        # if rdir == '':
        #     # self.max_frames.value = 0
        #     tmpdir = os.path.abspath('tmpdir')
        #     self.output_dir = tmpdir
        #     all_files = sorted(glob.glob(os.path.join(tmpdir, 'output*.xml')))
        #     if len(all_files) > 0:
        #         last_file = all_files[-1]
        #         self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "output%08d.xml"
        #         self.mcds_plot.update()
        #     return

        # all_files = sorted(glob.glob(os.path.join(rdir, 'output*.xml')))
        # if len(all_files) > 0:
        #     last_file = all_files[-1]
        #     self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "output%08d.xml"
        #     self.mcds_plot.update()

    def download_cb(self):
        file_xml = os.path.join(self.output_dir, '*.xml')
        file_mat = os.path.join(self.output_dir, '*.mat')
        # print('zip up all ',file_str)
        with zipfile.ZipFile('mcds.zip', 'w') as myzip:
            for f in glob.glob(file_xml):
                myzip.write(f, os.path.basename(f)) # 2nd arg avoids full filename path in the archive
            for f in glob.glob(file_mat):
                myzip.write(f, os.path.basename(f))

    def update_max_frames(self,_b):
        self.mcds_plot.children[0].max = self.max_frames.value

    def mcds_field_changed_cb(self, b):
        # print("mcds_field_changed_cb: self.mcds_field.value=",self.mcds_field.value)
        if (self.mcds_field.value == None):
            return
        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
#        print('mcds_field_cb: '+field_name)
        self.cmap_min.value = self.field_min_max[field_name][0]
        self.cmap_max.value = self.field_min_max[field_name][1]
        self.mcds_plot.update()

    def mcds_field_cb(self, b):
        #self.field_index = self.mcds_field.value
#        self.field_index = self.mcds_field.options.index(self.mcds_field.value) + 4
#        self.field_index = self.mcds_field.options[self.mcds_field.value]
        self.field_index = self.mcds_field.value + 4

        # field_name = self.mcds_field.options[self.mcds_field.value]
        # self.cmap_min.value = self.field_min_max[field_name][0]  # oxygen, etc
        # self.cmap_max.value = self.field_min_max[field_name][1]  # oxygen, etc

#        self.field_index = self.mcds_field.value + 4

#        print('field_index=',self.field_index)
        self.mcds_plot.update()

    def plot_substrate(self, frame):
        # global current_idx, axes_max, gFileId, field_index
        fname = "output%08d_microenvironment0.mat" % frame
        xml_fname = "output%08d.xml" % frame
        # fullname = output_dir_str + fname

#        fullname = fname
        full_fname = os.path.join(self.output_dir, fname)
        full_xml_fname = os.path.join(self.output_dir, xml_fname)
#        self.output_dir = '.'

#        if not os.path.isfile(fullname):
        if not os.path.isfile(full_fname):
            print("Once output files are generated, click the slider.")  # No:  output00000000_microenvironment0.mat
            return

#        tree = ET.parse(xml_fname)
        tree = ET.parse(full_xml_fname)
        xml_root = tree.getroot()
        mins= round(int(float(xml_root.find(".//current_time").text)))  # TODO: check units = mins
        hrs = int(mins/60)
        days = int(hrs/24)
        title_str = '%dd, %dh, %dm' % (int(days),(hrs%24), mins - (hrs*60))


        info_dict = {}
#        scipy.io.loadmat(fullname, info_dict)
        scipy.io.loadmat(full_fname, info_dict)
        M = info_dict['multiscale_microenvironment']
        #     global_field_index = int(mcds_field.value)
        #     print('plot_substrate: field_index =',field_index)
        f = M[self.field_index, :]   # 4=tumor cells field, 5=blood vessel density, 6=growth substrate
        # plt.clf()
        # my_plot = plt.imshow(f.reshape(400,400), cmap='jet', extent=[0,20, 0,20])
    
        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot
        self.fig = plt.figure(figsize=(24.0,20))  # this strange figsize results in a ~square contour plot
        # self.fig = plt.figure(figsize=(28.8,24))  # this strange figsize results in a ~square contour plot
        #     fig.set_tight_layout(True)
        #     ax = plt.axes([0, 0.05, 0.9, 0.9 ]) #left, bottom, width, height
        #     ax = plt.axes([0, 0.0, 1, 1 ])
        #     cmap = plt.cm.viridis # Blues, YlOrBr, ...
        #     im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #     ax.grid(False)

        # print("substrates.py: ------- numx, numy = ", self.numx, self.numy )
        if (self.numx == 0):   # need to parse vals from the config.xml
            fname = os.path.join(self.output_dir, "config.xml")
            tree = ET.parse(fname)
            xml_root = tree.getroot()
            xmin = float(xml_root.find(".//x_min").text)
            xmax = float(xml_root.find(".//x_max").text)
            dx = float(xml_root.find(".//dx").text)
            ymin = float(xml_root.find(".//y_min").text)
            ymax = float(xml_root.find(".//y_max").text)
            dy = float(xml_root.find(".//dy").text)
            self.numx =  math.ceil( (xmax - xmin) / dx)
            self.numy =  math.ceil( (ymax - ymin) / dy)

        xgrid = M[0, :].reshape(self.numy, self.numx)
        ygrid = M[1, :].reshape(self.numy, self.numx)

        num_contours = 15
        levels = MaxNLocator(nbins=num_contours).tick_values(self.cmap_min.value, self.cmap_max.value)
        contour_ok = True
        if (self.cmap_fixed.value):
            try:
                my_plot = plt.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy, self.numx), levels=levels, extend='both', cmap=self.field_cmap.value)
            except:
                contour_ok = False
                # print('got error on contourf 1.')
        else:    
            try:
                my_plot = plt.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy,self.numx), num_contours, cmap=self.field_cmap.value)
            except:
                contour_ok = False
                # print('got error on contourf 2.')

        if (contour_ok):
            plt.title(title_str)
            plt.colorbar(my_plot)
        axes_min = 0
        axes_max = 2000
Exemple #8
0
class Dashboard(VBox):
    """
    Build the dashboard for Jupyter widgets. Requires running
    in a notebook/jupyterlab.
    """
    def __init__(self, net, width="95%", height="550px", play_rate=0.5):
        self._ignore_layer_updates = False
        self.player = _Player(self, play_rate)
        self.player.start()
        self.net = net
        r = random.randint(1, 1000000)
        self.class_id = "picture-dashboard-%s-%s" % (self.net.name, r)
        self._width = width
        self._height = height
        ## Global widgets:
        style = {"description_width": "initial"}
        self.feature_columns = IntText(description="Feature columns:",
                                       value=self.net.config["dashboard.features.columns"],
                                       min=0,
                                       max=1024,
                                       style=style)
        self.feature_scale = FloatText(description="Feature scale:",
                                       value=self.net.config["dashboard.features.scale"],
                                       min=0.1,
                                       max=10,
                                       style=style)
        self.feature_columns.observe(self.regenerate, names='value')
        self.feature_scale.observe(self.regenerate, names='value')
        ## Hack to center SVG as justify-content is broken:
        self.net_svg = HTML(value="""<p style="text-align:center">%s</p>""" % ("",), layout=Layout(
            width=self._width, overflow_x='auto', overflow_y="auto",
            justify_content="center"))
        # Make controls first:
        self.output = Output()
        controls = self.make_controls()
        config = self.make_config()
        super().__init__([config, controls, self.net_svg, self.output])

    def propagate(self, inputs):
        """
        Propagate inputs through the dashboard view of the network.
        """
        if dynamic_pictures_check():
            return self.net.propagate(inputs, class_id=self.class_id, update_pictures=True)
        else:
            self.regenerate(inputs=input)

    def goto(self, position):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            return
        if self.control_select.value == "Train":
            length = len(self.net.dataset.train_inputs)
        elif self.control_select.value == "Test":
            length = len(self.net.dataset.test_inputs)
        #### Position it:
        if position == "begin":
            self.control_slider.value = 0
        elif position == "end":
            self.control_slider.value = length - 1
        elif position == "prev":
            if self.control_slider.value - 1 < 0:
                self.control_slider.value = length - 1 # wrap around
            else:
                self.control_slider.value = max(self.control_slider.value - 1, 0)
        elif position == "next":
            if self.control_slider.value + 1 > length - 1:
                self.control_slider.value = 0 # wrap around
            else:
                self.control_slider.value = min(self.control_slider.value + 1, length - 1)
        self.position_text.value = self.control_slider.value


    def change_select(self, change=None):
        """
        """
        self.update_control_slider(change)
        self.regenerate()

    def update_control_slider(self, change=None):
        self.net.config["dashboard.dataset"] = self.control_select.value
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            self.control_slider.value = 0
            self.position_text.value = 0
            self.control_slider.disabled = True
            self.position_text.disabled = True
            for child in self.control_buttons.children:
                if not hasattr(child, "icon") or child.icon != "refresh":
                    child.disabled = True
            return
        if self.control_select.value == "Test":
            self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
            minmax = (0, max(len(self.net.dataset.test_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.test_inputs) == 0:
                disabled = True
            else:
                disabled = False
        elif self.control_select.value == "Train":
            self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
            minmax = (0, max(len(self.net.dataset.train_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.train_inputs) == 0:
                disabled = True
            else:
                disabled = False
        self.control_slider.disabled = disabled
        self.position_text.disbaled = disabled
        self.position_text.value = self.control_slider.value
        for child in self.control_buttons.children:
            if not hasattr(child, "icon") or child.icon != "refresh":
                child.disabled = disabled

    def update_zoom_slider(self, change):
        if change["name"] == "value":
            self.net.config["svg_scale"] = self.zoom_slider.value
            self.regenerate()

    def update_position_text(self, change):
        # {'name': 'value', 'old': 2, 'new': 3, 'owner': IntText(value=3, layout=Layout(width='100%')), 'type': 'change'}
        self.control_slider.value = change["new"]

    def get_current_input(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_inputs[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_inputs[self.control_slider.value]

    def get_current_targets(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_targets[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_targets[self.control_slider.value]

    def update_slider_control(self, change):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            return
        if change["name"] == "value":
            self.position_text.value = self.control_slider.value
            if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.train_inputs[self.control_slider.value],
                                    targets=self.net.dataset.train_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.train_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.train_inputs[self.control_slider.value],
                                               cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]:
                    if len(self.net.output_bank_order) == 1: ## FIXME: use minmax of output bank
                        self.net.display_component([self.net.dataset.train_targets[self.control_slider.value]],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        self.net.display_component(self.net.dataset.train_targets[self.control_slider.value],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.train_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.train_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors",  class_id=self.class_id, minmax=(-1, 1))
            elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.test_inputs[self.control_slider.value],
                                    targets=self.net.dataset.test_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.test_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.test_inputs[self.control_slider.value],
                                               cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]: ## FIXME: use minmax of output bank
                    self.net.display_component([self.net.dataset.test_targets[self.control_slider.value]],
                                               "targets",
                                               class_id=self.class_id,
                                               minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.test_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.test_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1))

    def toggle_play(self, button):
        ## toggle
        if self.button_play.description == "Play":
            self.button_play.description = "Stop"
            self.button_play.icon = "pause"
            self.player.resume()
        else:
            self.button_play.description = "Play"
            self.button_play.icon = "play"
            self.player.pause()

    def prop_one(self, button=None):
        self.update_slider_control({"name": "value"})

    def regenerate(self, button=None, inputs=None, targets=None):
        ## Protection when deleting object on shutdown:
        if isinstance(button, dict) and 'new' in button and button['new'] is None:
            return
        ## Update the config:
        self.net.config["dashboard.features.bank"] = self.feature_bank.value
        self.net.config["dashboard.features.columns"] = self.feature_columns.value
        self.net.config["dashboard.features.scale"] = self.feature_scale.value
        inputs = inputs if inputs is not None else self.get_current_input()
        targets = targets if targets is not None else self.get_current_targets()
        features = None
        if self.feature_bank.value in self.net.layer_dict.keys() and inputs is not None:
            if self.net.model is not None:
                features = self.net.propagate_to_features(self.feature_bank.value, inputs,
                                                          cols=self.feature_columns.value,
                                                          scale=self.feature_scale.value, display=False)
        svg = """<p style="text-align:center">%s</p>""" % (self.net.to_svg(inputs=inputs, targets=targets,
                                                                           class_id=self.class_id),)
        if inputs is not None and features is not None:
            html_horizontal = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top" style="width: 50%%;">%s</td>
  <td valign="top" align="center" style="width: 50%%;"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            html_vertical = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top">%s</td>
</tr>
<tr>
  <td valign="top" align="center"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            self.net_svg.value = (html_vertical if self.net.config["svg_rotate"] else html_horizontal) % (
                svg, "%s features" % self.feature_bank.value, features)
        else:
            self.net_svg.value = svg

    def make_colormap_image(self, colormap_name):
        from .layers import Layer
        if not colormap_name:
            colormap_name = get_colormap()
        layer = Layer("Colormap", 100)
        minmax = layer.get_act_minmax()
        image = layer.make_image(np.arange(minmax[0], minmax[1], .01),
                                 colormap_name,
                                 {"pixels_per_unit": 1,
                                  "svg_rotate": self.net.config["svg_rotate"]}).resize((300, 25))
        return image

    def set_attr(self, obj, attr, value):
        if value not in [{}, None]: ## value is None when shutting down
            if isinstance(value, dict):
                value = value["value"]
            if isinstance(obj, dict):
                obj[attr] = value
            else:
                setattr(obj, attr, value)
            ## was crashing on Widgets.__del__, if get_ipython() no longer existed
            self.regenerate()

    def make_controls(self):
        button_begin = Button(icon="fast-backward", layout=Layout(width='100%'))
        button_prev = Button(icon="backward", layout=Layout(width='100%'))
        button_next = Button(icon="forward", layout=Layout(width='100%'))
        button_end = Button(icon="fast-forward", layout=Layout(width='100%'))
        #button_prop = Button(description="Propagate", layout=Layout(width='100%'))
        #button_train = Button(description="Train", layout=Layout(width='100%'))
        self.button_play = Button(icon="play", description="Play", layout=Layout(width="100%"))
        refresh_button = Button(icon="refresh", layout=Layout(width="25%"))

        self.position_text = IntText(value=0, layout=Layout(width="100%"))

        self.control_buttons = HBox([
            button_begin,
            button_prev,
            #button_train,
            self.position_text,
            button_next,
            button_end,
            self.button_play,
            refresh_button
        ], layout=Layout(width='100%', height="50px"))
        length = (len(self.net.dataset.train_inputs) - 1) if len(self.net.dataset.train_inputs) > 0 else 0
        self.control_slider = IntSlider(description="Dataset index",
                                   continuous_update=False,
                                   min=0,
                                   max=max(length, 0),
                                   value=0,
                                   layout=Layout(width='100%'))
        if self.net.config["dashboard.dataset"] == "Train":
            length = len(self.net.dataset.train_inputs)
        else:
            length = len(self.net.dataset.test_inputs)
        self.total_text = Label(value="of %s" % length, layout=Layout(width="100px"))
        self.zoom_slider = FloatSlider(description="Zoom",
                                       continuous_update=False,
                                       min=0, max=1.0,
                                       style={"description_width": 'initial'},
                                       layout=Layout(width="65%"),
                                       value=self.net.config["svg_scale"] if self.net.config["svg_scale"] is not None else 0.5)

        ## Hook them up:
        button_begin.on_click(lambda button: self.goto("begin"))
        button_end.on_click(lambda button: self.goto("end"))
        button_next.on_click(lambda button: self.goto("next"))
        button_prev.on_click(lambda button: self.goto("prev"))
        self.button_play.on_click(self.toggle_play)
        self.control_slider.observe(self.update_slider_control, names='value')
        refresh_button.on_click(lambda widget: (self.update_control_slider(),
                                                self.output.clear_output(),
                                                self.regenerate()))
        self.zoom_slider.observe(self.update_zoom_slider, names='value')
        self.position_text.observe(self.update_position_text, names='value')
        # Put them together:
        controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")),
                         self.control_buttons], layout=Layout(width='100%'))

        #net_page = VBox([control, self.net_svg], layout=Layout(width='95%'))
        controls.on_displayed(lambda widget: self.regenerate())
        return controls

    def make_config(self):
        layout = Layout()
        style = {"description_width": "initial"}
        checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"],
                             layout=layout, style=style)
        checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value')
        checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"],
                             layout=layout, style=style)
        checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value')

        hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:",
                         style=style, layout=layout)
        hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value')
        vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:",
                         style=style, layout=layout)
        vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value')
        self.feature_bank = Select(description="Features:", value=self.net.config["dashboard.features.bank"],
                              options=[""] + [layer.name for layer in self.net.layers if self.net._layer_has_features(layer.name)],
                              rows=1)
        self.feature_bank.observe(self.regenerate, names='value')
        self.control_select = Select(
            options=['Test', 'Train'],
            value=self.net.config["dashboard.dataset"],
            description='Dataset:',
            rows=1
        )
        self.control_select.observe(self.change_select, names='value')
        column1 = [self.control_select,
                   self.zoom_slider,
                   hspace,
                   vspace,
                   HBox([checkbox1, checkbox2]),
                   self.feature_bank,
                   self.feature_columns,
                   self.feature_scale
        ]
        ## Make layer selectable, and update-able:
        column2 = []
        layer = self.net.layers[-1]
        self.layer_select = Select(description="Layer:", value=layer.name,
                                   options=[layer.name for layer in
                                            self.net.layers],
                                   rows=1)
        self.layer_select.observe(self.update_layer_selection, names='value')
        column2.append(self.layer_select)
        self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout)
        self.layer_visible_checkbox.observe(self.update_layer, names='value')
        column2.append(self.layer_visible_checkbox)
        self.layer_colormap = Select(description="Colormap:",
                                     options=[""] + AVAILABLE_COLORMAPS,
                                     value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1)
        self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)))
        self.layer_colormap.observe(self.update_layer, names='value')
        column2.append(self.layer_colormap)
        column2.append(self.layer_colormap_image)
        ## get dynamic minmax; if you change it it will set it in layer as override:
        minmax = layer.get_act_minmax()
        self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style)
        self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style)
        self.layer_mindim.observe(self.update_layer, names='value')
        self.layer_maxdim.observe(self.update_layer, names='value')
        column2.append(self.layer_mindim)
        column2.append(self.layer_maxdim)
        output_shape = layer.get_output_shape()
        self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style)
        self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout)
        self.layer_feature.observe(self.update_layer, names='value')
        column2.append(self.layer_feature)
        self.svg_rotate = Checkbox(description="Rotate network",
                                   value=self.net.config["svg_rotate"],
                                   style={"description_width": 'initial'},
                                   layout=Layout(width="52%"))
        self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value')
        self.save_config_button = Button(icon="save", layout=Layout(width="10%"))
        self.save_config_button.on_click(self.save_config)
        column2.append(HBox([self.svg_rotate, self.save_config_button]))
        config_children = HBox([VBox(column1, layout=Layout(width="100%")),
                                VBox(column2, layout=Layout(width="100%"))])
        accordion = Accordion(children=[config_children])
        accordion.set_title(0, self.net.name)
        accordion.selected_index = None
        return accordion

    def save_config(self, widget=None):
        self.net.save_config()

    def update_layer(self, change):
        """
        Update the layer object, and redisplay.
        """
        if self._ignore_layer_updates:
            return
        ## The rest indicates a change to a display variable.
        ## We need to save the value in the layer, and regenerate
        ## the display.
        # Get the layer:
        layer = self.net[self.layer_select.value]
        # Save the changed value in the layer:
        layer.feature = self.layer_feature.value
        layer.visible = self.layer_visible_checkbox.value
        ## These three, dealing with colors of activations,
        ## can be done with a prop_one():
        if "color" in change["owner"].description.lower():
            ## Matches: Colormap, lefmost color, rightmost color
            ## overriding dynamic minmax!
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.colormap = self.layer_colormap.value if self.layer_colormap.value else None
            self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
            self.prop_one()
        else:
            self.regenerate()

    def update_layer_selection(self, change):
        """
        Just update the widgets; don't redraw anything.
        """
        ## No need to redisplay anything
        self._ignore_layer_updates = True
        ## First, get the new layer selected:
        layer = self.net[self.layer_select.value]
        ## Now, let's update all of the values without updating:
        self.layer_visible_checkbox.value = layer.visible
        self.layer_colormap.value = layer.colormap if layer.colormap != "" else ""
        self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
        minmax = layer.get_act_minmax()
        self.layer_mindim.value = minmax[0]
        self.layer_maxdim.value = minmax[1]
        self.layer_feature.value = layer.feature
        self._ignore_layer_updates = False
Exemple #9
0
    def build_options(self):
        grid = GridspecLayout(10, 2)
        options_map = {}
        style = {'description_width': '60%', 'width': 'auto'}

        # feature
        feature = Combobox(description='Feature to plot:',
                           style=style,
                           options=list(self.feature_names),
                           ensure_option=True,
                           value=self.feature_names[0])
        options_map['feature'] = feature

        # num_grid_points
        num_grid_points = BoundedIntText(
            value=10,
            min=1,
            max=999999,
            step=1,
            description='Number of grid points:',
            style=style,
            description_tooltip='Number of grid points for numeric feature')
        options_map['num_grid_points'] = num_grid_points

        # grid_type
        grid_type = Dropdown(
            description='Grid type:',
            options=['percentile', 'equal'],
            style=style,
            description_tooltip='Type of grid points for numeric feature')
        options_map['grid_type'] = grid_type

        # cust_range
        cust_range = Checkbox(description='Custom grid range', value=False)
        options_map['cust_range'] = cust_range

        # range_min
        range_min = FloatText(
            description='Custom range minimum:',
            style=style,
            description_tooltip=
            'Percentile (when grid_type="percentile") or value (when grid_type="equal") '
            'lower bound of range to investigate (for numeric feature)\n'
            ' - Enabled only when custom grid range is True and variable with grid points is None',
            disabled=True)
        options_map['range_min'] = range_min

        # range_max
        range_max = FloatText(
            description='Custom range maximum:',
            style=style,
            description_tooltip=
            'Percentile (when grid_type="percentile") or value (when grid_type="equal") '
            'upper bound of range to investigate (for numeric feature)\n'
            ' - Enabled only when custom grid range is True and variable with grid points is None',
            disabled=True)
        options_map['range_max'] = range_max

        # cust_grid_points
        cust_grid_points = UpdatingCombobox(
            options_keys=self.globals_options,
            description='Variable with grid points:',
            style=style,
            description_tooltip=
            'Name of variable (or None) with customized list of grid points for numeric feature',
            value='None',
            disabled=True)
        cust_grid_points.lookup_in_kernel = True
        options_map['cust_grid_points'] = cust_grid_points

        # set up disabling of range inputs, when user doesn't want custom range
        def disable_ranges(change):
            range_min.disabled = not change['new']
            range_max.disabled = not change['new']
            cust_grid_points.disabled = not change['new']
            # but if the cust_grid_points has a value filled in keep range_max and range_min disabled
            if cust_grid_points.value != 'None':
                range_max.disabled = True
                range_min.disabled = True

        cust_range.observe(disable_ranges, names=['value'])

        # set up disabling of range_max and range_min if user specifies custom grid points
        def disable_max_min(change):
            if change['new'] == 'None':
                range_max.disabled = False
                range_min.disabled = False
            else:
                range_max.disabled = True
                range_min.disabled = True

        cust_grid_points.observe(disable_max_min, names=['value'])

        # set up links between upper and lower ranges
        def set_ranges(change):
            if grid_type.value == 'percentile':
                if change['owner'] == range_min or change[
                        'owner'] == num_grid_points:
                    range_max.value = max(
                        range_max.value,
                        range_min.value + num_grid_points.value)
                if change['owner'] == range_max:
                    range_min.value = min(
                        range_min.value,
                        range_max.value - num_grid_points.value)
            else:
                if change['owner'] == range_min:
                    range_max.value = max(range_max.value, range_min.value)
                if change['owner'] == range_max:
                    range_min.value = min(range_min.value, range_max.value)

        range_min.observe(set_ranges, names=['value'])
        range_max.observe(set_ranges, names=['value'])
        num_grid_points.observe(set_ranges, names=['value'])

        # center
        center = Checkbox(description='Center the plot', value=True)
        options_map['center'] = center

        # plot_pts_dist
        plot_pts_dist = Checkbox(description='Plot data points distribution',
                                 value=True)
        options_map['plot_pts_dist'] = plot_pts_dist

        # x_quantile
        x_quantile = Checkbox(description='X-axis as quantiles', value=False)
        options_map['x_quantile'] = x_quantile

        # show_percentile
        show_percentile = Checkbox(description='Show precentile buckets',
                                   value=False)
        options_map['show_percentile'] = show_percentile

        # lines
        lines = Checkbox(description='Plot lines - ICE plot', value=False)
        options_map['lines'] = lines

        # frac_to_plot
        frac_to_plot = BoundedFloatText(
            description='Lines to plot:',
            value=1,
            description_tooltip=
            'How many lines to plot, can be a integer or a float.\n'
            ' - integer values higher than 1 are interpreted as absolute amount\n'
            ' - floats are interpreted as fraction (e.g. 0.5 means half of all possible lines)',
            style=style,
            disabled=True)
        options_map['frac_to_plot'] = frac_to_plot

        # cluster
        cluster = Checkbox(description='Cluster lines',
                           value=False,
                           disabled=True)
        options_map['cluster'] = cluster

        # n_cluster_centers
        n_cluster_centers = BoundedIntText(
            value=10,
            min=1,
            max=999999,
            step=1,
            description='Number of cluster centers:',
            style=style,
            description_tooltip='Number of cluster centers for lines',
            disabled=True)
        options_map['n_cluster_centers'] = n_cluster_centers

        # cluster method
        cluster_method = Dropdown(
            description='Cluster method',
            style=style,
            options={
                'KMeans': 'accurate',
                'MiniBatchKMeans': 'approx'
            },
            description_tooltip='Method to use for clustering of lines',
            disabled=True)
        options_map['cluster_method'] = cluster_method

        # set up disabling of lines related options
        def disable_lines(change):
            frac_to_plot.disabled = not change['new']
            cluster.disabled = not change['new']
            n_cluster_centers.disabled = not (change['new'] and cluster.value)
            cluster_method.disabled = not (change['new'] and cluster.value)

        lines.observe(disable_lines, names=['value'])

        # set up disabling of clustering options
        def disable_clustering(change):
            n_cluster_centers.disabled = not (cluster.value and change['new'])
            cluster_method.disabled = not (cluster.value and change['new'])

        cluster.observe(disable_clustering, names=['value'])

        grid[0, :] = feature
        grid[1, 0] = num_grid_points
        grid[1, 1] = grid_type
        grid[2, 0] = cust_range
        grid[2, 1] = cust_grid_points
        grid[3, 0] = range_min
        grid[3, 1] = range_max
        grid[4, 0] = center
        grid[4, 1] = plot_pts_dist
        grid[5, 0] = x_quantile
        grid[5, 1] = show_percentile
        grid[6, :] = lines
        grid[7, :] = frac_to_plot
        grid[8, :] = cluster
        grid[9, 0] = n_cluster_centers
        grid[9, 1] = cluster_method

        return options_map, grid
def _get_value_widget(obj, index=None):
    wdict = {}
    widget_bounds = _interactive_slider_bounds(obj, index=index)
    thismin = FloatText(
        value=widget_bounds['min'],
        description='min',
        layout=Layout(flex='0 1 auto', width='auto'),
    )
    thismax = FloatText(
        value=widget_bounds['max'],
        description='max',
        layout=Layout(flex='0 1 auto', width='auto'),
    )
    current_value = obj.value if index is None else obj.value[index]
    if index is None:
        current_name = obj.name
    else:
        current_name = '{}'.format(index)
    widget = FloatSlider(value=current_value,
                         min=thismin.value,
                         max=thismax.value,
                         step=widget_bounds['step'],
                         description=current_name,
                         layout=Layout(flex='1 1 auto', width='auto'))

    def on_min_change(change):
        if widget.max > change['new']:
            widget.min = change['new']
            widget.step = np.abs(widget.max - widget.min) * 0.001

    def on_max_change(change):
        if widget.min < change['new']:
            widget.max = change['new']
            widget.step = np.abs(widget.max - widget.min) * 0.001

    thismin.observe(on_min_change, names='value')
    thismax.observe(on_max_change, names='value')
    # We store the link in the widget so that they are not deleted by the
    # garbage collector
    thismin._link = dlink((obj, "bmin"), (thismin, "value"))
    thismax._link = dlink((obj, "bmax"), (thismax, "value"))
    if index is not None:  # value is tuple, expanding

        def _interactive_tuple_update(value):
            """Callback function for the widgets, to update the value
            """
            obj.value = obj.value[:index] + (value['new'],) +\
                obj.value[index + 1:]

        widget.observe(_interactive_tuple_update, names='value')
    else:
        link((obj, "value"), (widget, "value"))

    container = HBox((thismin, widget, thismax))
    wdict["value"] = widget
    wdict["min"] = thismin
    wdict["max"] = thismax
    return {
        "widget": container,
        "wdict": wdict,
    }
Exemple #11
0
class SubstrateTab(object):
    def __init__(self):

        self.output_dir = '.'
        #        self.output_dir = 'tmpdir'

        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(
            width='900px',  # border='2px solid black',
            height=tab_height,
        )  #overflow_y='scroll')

        max_frames = 1
        self.mcds_plot = interactive(self.plot_substrate,
                                     frame=(0, max_frames),
                                     continuous_update=False)
        svg_plot_size = '700px'
        self.mcds_plot.layout.width = svg_plot_size
        self.mcds_plot.layout.height = svg_plot_size

        self.max_frames = BoundedIntText(
            min=0,
            max=99999,
            value=max_frames,
            description='Max',
            layout=Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.field_min_max = {'dummy': [0., 1.]}
        # hacky I know, but make a dict that's got (key,value) reversed from the dict in the Dropdown below
        self.field_dict = {0: 'dummy'}

        self.mcds_field = Dropdown(
            options={'dummy': 0},
            value=0,
            #     description='Field',
            layout=Layout(width=constWidth))
        # print("substrate __init__: self.mcds_field.value=",self.mcds_field.value)
        #        self.mcds_field.observe(self.mcds_field_cb)
        self.mcds_field.observe(self.mcds_field_changed_cb)

        # self.field_cmap = Text(
        #     value='viridis',
        #     description='Colormap',
        #     disabled=True,
        #     layout=Layout(width=constWidth),
        # )
        self.field_cmap = Dropdown(
            options=['viridis', 'jet', 'YlOrRd'],
            value='viridis',
            #     description='Field',
            layout=Layout(width=constWidth))
        #self.field_cmap.observe(self.plot_substrate)
        #        self.field_cmap.observe(self.plot_substrate)
        self.field_cmap.observe(self.mcds_field_cb)

        self.cmap_fixed = Checkbox(
            description='Fix',
            disabled=False,
            #           layout=Layout(width=constWidth2),
        )

        self.save_min_max = Button(
            description='Save',  #style={'description_width': 'initial'},
            button_style=
            'success',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Save min/max for this substrate',
            disabled=True,
            layout=Layout(width='90px'))

        def save_min_max_cb(b):
            #            field_name = self.mcds_field.options[]
            #            field_name = next(key for key, value in self.mcds_field.options.items() if value == self.mcds_field.value)
            field_name = self.field_dict[self.mcds_field.value]
            #            print(field_name)
            #            self.field_min_max = {'oxygen': [0., 30.], 'glucose': [0., 1.], 'H+ ions': [0., 1.], 'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 1.]}
            self.field_min_max[field_name][0] = self.cmap_min.value
            self.field_min_max[field_name][1] = self.cmap_max.value
#            print(self.field_min_max)

        self.save_min_max.on_click(save_min_max_cb)

        self.cmap_min = FloatText(
            description='Min',
            value=0,
            step=0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_min.observe(self.mcds_field_cb)

        self.cmap_max = FloatText(
            description='Max',
            value=38,
            step=0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_max.observe(self.mcds_field_cb)

        def cmap_fixed_cb(b):
            if (self.cmap_fixed.value):
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.save_min_max.disabled = False
            else:
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.save_min_max.disabled = True
#            self.mcds_field_cb()

        self.cmap_fixed.observe(cmap_fixed_cb)

        field_cmap_row2 = HBox([self.field_cmap, self.cmap_fixed])

        #        field_cmap_row3 = HBox([self.save_min_max, self.cmap_min, self.cmap_max])
        items_auto = [
            self.save_min_max,  #layout=Layout(flex='3 1 auto', width='auto'),
            self.cmap_min,
            self.cmap_max,
        ]
        box_layout = Layout(display='flex',
                            flex_flow='row',
                            align_items='stretch',
                            width='80%')
        field_cmap_row3 = Box(children=items_auto, layout=box_layout)

        #        field_cmap_row3 = Box([self.save_min_max, self.cmap_min, self.cmap_max])

        # mcds_tab = widgets.VBox([mcds_dir, mcds_plot, mcds_play], layout=tab_layout)
        mcds_params = VBox([
            self.mcds_field, field_cmap_row2, field_cmap_row3, self.max_frames
        ])  # mcds_dir
        #        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3,])  # mcds_dir

        #        self.tab = HBox([mcds_params, self.mcds_plot], layout=tab_layout)
        #        self.tab = HBox([mcds_params, self.mcds_plot])

        help_label = Label('select slider: drag or left/right arrows')
        row1 = Box([
            help_label,
            Box([self.max_frames, self.mcds_field, self.field_cmap],
                layout=Layout(border='0px solid black',
                              width='50%',
                              height='',
                              align_items='stretch',
                              flex_direction='row',
                              display='flex'))
        ])
        row2 = Box([self.cmap_fixed, self.cmap_min, self.cmap_max],
                   layout=Layout(border='0px solid black',
                                 width='50%',
                                 height='',
                                 align_items='stretch',
                                 flex_direction='row',
                                 display='flex'))
        self.tab = VBox([row1, row2, self.mcds_plot])

    #---------------------------------------------------
    def update_dropdown_fields(self, data_dir):
        # print('update_dropdown_fields called --------')
        self.output_dir = data_dir
        tree = None
        try:
            fname = os.path.join(self.output_dir, "initial.xml")
            tree = ET.parse(fname)
#            return
        except:
            print("Cannot open ", fname, " to get names of substrate fields.")
            return

        xml_root = tree.getroot()
        self.field_min_max = {}
        self.field_dict = {}
        dropdown_options = {}
        uep = xml_root.find('.//variables')
        comment_str = ""
        field_idx = 0
        if (uep):
            for elm in uep.findall('variable'):
                # print("-----> ",elm.attrib['name'])
                self.field_min_max[elm.attrib['name']] = [0., 1.]
                self.field_dict[field_idx] = elm.attrib['name']
                dropdown_options[elm.attrib['name']] = field_idx
                field_idx += 1

#        constWidth = '180px'
# print('options=',dropdown_options)
        self.mcds_field.value = 0
        self.mcds_field.options = dropdown_options
#         self.mcds_field = Dropdown(
# #            options={'oxygen': 0, 'glucose': 1},
#             options=dropdown_options,
#             value=0,
#             #     description='Field',
#            layout=Layout(width=constWidth)
#         )

    def update_max_frames_expected(
            self, value):  # called when beginning an interactive Run
        self.max_frames.value = value  # assumes naming scheme: "snapshot%08d.svg"
        self.mcds_plot.children[0].max = self.max_frames.value

    def update(self, rdir):
        self.output_dir = rdir
        if rdir == '':
            # self.max_frames.value = 0
            tmpdir = os.path.abspath('tmpdir')
            self.output_dir = tmpdir
            all_files = sorted(glob.glob(os.path.join(tmpdir, 'output*.xml')))
            if len(all_files) > 0:
                last_file = all_files[-1]
                self.max_frames.value = int(
                    last_file[-12:-4]
                )  # assumes naming scheme: "output%08d.xml"
                self.mcds_plot.update()
            return

        all_files = sorted(glob.glob(os.path.join(rdir, 'output*.xml')))
        if len(all_files) > 0:
            last_file = all_files[-1]
            self.max_frames.value = int(
                last_file[-12:-4])  # assumes naming scheme: "output%08d.xml"
            self.mcds_plot.update()

    def update_max_frames(self, _b):
        self.mcds_plot.children[0].max = self.max_frames.value

    def mcds_field_changed_cb(self, b):
        # print("mcds_field_changed_cb: self.mcds_field.value=",self.mcds_field.value)
        if (self.mcds_field.value == None):
            return
        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
        #        print('mcds_field_cb: '+field_name)
        self.cmap_min.value = self.field_min_max[field_name][0]
        self.cmap_max.value = self.field_min_max[field_name][1]
        self.mcds_plot.update()

    def mcds_field_cb(self, b):
        #self.field_index = self.mcds_field.value
        #        self.field_index = self.mcds_field.options.index(self.mcds_field.value) + 4
        #        self.field_index = self.mcds_field.options[self.mcds_field.value]
        self.field_index = self.mcds_field.value + 4

        # field_name = self.mcds_field.options[self.mcds_field.value]
        # self.cmap_min.value = self.field_min_max[field_name][0]  # oxygen, etc
        # self.cmap_max.value = self.field_min_max[field_name][1]  # oxygen, etc

        #        self.field_index = self.mcds_field.value + 4

        #        print('field_index=',self.field_index)
        self.mcds_plot.update()

    def plot_substrate(self, frame):
        # global current_idx, axes_max, gFileId, field_index
        fname = "output%08d_microenvironment0.mat" % frame
        xml_fname = "output%08d.xml" % frame
        # fullname = output_dir_str + fname

        #        fullname = fname
        full_fname = os.path.join(self.output_dir, fname)
        full_xml_fname = os.path.join(self.output_dir, xml_fname)
        #        self.output_dir = '.'

        #        if not os.path.isfile(fullname):
        if not os.path.isfile(full_fname):
            #            print("File does not exist: ", full_fname)
            #            print("No: ", full_fname)
            print("Once output files are generated, click the slider."
                  )  # No:  output00000000_microenvironment0.mat

            return

#        tree = ET.parse(xml_fname)
        tree = ET.parse(full_xml_fname)
        xml_root = tree.getroot()
        mins = round(int(float(xml_root.find(
            ".//current_time").text)))  # TODO: check units = mins
        hrs = int(mins / 60)
        days = int(hrs / 24)
        title_str = '%dd, %dh, %dm' % (int(days),
                                       (hrs % 24), mins - (hrs * 60))

        info_dict = {}
        #        scipy.io.loadmat(fullname, info_dict)
        scipy.io.loadmat(full_fname, info_dict)
        M = info_dict['multiscale_microenvironment']
        #     global_field_index = int(mcds_field.value)
        #     print('plot_substrate: field_index =',field_index)
        f = M[
            self.
            field_index, :]  # 4=tumor cells field, 5=blood vessel density, 6=growth substrate
        # plt.clf()
        # my_plot = plt.imshow(f.reshape(400,400), cmap='jet', extent=[0,20, 0,20])

        self.fig = plt.figure(figsize=(
            7.2, 6))  # this strange figsize results in a ~square contour plot
        #     fig.set_tight_layout(True)
        #     ax = plt.axes([0, 0.05, 0.9, 0.9 ]) #left, bottom, width, height
        #     ax = plt.axes([0, 0.0, 1, 1 ])
        #     cmap = plt.cm.viridis # Blues, YlOrBr, ...
        #     im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #     ax.grid(False)

        N = int(math.sqrt(len(M[0, :])))
        grid2D = M[0, :].reshape(N, N)
        xvec = grid2D[0, :]

        num_contours = 15
        #        levels = MaxNLocator(nbins=10).tick_values(vmin, vmax)
        levels = MaxNLocator(nbins=num_contours).tick_values(
            self.cmap_min.value, self.cmap_max.value)
        if (self.cmap_fixed.value):
            my_plot = plt.contourf(xvec,
                                   xvec,
                                   M[self.field_index, :].reshape(N, N),
                                   levels=levels,
                                   extend='both',
                                   cmap=self.field_cmap.value)
        else:
            #        my_plot = plt.contourf(xvec, xvec, M[self.field_index, :].reshape(N,N), num_contours, cmap=self.field_cmap.value)
            my_plot = plt.contourf(xvec,
                                   xvec,
                                   M[self.field_index, :].reshape(N, N),
                                   num_contours,
                                   cmap=self.field_cmap.value)

        plt.title(title_str)
        plt.colorbar(my_plot)
        axes_min = 0
        axes_max = 2000
Exemple #12
0
class Dashboard(VBox):
    """
    Build the dashboard for Jupyter widgets. Requires running
    in a notebook/jupyterlab.
    """
    def __init__(self, net, width="95%", height="550px", play_rate=0.5):
        self._ignore_layer_updates = False
        self.player = _Player(self, play_rate)
        self.player.start()
        self.net = net
        r = random.randint(1, 1000000)
        self.class_id = "picture-dashboard-%s-%s" % (self.net.name, r)
        self._width = width
        self._height = height
        ## Global widgets:
        style = {"description_width": "initial"}
        self.feature_columns = IntText(description="Detail columns:",
                                       value=self.net.config["dashboard.features.columns"],
                                       min=0,
                                       max=1024,
                                       style=style)
        self.feature_scale = FloatText(description="Detail scale:",
                                       value=self.net.config["dashboard.features.scale"],
                                       min=0.1,
                                       max=10,
                                       style=style)
        self.feature_columns.observe(self.regenerate, names='value')
        self.feature_scale.observe(self.regenerate, names='value')
        ## Hack to center SVG as justify-content is broken:
        self.net_svg = HTML(value="""<p style="text-align:center">%s</p>""" % ("",), layout=Layout(
            width=self._width, overflow_x='auto', overflow_y="auto",
            justify_content="center"))
        # Make controls first:
        self.output = Output()
        controls = self.make_controls()
        config = self.make_config()
        super().__init__([config, controls, self.net_svg, self.output])

    def propagate(self, inputs):
        """
        Propagate inputs through the dashboard view of the network.
        """
        if dynamic_pictures_check():
            return self.net.propagate(inputs, class_id=self.class_id, update_pictures=True)
        else:
            self.regenerate(inputs=input)

    def goto(self, position):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            return
        if self.control_select.value == "Train":
            length = len(self.net.dataset.train_inputs)
        elif self.control_select.value == "Test":
            length = len(self.net.dataset.test_inputs)
        #### Position it:
        if position == "begin":
            self.control_slider.value = 0
        elif position == "end":
            self.control_slider.value = length - 1
        elif position == "prev":
            if self.control_slider.value - 1 < 0:
                self.control_slider.value = length - 1 # wrap around
            else:
                self.control_slider.value = max(self.control_slider.value - 1, 0)
        elif position == "next":
            if self.control_slider.value + 1 > length - 1:
                self.control_slider.value = 0 # wrap around
            else:
                self.control_slider.value = min(self.control_slider.value + 1, length - 1)
        self.position_text.value = self.control_slider.value


    def change_select(self, change=None):
        """
        """
        self.update_control_slider(change)
        self.regenerate()

    def update_control_slider(self, change=None):
        self.net.config["dashboard.dataset"] = self.control_select.value
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            self.control_slider.value = 0
            self.position_text.value = 0
            self.control_slider.disabled = True
            self.position_text.disabled = True
            for child in self.control_buttons.children:
                if not hasattr(child, "icon") or child.icon != "refresh":
                    child.disabled = True
            return
        if self.control_select.value == "Test":
            self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
            minmax = (0, max(len(self.net.dataset.test_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.test_inputs) == 0:
                disabled = True
            else:
                disabled = False
        elif self.control_select.value == "Train":
            self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
            minmax = (0, max(len(self.net.dataset.train_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.train_inputs) == 0:
                disabled = True
            else:
                disabled = False
        self.control_slider.disabled = disabled
        self.position_text.disbaled = disabled
        self.position_text.value = self.control_slider.value
        for child in self.control_buttons.children:
            if not hasattr(child, "icon") or child.icon != "refresh":
                child.disabled = disabled

    def update_zoom_slider(self, change):
        if change["name"] == "value":
            self.net.config["svg_scale"] = self.zoom_slider.value
            self.regenerate()

    def update_position_text(self, change):
        # {'name': 'value', 'old': 2, 'new': 3, 'owner': IntText(value=3, layout=Layout(width='100%')), 'type': 'change'}
        self.control_slider.value = change["new"]

    def get_current_input(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_inputs[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_inputs[self.control_slider.value]

    def get_current_targets(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_targets[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_targets[self.control_slider.value]

    def update_slider_control(self, change):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            return
        if change["name"] == "value":
            self.position_text.value = self.control_slider.value
            if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.train_inputs[self.control_slider.value],
                                    targets=self.net.dataset.train_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.train_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.train_inputs[self.control_slider.value],
                                                   cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]:
                    if len(self.net.output_bank_order) == 1: ## FIXME: use minmax of output bank
                        self.net.display_component([self.net.dataset.train_targets[self.control_slider.value]],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        self.net.display_component(self.net.dataset.train_targets[self.control_slider.value],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.train_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.train_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors",  class_id=self.class_id, minmax=(-1, 1))
            elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.test_inputs[self.control_slider.value],
                                    targets=self.net.dataset.test_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.test_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.test_inputs[self.control_slider.value],
                                               cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]: ## FIXME: use minmax of output bank
                    self.net.display_component([self.net.dataset.test_targets[self.control_slider.value]],
                                               "targets",
                                               class_id=self.class_id,
                                               minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.test_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.test_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1))

    def toggle_play(self, button):
        ## toggle
        if self.button_play.description == "Play":
            self.button_play.description = "Stop"
            self.button_play.icon = "pause"
            self.player.resume()
        else:
            self.button_play.description = "Play"
            self.button_play.icon = "play"
            self.player.pause()

    def prop_one(self, button=None):
        self.update_slider_control({"name": "value"})

    def regenerate(self, button=None, inputs=None, targets=None):
        ## Protection when deleting object on shutdown:
        if isinstance(button, dict) and 'new' in button and button['new'] is None:
            return
        ## Update the config:
        self.net.config["dashboard.features.bank"] = self.feature_bank.value
        self.net.config["dashboard.features.columns"] = self.feature_columns.value
        self.net.config["dashboard.features.scale"] = self.feature_scale.value
        inputs = inputs if inputs is not None else self.get_current_input()
        targets = targets if targets is not None else self.get_current_targets()
        features = None
        if self.feature_bank.value in self.net.layer_dict.keys() and inputs is not None:
            if self.net.model is not None:
                features = self.net.propagate_to_features(self.feature_bank.value, inputs,
                                                          cols=self.feature_columns.value,
                                                          scale=self.feature_scale.value, display=False)
        svg = """<p style="text-align:center">%s</p>""" % (self.net.to_svg(
            inputs=inputs,
            targets=targets,
            class_id=self.class_id,
            highlights={self.feature_bank.value: {
                "border_color": "orange",
                "border_width": 30,
            }}))
        if inputs is not None and features is not None:
            html_horizontal = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top" style="width: 50%%;">%s</td>
  <td valign="top" align="center" style="width: 50%%;"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            html_vertical = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top">%s</td>
</tr>
<tr>
  <td valign="top" align="center"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            self.net_svg.value = (html_vertical if self.net.config["svg_rotate"] else html_horizontal) % (
                svg, "%s details" % self.feature_bank.value, features)
        else:
            self.net_svg.value = svg

    def make_colormap_image(self, colormap_name):
        from .layers import Layer
        if not colormap_name:
            colormap_name = get_colormap()
        layer = Layer("Colormap", 100)
        minmax = layer.get_act_minmax()
        image = layer.make_image(np.arange(minmax[0], minmax[1], .01),
                                 colormap_name,
                                 {"pixels_per_unit": 1,
                                  "svg_rotate": self.net.config["svg_rotate"]}).resize((300, 25))
        return image

    def set_attr(self, obj, attr, value):
        if value not in [{}, None]: ## value is None when shutting down
            if isinstance(value, dict):
                value = value["value"]
            if isinstance(obj, dict):
                obj[attr] = value
            else:
                setattr(obj, attr, value)
            ## was crashing on Widgets.__del__, if get_ipython() no longer existed
            self.regenerate()

    def make_controls(self):
        layout = Layout(width='100%', height="100%")
        button_begin = Button(icon="fast-backward", layout=layout)
        button_prev = Button(icon="backward", layout=layout)
        button_next = Button(icon="forward", layout=layout)
        button_end = Button(icon="fast-forward", layout=layout)
        #button_prop = Button(description="Propagate", layout=Layout(width='100%'))
        #button_train = Button(description="Train", layout=Layout(width='100%'))
        self.button_play = Button(icon="play", description="Play", layout=layout)
        step_down = Button(icon="sort-down", layout=Layout(width="95%", height="100%"))
        step_up = Button(icon="sort-up", layout=Layout(width="95%", height="100%"))
        up_down = HBox([step_down, step_up], layout=Layout(width="100%", height="100%"))
        refresh_button = Button(icon="refresh", layout=Layout(width="25%", height="100%"))

        self.position_text = IntText(value=0, layout=layout)

        self.control_buttons = HBox([
            button_begin,
            button_prev,
            #button_train,
            self.position_text,
            button_next,
            button_end,
            self.button_play,
            up_down,
            refresh_button
        ], layout=Layout(width='100%', height="100%"))
        length = (len(self.net.dataset.train_inputs) - 1) if len(self.net.dataset.train_inputs) > 0 else 0
        self.control_slider = IntSlider(description="Dataset index",
                                   continuous_update=False,
                                   min=0,
                                   max=max(length, 0),
                                   value=0,
                                   layout=Layout(width='100%'))
        if self.net.config["dashboard.dataset"] == "Train":
            length = len(self.net.dataset.train_inputs)
        else:
            length = len(self.net.dataset.test_inputs)
        self.total_text = Label(value="of %s" % length, layout=Layout(width="100px"))
        self.zoom_slider = FloatSlider(description="Zoom",
                                       continuous_update=False,
                                       min=0, max=1.0,
                                       style={"description_width": 'initial'},
                                       layout=Layout(width="65%"),
                                       value=self.net.config["svg_scale"] if self.net.config["svg_scale"] is not None else 0.5)

        ## Hook them up:
        button_begin.on_click(lambda button: self.goto("begin"))
        button_end.on_click(lambda button: self.goto("end"))
        button_next.on_click(lambda button: self.goto("next"))
        button_prev.on_click(lambda button: self.goto("prev"))
        self.button_play.on_click(self.toggle_play)
        self.control_slider.observe(self.update_slider_control, names='value')
        refresh_button.on_click(lambda widget: (self.update_control_slider(),
                                                self.output.clear_output(),
                                                self.regenerate()))
        step_down.on_click(lambda widget: self.move_step("down"))
        step_up.on_click(lambda widget: self.move_step("up"))
        self.zoom_slider.observe(self.update_zoom_slider, names='value')
        self.position_text.observe(self.update_position_text, names='value')
        # Put them together:
        controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")),
                         self.control_buttons], layout=Layout(width='100%'))

        #net_page = VBox([control, self.net_svg], layout=Layout(width='95%'))
        controls.on_displayed(lambda widget: self.regenerate())
        return controls

    def move_step(self, direction):
        """
        Move the layer stepper up/down through network
        """
        options = [""] + [layer.name for layer in self.net.layers]
        index = options.index(self.feature_bank.value)
        if direction == "up":
            new_index = (index + 1) % len(options)
        else: ## down
            new_index = (index - 1) % len(options)
        self.feature_bank.value = options[new_index]
        self.regenerate()

    def make_config(self):
        layout = Layout()
        style = {"description_width": "initial"}
        checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"],
                             layout=layout, style=style)
        checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value')
        checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"],
                             layout=layout, style=style)
        checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value')

        hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:",
                         style=style, layout=layout)
        hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value')
        vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:",
                         style=style, layout=layout)
        vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value')
        self.feature_bank = Select(description="Details:", value=self.net.config["dashboard.features.bank"],
                              options=[""] + [layer.name for layer in self.net.layers],
                              rows=1)
        self.feature_bank.observe(self.regenerate, names='value')
        self.control_select = Select(
            options=['Test', 'Train'],
            value=self.net.config["dashboard.dataset"],
            description='Dataset:',
            rows=1
        )
        self.control_select.observe(self.change_select, names='value')
        column1 = [self.control_select,
                   self.zoom_slider,
                   hspace,
                   vspace,
                   HBox([checkbox1, checkbox2]),
                   self.feature_bank,
                   self.feature_columns,
                   self.feature_scale
        ]
        ## Make layer selectable, and update-able:
        column2 = []
        layer = self.net.layers[-1]
        self.layer_select = Select(description="Layer:", value=layer.name,
                                   options=[layer.name for layer in
                                            self.net.layers],
                                   rows=1)
        self.layer_select.observe(self.update_layer_selection, names='value')
        column2.append(self.layer_select)
        self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout)
        self.layer_visible_checkbox.observe(self.update_layer, names='value')
        column2.append(self.layer_visible_checkbox)
        self.layer_colormap = Select(description="Colormap:",
                                     options=[""] + AVAILABLE_COLORMAPS,
                                     value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1)
        self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)))
        self.layer_colormap.observe(self.update_layer, names='value')
        column2.append(self.layer_colormap)
        column2.append(self.layer_colormap_image)
        ## get dynamic minmax; if you change it it will set it in layer as override:
        minmax = layer.get_act_minmax()
        self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style)
        self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style)
        self.layer_mindim.observe(self.update_layer, names='value')
        self.layer_maxdim.observe(self.update_layer, names='value')
        column2.append(self.layer_mindim)
        column2.append(self.layer_maxdim)
        output_shape = layer.get_output_shape()
        self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style)
        self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout)
        self.layer_feature.observe(self.update_layer, names='value')
        column2.append(self.layer_feature)
        self.svg_rotate = Checkbox(description="Rotate network",
                                   value=self.net.config["svg_rotate"],
                                   style={"description_width": 'initial'},
                                   layout=Layout(width="52%"))
        self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value')
        self.save_config_button = Button(icon="save", layout=Layout(width="10%"))
        self.save_config_button.on_click(self.save_config)
        column2.append(HBox([self.svg_rotate, self.save_config_button]))
        config_children = HBox([VBox(column1, layout=Layout(width="100%")),
                                VBox(column2, layout=Layout(width="100%"))])
        accordion = Accordion(children=[config_children])
        accordion.set_title(0, self.net.name)
        accordion.selected_index = None
        return accordion

    def save_config(self, widget=None):
        self.net.save_config()

    def update_layer(self, change):
        """
        Update the layer object, and redisplay.
        """
        if self._ignore_layer_updates:
            return
        ## The rest indicates a change to a display variable.
        ## We need to save the value in the layer, and regenerate
        ## the display.
        # Get the layer:
        layer = self.net[self.layer_select.value]
        # Save the changed value in the layer:
        layer.feature = self.layer_feature.value
        layer.visible = self.layer_visible_checkbox.value
        ## These three, dealing with colors of activations,
        ## can be done with a prop_one():
        if "color" in change["owner"].description.lower():
            ## Matches: Colormap, lefmost color, rightmost color
            ## overriding dynamic minmax!
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.colormap = self.layer_colormap.value if self.layer_colormap.value else None
            self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
            self.prop_one()
        else:
            self.regenerate()

    def update_layer_selection(self, change):
        """
        Just update the widgets; don't redraw anything.
        """
        ## No need to redisplay anything
        self._ignore_layer_updates = True
        ## First, get the new layer selected:
        layer = self.net[self.layer_select.value]
        ## Now, let's update all of the values without updating:
        self.layer_visible_checkbox.value = layer.visible
        self.layer_colormap.value = layer.colormap if layer.colormap != "" else ""
        self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
        minmax = layer.get_act_minmax()
        self.layer_mindim.value = minmax[0]
        self.layer_maxdim.value = minmax[1]
        self.layer_feature.value = layer.feature
        self._ignore_layer_updates = False
Exemple #13
0
class PapayaConfigWidget(VBox):
    """A widget that displays widgets to adjust NLPapayaViewer image parameters."""

    lut_options = [
        "Grayscale",
        "Red Overlay",
        "Green Overlay",
        "Blue Overlay",
        "Gold",
        "Spectrum",
        "Overlay (Positives)",
        "Overlay (Negatives)",
    ]

    def __init__(self, viewer, *args, **kwargs):
        """
        Parameters
        ----------
        viewer: NlPapayaViewer
            associated viewer.
        """
        super().__init__(*args, **kwargs)

        self._viewer = viewer
        self._init_widgets()

        self.children = [
            VBox([
                VBox(
                    [self._hist],
                    layout=Layout(
                        height="auto",
                        margin="0px 0px 0px 0px",
                        padding="5px 5px 5px 5px",
                    ),
                ),
                VBox(
                    [
                        self._alpha,
                        self._lut,
                        self._nlut,
                        self._min,
                        self._minp,
                        self._max,
                        self._maxp,
                        self._sym,
                    ],
                    layout=Layout(width="230px"),
                ),
            ])
        ]

    def _init_widgets(self):
        """Initializes all configuration widgets. Possible image config parameters are:"""
        layout = Layout(width="200px", max_width="200px")

        self._alpha = FloatSlider(
            value=1,
            min=0,
            max=1.0,
            step=0.1,
            description="alpha:",
            description_tooltip="Overlay image alpha level (0 to 1).",
            disabled=False,
            continuous_update=True,
            orientation="horizontal",
            readout=True,
            readout_format=".1f",
            layout=layout,
        )

        self._lut = Dropdown(
            options=PapayaConfigWidget.lut_options,
            value="Red Overlay",
            description="lut:",
            description_tooltip="The color table name.",
            layout=layout,
        )

        self._nlut = Dropdown(
            options=PapayaConfigWidget.lut_options,
            value="Red Overlay",
            description="negative-lut:",
            description_tooltip=
            "The color table name used by the negative side of the parametric pair.",
            layout=layout,
        )

        self._min = FloatText(
            value=None,
            description="min:",
            description_tooltip="The display range minimum.",
            step=0.01,
            continuous_update=True,
            disabled=False,
            layout=layout,
        )

        self._minp = BoundedFloatText(
            value=None,
            min=0,
            max=100,
            step=1,
            continuous_update=True,
            description="min %:",
            description_tooltip=
            "The display range minimum as a percentage of image max.",
            disabled=False,
            layout=layout,
        )

        self._max = FloatText(
            value=None,
            description="max:",
            description_tooltip="The display range maximum.",
            step=0.01,
            continuous_update=True,
            disabled=False,
            layout=layout,
        )

        self._maxp = BoundedFloatText(
            value=None,
            min=0,
            max=100,
            step=1,
            continuous_update=True,
            description="max %:",
            description_tooltip=
            "The display range minimum as a percentage of image max.",
            disabled=False,
            layout=layout,
        )

        self._sym = Checkbox(
            value=False,
            description="symmetric",
            description_tooltip=
            "When selected, sets the negative range of a parametric pair to the same size as the positive range.",
            disabled=False,
            layout=layout,
        )

        # figure to display histogram of image data
        fig = Figure()
        fig.update_layout(
            height=300,
            margin=dict(l=15, t=15, b=15, r=15, pad=4),
            showlegend=True,
            legend_orientation="h",
        )

        self._hist = FigureWidget(fig)
        self._hist.add_trace(
            Histogram(x=[], name="All image data", visible="legendonly"))
        self._hist.add_trace(Histogram(x=[], name="Image data without 0s"))

        self._handlers = defaultdict()

    def _set_values(self, config, range, data):
        """Sets config values from the specified `config` and creates histogram for `data`.

        Parameters
        ----------
        config : dict
            configuration parameters for the image. Possible keywords are:
            alpha : int
                the overlay image alpha level (0 to 1).
            lut : str
                the color table name.
            negative_lut : str
                the color table name used by the negative side of the parametric pair.
            max : int
                the display range maximum.
            maxPercent : int
                the display range maximum as a percentage of image max.
            min : int
                the display range minimum.
            minPercent : int
                the display range minimum as a percentage of image min.
           symmetric : bool
                if true, sets the negative range of a parametric pair to the same size as the positive range.
        range: float
            range of image values.
        data: []
           flattened image data.
        """
        self._alpha.value = config.get("alpha", 1)
        self._lut.value = config.get("lut", PapayaConfigWidget.lut_options[1])
        self._nlut.value = config.get("negative_lut",
                                      PapayaConfigWidget.lut_options[1])
        self._min.value = config.get("min", 0)
        self._minp.value = self._get_per_from_value(range,
                                                    config.get("min", 0))
        self._max.value = config.get("max", 0.1)
        self._maxp.value = self._get_per_from_value(range,
                                                    config.get("max", 0.1))
        self._sym.value = config.get("symmetric", "false") == "true"

        # set histogram data
        self._hist.data[0].x = data
        # leave out 0 values
        self._hist.data[1].x = [] if (data == []
                                      or data is None) else data[data != 0]

    def _add_handlers(self, image):
        """Add config widget event handlers to change the config values for the specified `image`.

        Parameters
        ----------
        image: neurolang_ipywidgets.PapayaImage
            image whose config values will be viewed/modified using this config widget.
        """

        # Dropdown does not support resetting event handlers after Dropdown.unobserve_all is called
        # So handlers are stored to be removed individually
        # github issue https://github.com/jupyter-widgets/ipywidgets/issues/1868

        self._handlers["alpha"] = partial(self._config_changed,
                                          image=image,
                                          name="alpha")
        self._handlers["lut"] = partial(self._config_changed,
                                        image=image,
                                        name="lut")
        self._handlers["nlut"] = partial(self._config_changed,
                                         image=image,
                                         name="negative_lut")
        self._handlers["min"] = partial(self._config_changed,
                                        image=image,
                                        name="min")
        self._handlers["minp"] = partial(self._set_min_max,
                                         image=image,
                                         name="minPercent")
        self._handlers["max"] = partial(self._config_changed,
                                        image=image,
                                        name="max")
        self._handlers["maxp"] = partial(self._set_min_max,
                                         image=image,
                                         name="maxPercent")
        self._handlers["sym"] = partial(self._config_bool_changed,
                                        image=image,
                                        name="symmetric")

        self._alpha.observe(self._handlers["alpha"], names="value")

        self._lut.observe(self._handlers["lut"], names="value")

        self._nlut.observe(self._handlers["nlut"], names="value")

        self._min.observe(self._handlers["min"], names="value")

        self._minp.observe(self._handlers["minp"], names="value")

        self._max.observe(self._handlers["max"], names="value")

        self._maxp.observe(self._handlers["maxp"], names="value")

        self._sym.observe(self._handlers["sym"], names="value")

    def _remove_handlers(self):
        """Removes all event handlers set for the config widgets."""
        if len(self._handlers):
            self._alpha.unobserve(self._handlers["alpha"], names="value")
            self._lut.unobserve(self._handlers["lut"], names="value")
            self._nlut.unobserve(self._handlers["nlut"], names="value")
            self._min.unobserve(self._handlers["min"], names="value")
            self._minp.unobserve(self._handlers["minp"], names="value")
            self._max.unobserve(self._handlers["max"], names="value")
            self._maxp.unobserve(self._handlers["maxp"], names="value")
            self._sym.unobserve(self._handlers["sym"], names="value")

            self._handlers = defaultdict()

    @debounce(0.5)
    def _config_changed(self, change, image, name):
        if name == "min":
            self._minp.unobserve(self._handlers["minp"], names="value")
            self._minp.value = self._get_per_from_value(
                image.range, change.new)
            self._minp.observe(self._handlers["minp"], names="value")
        elif name == "max":
            self._maxp.unobserve(self._handlers["maxp"], names="value")
            self._maxp.value = self._get_per_from_value(
                image.range, change.new)
            self._maxp.observe(self._handlers["maxp"], names="value")

        self._set_config(image, name, change.new)

    @debounce(0.5)
    def _set_min_max(self, change, image, name):
        if name == "minPercent":
            self._min.unobserve(self._handlers["min"], names="value")
            self._min.value = self._get_value_from_per(image.range, change.new)
            self._set_config(image, "min", self._min.value)
            self._min.observe(self._handlers["min"], names="value")
        elif name == "maxPercent":
            self._max.unobserve(self._handlers["max"], names="value")
            self._max.value = self._get_value_from_per(image.range, change.new)
            self._set_config(image, "max", self._max.value)
            self._max.observe(self._handlers["max"], names="value")

    def _config_bool_changed(self, change, image, name):
        value = "false"
        if change.new:
            value = "true"
        self._set_config(image, name, value)

    def _set_config(self, image, key, value):
        image.config[key] = value
        self._viewer.set_images()

    def _get_per_from_value(self, range, value):
        return round(value * 100 / range, 0)

    def _get_value_from_per(self, range, per):
        return round(per * range / 100, 2)

    def set_image(self, image):
        """Sets the image whose config values will be viewed/modified using this config widget.
        If image is `None`, all config values are reset.

        Parameters
        ----------
        image: neurolang_ipywidgets.PapayaImage
            image whose config values will be viewed/modified using this config widget.
        """
        if image:
            self._remove_handlers()
            self._set_values(image.config, image.range,
                             image.image.get_fdata().flatten())
            self._add_handlers(image)
        else:
            self.reset()

    def reset(self):
        """Resets values for all config widgets."""
        self._remove_handlers()
        self._set_values({}, 100, [])
        self.layout.visibility = "hidden"