Пример #1
0
    def make_config(self):
        layout = Layout()
        style = {"description_width": "initial"}
        checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"],
                             layout=layout, style=style)
        checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value')
        checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"],
                             layout=layout, style=style)
        checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value')

        hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:",
                         style=style, layout=layout)
        hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value')
        vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:",
                         style=style, layout=layout)
        vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value')
        self.feature_bank = Select(description="Features:", value=self.net.config["dashboard.features.bank"],
                              options=[""] + [layer.name for layer in self.net.layers if self.net._layer_has_features(layer.name)],
                              rows=1)
        self.feature_bank.observe(self.regenerate, names='value')
        self.control_select = Select(
            options=['Test', 'Train'],
            value=self.net.config["dashboard.dataset"],
            description='Dataset:',
            rows=1
        )
        self.control_select.observe(self.change_select, names='value')
        column1 = [self.control_select,
                   self.zoom_slider,
                   hspace,
                   vspace,
                   HBox([checkbox1, checkbox2]),
                   self.feature_bank,
                   self.feature_columns,
                   self.feature_scale
        ]
        ## Make layer selectable, and update-able:
        column2 = []
        layer = self.net.layers[-1]
        self.layer_select = Select(description="Layer:", value=layer.name,
                                   options=[layer.name for layer in
                                            self.net.layers],
                                   rows=1)
        self.layer_select.observe(self.update_layer_selection, names='value')
        column2.append(self.layer_select)
        self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout)
        self.layer_visible_checkbox.observe(self.update_layer, names='value')
        column2.append(self.layer_visible_checkbox)
        self.layer_colormap = Select(description="Colormap:",
                                     options=[""] + AVAILABLE_COLORMAPS,
                                     value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1)
        self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)))
        self.layer_colormap.observe(self.update_layer, names='value')
        column2.append(self.layer_colormap)
        column2.append(self.layer_colormap_image)
        ## get dynamic minmax; if you change it it will set it in layer as override:
        minmax = layer.get_act_minmax()
        self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style)
        self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style)
        self.layer_mindim.observe(self.update_layer, names='value')
        self.layer_maxdim.observe(self.update_layer, names='value')
        column2.append(self.layer_mindim)
        column2.append(self.layer_maxdim)
        output_shape = layer.get_output_shape()
        self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style)
        self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout)
        self.layer_feature.observe(self.update_layer, names='value')
        column2.append(self.layer_feature)
        self.svg_rotate = Checkbox(description="Rotate network",
                                   value=self.net.config["svg_rotate"],
                                   style={"description_width": 'initial'},
                                   layout=Layout(width="52%"))
        self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value')
        self.save_config_button = Button(icon="save", layout=Layout(width="10%"))
        self.save_config_button.on_click(self.save_config)
        column2.append(HBox([self.svg_rotate, self.save_config_button]))
        config_children = HBox([VBox(column1, layout=Layout(width="100%")),
                                VBox(column2, layout=Layout(width="100%"))])
        accordion = Accordion(children=[config_children])
        accordion.set_title(0, self.net.name)
        accordion.selected_index = None
        return accordion
Пример #2
0
class Dashboard(VBox):
    """
    Build the dashboard for Jupyter widgets. Requires running
    in a notebook/jupyterlab.
    """
    def __init__(self, net, width="95%", height="550px", play_rate=0.5):
        self._ignore_layer_updates = False
        self.player = _Player(self, play_rate)
        self.player.start()
        self.net = net
        r = random.randint(1, 1000000)
        self.class_id = "picture-dashboard-%s-%s" % (self.net.name, r)
        self._width = width
        self._height = height
        ## Global widgets:
        style = {"description_width": "initial"}
        self.feature_columns = IntText(description="Feature columns:",
                                       value=self.net.config["dashboard.features.columns"],
                                       min=0,
                                       max=1024,
                                       style=style)
        self.feature_scale = FloatText(description="Feature scale:",
                                       value=self.net.config["dashboard.features.scale"],
                                       min=0.1,
                                       max=10,
                                       style=style)
        self.feature_columns.observe(self.regenerate, names='value')
        self.feature_scale.observe(self.regenerate, names='value')
        ## Hack to center SVG as justify-content is broken:
        self.net_svg = HTML(value="""<p style="text-align:center">%s</p>""" % ("",), layout=Layout(
            width=self._width, overflow_x='auto', overflow_y="auto",
            justify_content="center"))
        # Make controls first:
        self.output = Output()
        controls = self.make_controls()
        config = self.make_config()
        super().__init__([config, controls, self.net_svg, self.output])

    def propagate(self, inputs):
        """
        Propagate inputs through the dashboard view of the network.
        """
        if dynamic_pictures_check():
            return self.net.propagate(inputs, class_id=self.class_id, update_pictures=True)
        else:
            self.regenerate(inputs=input)

    def goto(self, position):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            return
        if self.control_select.value == "Train":
            length = len(self.net.dataset.train_inputs)
        elif self.control_select.value == "Test":
            length = len(self.net.dataset.test_inputs)
        #### Position it:
        if position == "begin":
            self.control_slider.value = 0
        elif position == "end":
            self.control_slider.value = length - 1
        elif position == "prev":
            if self.control_slider.value - 1 < 0:
                self.control_slider.value = length - 1 # wrap around
            else:
                self.control_slider.value = max(self.control_slider.value - 1, 0)
        elif position == "next":
            if self.control_slider.value + 1 > length - 1:
                self.control_slider.value = 0 # wrap around
            else:
                self.control_slider.value = min(self.control_slider.value + 1, length - 1)
        self.position_text.value = self.control_slider.value


    def change_select(self, change=None):
        """
        """
        self.update_control_slider(change)
        self.regenerate()

    def update_control_slider(self, change=None):
        self.net.config["dashboard.dataset"] = self.control_select.value
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            self.control_slider.value = 0
            self.position_text.value = 0
            self.control_slider.disabled = True
            self.position_text.disabled = True
            for child in self.control_buttons.children:
                if not hasattr(child, "icon") or child.icon != "refresh":
                    child.disabled = True
            return
        if self.control_select.value == "Test":
            self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
            minmax = (0, max(len(self.net.dataset.test_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.test_inputs) == 0:
                disabled = True
            else:
                disabled = False
        elif self.control_select.value == "Train":
            self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
            minmax = (0, max(len(self.net.dataset.train_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.train_inputs) == 0:
                disabled = True
            else:
                disabled = False
        self.control_slider.disabled = disabled
        self.position_text.disbaled = disabled
        self.position_text.value = self.control_slider.value
        for child in self.control_buttons.children:
            if not hasattr(child, "icon") or child.icon != "refresh":
                child.disabled = disabled

    def update_zoom_slider(self, change):
        if change["name"] == "value":
            self.net.config["svg_scale"] = self.zoom_slider.value
            self.regenerate()

    def update_position_text(self, change):
        # {'name': 'value', 'old': 2, 'new': 3, 'owner': IntText(value=3, layout=Layout(width='100%')), 'type': 'change'}
        self.control_slider.value = change["new"]

    def get_current_input(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_inputs[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_inputs[self.control_slider.value]

    def get_current_targets(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_targets[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_targets[self.control_slider.value]

    def update_slider_control(self, change):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            return
        if change["name"] == "value":
            self.position_text.value = self.control_slider.value
            if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.train_inputs[self.control_slider.value],
                                    targets=self.net.dataset.train_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.train_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.train_inputs[self.control_slider.value],
                                               cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]:
                    if len(self.net.output_bank_order) == 1: ## FIXME: use minmax of output bank
                        self.net.display_component([self.net.dataset.train_targets[self.control_slider.value]],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        self.net.display_component(self.net.dataset.train_targets[self.control_slider.value],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.train_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.train_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors",  class_id=self.class_id, minmax=(-1, 1))
            elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.test_inputs[self.control_slider.value],
                                    targets=self.net.dataset.test_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.test_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.test_inputs[self.control_slider.value],
                                               cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]: ## FIXME: use minmax of output bank
                    self.net.display_component([self.net.dataset.test_targets[self.control_slider.value]],
                                               "targets",
                                               class_id=self.class_id,
                                               minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.test_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.test_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1))

    def toggle_play(self, button):
        ## toggle
        if self.button_play.description == "Play":
            self.button_play.description = "Stop"
            self.button_play.icon = "pause"
            self.player.resume()
        else:
            self.button_play.description = "Play"
            self.button_play.icon = "play"
            self.player.pause()

    def prop_one(self, button=None):
        self.update_slider_control({"name": "value"})

    def regenerate(self, button=None, inputs=None, targets=None):
        ## Protection when deleting object on shutdown:
        if isinstance(button, dict) and 'new' in button and button['new'] is None:
            return
        ## Update the config:
        self.net.config["dashboard.features.bank"] = self.feature_bank.value
        self.net.config["dashboard.features.columns"] = self.feature_columns.value
        self.net.config["dashboard.features.scale"] = self.feature_scale.value
        inputs = inputs if inputs is not None else self.get_current_input()
        targets = targets if targets is not None else self.get_current_targets()
        features = None
        if self.feature_bank.value in self.net.layer_dict.keys() and inputs is not None:
            if self.net.model is not None:
                features = self.net.propagate_to_features(self.feature_bank.value, inputs,
                                                          cols=self.feature_columns.value,
                                                          scale=self.feature_scale.value, display=False)
        svg = """<p style="text-align:center">%s</p>""" % (self.net.to_svg(inputs=inputs, targets=targets,
                                                                           class_id=self.class_id),)
        if inputs is not None and features is not None:
            html_horizontal = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top" style="width: 50%%;">%s</td>
  <td valign="top" align="center" style="width: 50%%;"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            html_vertical = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top">%s</td>
</tr>
<tr>
  <td valign="top" align="center"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            self.net_svg.value = (html_vertical if self.net.config["svg_rotate"] else html_horizontal) % (
                svg, "%s features" % self.feature_bank.value, features)
        else:
            self.net_svg.value = svg

    def make_colormap_image(self, colormap_name):
        from .layers import Layer
        if not colormap_name:
            colormap_name = get_colormap()
        layer = Layer("Colormap", 100)
        minmax = layer.get_act_minmax()
        image = layer.make_image(np.arange(minmax[0], minmax[1], .01),
                                 colormap_name,
                                 {"pixels_per_unit": 1,
                                  "svg_rotate": self.net.config["svg_rotate"]}).resize((300, 25))
        return image

    def set_attr(self, obj, attr, value):
        if value not in [{}, None]: ## value is None when shutting down
            if isinstance(value, dict):
                value = value["value"]
            if isinstance(obj, dict):
                obj[attr] = value
            else:
                setattr(obj, attr, value)
            ## was crashing on Widgets.__del__, if get_ipython() no longer existed
            self.regenerate()

    def make_controls(self):
        button_begin = Button(icon="fast-backward", layout=Layout(width='100%'))
        button_prev = Button(icon="backward", layout=Layout(width='100%'))
        button_next = Button(icon="forward", layout=Layout(width='100%'))
        button_end = Button(icon="fast-forward", layout=Layout(width='100%'))
        #button_prop = Button(description="Propagate", layout=Layout(width='100%'))
        #button_train = Button(description="Train", layout=Layout(width='100%'))
        self.button_play = Button(icon="play", description="Play", layout=Layout(width="100%"))
        refresh_button = Button(icon="refresh", layout=Layout(width="25%"))

        self.position_text = IntText(value=0, layout=Layout(width="100%"))

        self.control_buttons = HBox([
            button_begin,
            button_prev,
            #button_train,
            self.position_text,
            button_next,
            button_end,
            self.button_play,
            refresh_button
        ], layout=Layout(width='100%', height="50px"))
        length = (len(self.net.dataset.train_inputs) - 1) if len(self.net.dataset.train_inputs) > 0 else 0
        self.control_slider = IntSlider(description="Dataset index",
                                   continuous_update=False,
                                   min=0,
                                   max=max(length, 0),
                                   value=0,
                                   layout=Layout(width='100%'))
        if self.net.config["dashboard.dataset"] == "Train":
            length = len(self.net.dataset.train_inputs)
        else:
            length = len(self.net.dataset.test_inputs)
        self.total_text = Label(value="of %s" % length, layout=Layout(width="100px"))
        self.zoom_slider = FloatSlider(description="Zoom",
                                       continuous_update=False,
                                       min=0, max=1.0,
                                       style={"description_width": 'initial'},
                                       layout=Layout(width="65%"),
                                       value=self.net.config["svg_scale"] if self.net.config["svg_scale"] is not None else 0.5)

        ## Hook them up:
        button_begin.on_click(lambda button: self.goto("begin"))
        button_end.on_click(lambda button: self.goto("end"))
        button_next.on_click(lambda button: self.goto("next"))
        button_prev.on_click(lambda button: self.goto("prev"))
        self.button_play.on_click(self.toggle_play)
        self.control_slider.observe(self.update_slider_control, names='value')
        refresh_button.on_click(lambda widget: (self.update_control_slider(),
                                                self.output.clear_output(),
                                                self.regenerate()))
        self.zoom_slider.observe(self.update_zoom_slider, names='value')
        self.position_text.observe(self.update_position_text, names='value')
        # Put them together:
        controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")),
                         self.control_buttons], layout=Layout(width='100%'))

        #net_page = VBox([control, self.net_svg], layout=Layout(width='95%'))
        controls.on_displayed(lambda widget: self.regenerate())
        return controls

    def make_config(self):
        layout = Layout()
        style = {"description_width": "initial"}
        checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"],
                             layout=layout, style=style)
        checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value')
        checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"],
                             layout=layout, style=style)
        checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value')

        hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:",
                         style=style, layout=layout)
        hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value')
        vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:",
                         style=style, layout=layout)
        vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value')
        self.feature_bank = Select(description="Features:", value=self.net.config["dashboard.features.bank"],
                              options=[""] + [layer.name for layer in self.net.layers if self.net._layer_has_features(layer.name)],
                              rows=1)
        self.feature_bank.observe(self.regenerate, names='value')
        self.control_select = Select(
            options=['Test', 'Train'],
            value=self.net.config["dashboard.dataset"],
            description='Dataset:',
            rows=1
        )
        self.control_select.observe(self.change_select, names='value')
        column1 = [self.control_select,
                   self.zoom_slider,
                   hspace,
                   vspace,
                   HBox([checkbox1, checkbox2]),
                   self.feature_bank,
                   self.feature_columns,
                   self.feature_scale
        ]
        ## Make layer selectable, and update-able:
        column2 = []
        layer = self.net.layers[-1]
        self.layer_select = Select(description="Layer:", value=layer.name,
                                   options=[layer.name for layer in
                                            self.net.layers],
                                   rows=1)
        self.layer_select.observe(self.update_layer_selection, names='value')
        column2.append(self.layer_select)
        self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout)
        self.layer_visible_checkbox.observe(self.update_layer, names='value')
        column2.append(self.layer_visible_checkbox)
        self.layer_colormap = Select(description="Colormap:",
                                     options=[""] + AVAILABLE_COLORMAPS,
                                     value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1)
        self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)))
        self.layer_colormap.observe(self.update_layer, names='value')
        column2.append(self.layer_colormap)
        column2.append(self.layer_colormap_image)
        ## get dynamic minmax; if you change it it will set it in layer as override:
        minmax = layer.get_act_minmax()
        self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style)
        self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style)
        self.layer_mindim.observe(self.update_layer, names='value')
        self.layer_maxdim.observe(self.update_layer, names='value')
        column2.append(self.layer_mindim)
        column2.append(self.layer_maxdim)
        output_shape = layer.get_output_shape()
        self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style)
        self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout)
        self.layer_feature.observe(self.update_layer, names='value')
        column2.append(self.layer_feature)
        self.svg_rotate = Checkbox(description="Rotate network",
                                   value=self.net.config["svg_rotate"],
                                   style={"description_width": 'initial'},
                                   layout=Layout(width="52%"))
        self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value')
        self.save_config_button = Button(icon="save", layout=Layout(width="10%"))
        self.save_config_button.on_click(self.save_config)
        column2.append(HBox([self.svg_rotate, self.save_config_button]))
        config_children = HBox([VBox(column1, layout=Layout(width="100%")),
                                VBox(column2, layout=Layout(width="100%"))])
        accordion = Accordion(children=[config_children])
        accordion.set_title(0, self.net.name)
        accordion.selected_index = None
        return accordion

    def save_config(self, widget=None):
        self.net.save_config()

    def update_layer(self, change):
        """
        Update the layer object, and redisplay.
        """
        if self._ignore_layer_updates:
            return
        ## The rest indicates a change to a display variable.
        ## We need to save the value in the layer, and regenerate
        ## the display.
        # Get the layer:
        layer = self.net[self.layer_select.value]
        # Save the changed value in the layer:
        layer.feature = self.layer_feature.value
        layer.visible = self.layer_visible_checkbox.value
        ## These three, dealing with colors of activations,
        ## can be done with a prop_one():
        if "color" in change["owner"].description.lower():
            ## Matches: Colormap, lefmost color, rightmost color
            ## overriding dynamic minmax!
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.colormap = self.layer_colormap.value if self.layer_colormap.value else None
            self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
            self.prop_one()
        else:
            self.regenerate()

    def update_layer_selection(self, change):
        """
        Just update the widgets; don't redraw anything.
        """
        ## No need to redisplay anything
        self._ignore_layer_updates = True
        ## First, get the new layer selected:
        layer = self.net[self.layer_select.value]
        ## Now, let's update all of the values without updating:
        self.layer_visible_checkbox.value = layer.visible
        self.layer_colormap.value = layer.colormap if layer.colormap != "" else ""
        self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
        minmax = layer.get_act_minmax()
        self.layer_mindim.value = minmax[0]
        self.layer_maxdim.value = minmax[1]
        self.layer_feature.value = layer.feature
        self._ignore_layer_updates = False
class CompleteWordEmbeddingVisualizer:
    def __init__(self, fasttext_model):
        self.fasttext_model = fasttext_model
        self.n_samples = self.fig = None
        self._hold = False

        # Get vocabulary
        self.vocabulary = self.fasttext_model.get_labels()
        random.shuffle(self.vocabulary)
        self.n_words = len(self.vocabulary)

        # Ticks for slider
        self._n_ticks = 100
        self._ticks = np.logspace(start=1,
                                  stop=np.log10(self.n_words),
                                  endpoint=True,
                                  base=10,
                                  num=self._n_ticks)

        # Slider for how many to plot
        self.n_samples_slider = IntSlider(
            value=5,
            min=0,
            max=self._n_ticks - 1,
            step=1,
            description='# Samples:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=False,
            layout=dict(width="50%"),
        )

        # Label
        self.n_samples_text = IntText(
            value=0,
            layout=dict(width="15%"),
        )

        # Button
        self.button = Button(
            description='Plot points',
            disabled=False,
            button_style=
            'success',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Click me',
            layout=dict(width="15%"),
        )
        self.button.on_click(self._button_clicked)

        # Checkbox for showing words instead of points
        self.show_words = Checkbox(
            value=False,
            description='Show words',
            disabled=False,
            layout=dict(width="15%", padding="0pt 0pt 0pt 10pt"),
            indent=False,
        )
        self.show_words.observe(self._update_button_color)

        # Observations
        self._slider_moved()
        self.n_samples_slider.observe(self._slider_moved)
        self.n_samples_text.observe(self._int_text_edited)

        self.display()

    def display(self):
        clear_output(wait=True)

        box = HBox((self.n_samples_slider, self.n_samples_text, self.button,
                    self.show_words))
        if self.fig is not None:
            # noinspection PyTypeChecker
            display(box, self.fig)
        else:
            # noinspection PyTypeChecker
            display(box)

    def _update_button_color(self, _=None):
        factor = 0.5 if self.show_words.value else 1.0

        if self.n_samples > self._ticks[int(self._n_ticks * factor * 4 / 5)]:
            self.button.button_style = "danger"
        elif self.n_samples > self._ticks[int(self._n_ticks * factor * 2 / 3)]:
            self.button.button_style = "warning"
        elif self.n_samples > self._ticks[int(self._n_ticks * factor * 1 / 2)]:
            self.button.button_style = "info"
        else:
            self.button.button_style = "success"

    def _int_text_edited(self, _=None):
        if not self._hold:
            self._hold = True
            self.n_samples = self.n_samples_text.value

            # Ensure minimum
            if self.n_samples < 10:
                self.n_samples = self.n_samples_text.value = 10

            # Determine tick for slider and ensure maximum
            if self._ticks[-1] > self.n_samples:
                loc = next(idx for idx, val in enumerate(self._ticks)
                           if val > self.n_samples)
            else:
                self.n_samples = self._ticks[-1]
                self.n_samples_text.value = self.n_samples
                loc = self._n_ticks

            self.n_samples_slider.value = loc

        self._update_button_color()
        self._hold = False

    def _slider_moved(self, _=None):
        if not self._hold:
            self._hold = True
            self.n_samples = int(self._ticks[self.n_samples_slider.value])
            self.n_samples_text.value = self.n_samples

        self._update_button_color()
        self._hold = False

    def _disable_button(self):
        self.button.button_style = ""
        self.button.disabled = True

    def _enable_button(self):
        self.button.disabled = False
        self._update_button_color()

    def _button_clicked(self, _=None):
        self._disable_button()
        sleep(0.1)

        # Get a sample of words
        words = self.vocabulary[:self.n_samples]

        # Get vectors
        vectors = np.array(
            [self.fasttext_model.get_word_vector(word) for word in words])

        # Fit PCA
        pca = PCA(n_components=3)
        _ = pca.fit(vectors)

        # Get projections
        projection_matrix = pca.components_
        projections = vectors.dot(projection_matrix.T)

        # Make figure
        plt.close("all")
        self.fig = plt.figure(figsize=(8, 6))
        ax = Axes3D(self.fig)

        if self.show_words.value:
            for loc, word in zip(projections, words):
                ax.text(x=loc[0],
                        y=loc[1],
                        z=loc[2],
                        s=word,
                        ha="center",
                        va="center")
        else:
            ax.scatter(
                xs=projections[:, 0],
                ys=projections[:, 1],
                zs=projections[:, 2],
            )

        # Limits
        ax.set_xlim(projections[:, 0].min(), projections[:, 0].max())
        ax.set_ylim(projections[:, 1].min(), projections[:, 1].max())
        ax.set_zlim(projections[:, 2].min(), projections[:, 2].max())

        # Re-enable button
        self._enable_button()

        self.display()
Пример #4
0
    def make_config(self):
        layout = Layout()
        style = {"description_width": "initial"}
        checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"],
                             layout=layout, style=style)
        checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value')
        checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"],
                             layout=layout, style=style)
        checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value')

        hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:",
                         style=style, layout=layout)
        hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value')
        vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:",
                         style=style, layout=layout)
        vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value')
        self.feature_bank = Select(description="Details:", value=self.net.config["dashboard.features.bank"],
                              options=[""] + [layer.name for layer in self.net.layers],
                              rows=1)
        self.feature_bank.observe(self.regenerate, names='value')
        self.control_select = Select(
            options=['Test', 'Train'],
            value=self.net.config["dashboard.dataset"],
            description='Dataset:',
            rows=1
        )
        self.control_select.observe(self.change_select, names='value')
        column1 = [self.control_select,
                   self.zoom_slider,
                   hspace,
                   vspace,
                   HBox([checkbox1, checkbox2]),
                   self.feature_bank,
                   self.feature_columns,
                   self.feature_scale
        ]
        ## Make layer selectable, and update-able:
        column2 = []
        layer = self.net.layers[-1]
        self.layer_select = Select(description="Layer:", value=layer.name,
                                   options=[layer.name for layer in
                                            self.net.layers],
                                   rows=1)
        self.layer_select.observe(self.update_layer_selection, names='value')
        column2.append(self.layer_select)
        self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout)
        self.layer_visible_checkbox.observe(self.update_layer, names='value')
        column2.append(self.layer_visible_checkbox)
        self.layer_colormap = Select(description="Colormap:",
                                     options=[""] + AVAILABLE_COLORMAPS,
                                     value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1)
        self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)))
        self.layer_colormap.observe(self.update_layer, names='value')
        column2.append(self.layer_colormap)
        column2.append(self.layer_colormap_image)
        ## get dynamic minmax; if you change it it will set it in layer as override:
        minmax = layer.get_act_minmax()
        self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style)
        self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style)
        self.layer_mindim.observe(self.update_layer, names='value')
        self.layer_maxdim.observe(self.update_layer, names='value')
        column2.append(self.layer_mindim)
        column2.append(self.layer_maxdim)
        output_shape = layer.get_output_shape()
        self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style)
        self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout)
        self.layer_feature.observe(self.update_layer, names='value')
        column2.append(self.layer_feature)
        self.svg_rotate = Checkbox(description="Rotate network",
                                   value=self.net.config["svg_rotate"],
                                   style={"description_width": 'initial'},
                                   layout=Layout(width="52%"))
        self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value')
        self.save_config_button = Button(icon="save", layout=Layout(width="10%"))
        self.save_config_button.on_click(self.save_config)
        column2.append(HBox([self.svg_rotate, self.save_config_button]))
        config_children = HBox([VBox(column1, layout=Layout(width="100%")),
                                VBox(column2, layout=Layout(width="100%"))])
        accordion = Accordion(children=[config_children])
        accordion.set_title(0, self.net.name)
        accordion.selected_index = None
        return accordion
Пример #5
0
class Dashboard(VBox):
    """
    Build the dashboard for Jupyter widgets. Requires running
    in a notebook/jupyterlab.
    """
    def __init__(self, net, width="95%", height="550px", play_rate=0.5):
        self._ignore_layer_updates = False
        self.player = _Player(self, play_rate)
        self.player.start()
        self.net = net
        r = random.randint(1, 1000000)
        self.class_id = "picture-dashboard-%s-%s" % (self.net.name, r)
        self._width = width
        self._height = height
        ## Global widgets:
        style = {"description_width": "initial"}
        self.feature_columns = IntText(description="Detail columns:",
                                       value=self.net.config["dashboard.features.columns"],
                                       min=0,
                                       max=1024,
                                       style=style)
        self.feature_scale = FloatText(description="Detail scale:",
                                       value=self.net.config["dashboard.features.scale"],
                                       min=0.1,
                                       max=10,
                                       style=style)
        self.feature_columns.observe(self.regenerate, names='value')
        self.feature_scale.observe(self.regenerate, names='value')
        ## Hack to center SVG as justify-content is broken:
        self.net_svg = HTML(value="""<p style="text-align:center">%s</p>""" % ("",), layout=Layout(
            width=self._width, overflow_x='auto', overflow_y="auto",
            justify_content="center"))
        # Make controls first:
        self.output = Output()
        controls = self.make_controls()
        config = self.make_config()
        super().__init__([config, controls, self.net_svg, self.output])

    def propagate(self, inputs):
        """
        Propagate inputs through the dashboard view of the network.
        """
        if dynamic_pictures_check():
            return self.net.propagate(inputs, class_id=self.class_id, update_pictures=True)
        else:
            self.regenerate(inputs=input)

    def goto(self, position):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            return
        if self.control_select.value == "Train":
            length = len(self.net.dataset.train_inputs)
        elif self.control_select.value == "Test":
            length = len(self.net.dataset.test_inputs)
        #### Position it:
        if position == "begin":
            self.control_slider.value = 0
        elif position == "end":
            self.control_slider.value = length - 1
        elif position == "prev":
            if self.control_slider.value - 1 < 0:
                self.control_slider.value = length - 1 # wrap around
            else:
                self.control_slider.value = max(self.control_slider.value - 1, 0)
        elif position == "next":
            if self.control_slider.value + 1 > length - 1:
                self.control_slider.value = 0 # wrap around
            else:
                self.control_slider.value = min(self.control_slider.value + 1, length - 1)
        self.position_text.value = self.control_slider.value


    def change_select(self, change=None):
        """
        """
        self.update_control_slider(change)
        self.regenerate()

    def update_control_slider(self, change=None):
        self.net.config["dashboard.dataset"] = self.control_select.value
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            self.control_slider.value = 0
            self.position_text.value = 0
            self.control_slider.disabled = True
            self.position_text.disabled = True
            for child in self.control_buttons.children:
                if not hasattr(child, "icon") or child.icon != "refresh":
                    child.disabled = True
            return
        if self.control_select.value == "Test":
            self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
            minmax = (0, max(len(self.net.dataset.test_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.test_inputs) == 0:
                disabled = True
            else:
                disabled = False
        elif self.control_select.value == "Train":
            self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
            minmax = (0, max(len(self.net.dataset.train_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.train_inputs) == 0:
                disabled = True
            else:
                disabled = False
        self.control_slider.disabled = disabled
        self.position_text.disbaled = disabled
        self.position_text.value = self.control_slider.value
        for child in self.control_buttons.children:
            if not hasattr(child, "icon") or child.icon != "refresh":
                child.disabled = disabled

    def update_zoom_slider(self, change):
        if change["name"] == "value":
            self.net.config["svg_scale"] = self.zoom_slider.value
            self.regenerate()

    def update_position_text(self, change):
        # {'name': 'value', 'old': 2, 'new': 3, 'owner': IntText(value=3, layout=Layout(width='100%')), 'type': 'change'}
        self.control_slider.value = change["new"]

    def get_current_input(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_inputs[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_inputs[self.control_slider.value]

    def get_current_targets(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_targets[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_targets[self.control_slider.value]

    def update_slider_control(self, change):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            return
        if change["name"] == "value":
            self.position_text.value = self.control_slider.value
            if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.train_inputs[self.control_slider.value],
                                    targets=self.net.dataset.train_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.train_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.train_inputs[self.control_slider.value],
                                                   cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]:
                    if len(self.net.output_bank_order) == 1: ## FIXME: use minmax of output bank
                        self.net.display_component([self.net.dataset.train_targets[self.control_slider.value]],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        self.net.display_component(self.net.dataset.train_targets[self.control_slider.value],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.train_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.train_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors",  class_id=self.class_id, minmax=(-1, 1))
            elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.test_inputs[self.control_slider.value],
                                    targets=self.net.dataset.test_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.test_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.test_inputs[self.control_slider.value],
                                               cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]: ## FIXME: use minmax of output bank
                    self.net.display_component([self.net.dataset.test_targets[self.control_slider.value]],
                                               "targets",
                                               class_id=self.class_id,
                                               minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.test_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.test_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1))

    def toggle_play(self, button):
        ## toggle
        if self.button_play.description == "Play":
            self.button_play.description = "Stop"
            self.button_play.icon = "pause"
            self.player.resume()
        else:
            self.button_play.description = "Play"
            self.button_play.icon = "play"
            self.player.pause()

    def prop_one(self, button=None):
        self.update_slider_control({"name": "value"})

    def regenerate(self, button=None, inputs=None, targets=None):
        ## Protection when deleting object on shutdown:
        if isinstance(button, dict) and 'new' in button and button['new'] is None:
            return
        ## Update the config:
        self.net.config["dashboard.features.bank"] = self.feature_bank.value
        self.net.config["dashboard.features.columns"] = self.feature_columns.value
        self.net.config["dashboard.features.scale"] = self.feature_scale.value
        inputs = inputs if inputs is not None else self.get_current_input()
        targets = targets if targets is not None else self.get_current_targets()
        features = None
        if self.feature_bank.value in self.net.layer_dict.keys() and inputs is not None:
            if self.net.model is not None:
                features = self.net.propagate_to_features(self.feature_bank.value, inputs,
                                                          cols=self.feature_columns.value,
                                                          scale=self.feature_scale.value, display=False)
        svg = """<p style="text-align:center">%s</p>""" % (self.net.to_svg(
            inputs=inputs,
            targets=targets,
            class_id=self.class_id,
            highlights={self.feature_bank.value: {
                "border_color": "orange",
                "border_width": 30,
            }}))
        if inputs is not None and features is not None:
            html_horizontal = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top" style="width: 50%%;">%s</td>
  <td valign="top" align="center" style="width: 50%%;"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            html_vertical = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top">%s</td>
</tr>
<tr>
  <td valign="top" align="center"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            self.net_svg.value = (html_vertical if self.net.config["svg_rotate"] else html_horizontal) % (
                svg, "%s details" % self.feature_bank.value, features)
        else:
            self.net_svg.value = svg

    def make_colormap_image(self, colormap_name):
        from .layers import Layer
        if not colormap_name:
            colormap_name = get_colormap()
        layer = Layer("Colormap", 100)
        minmax = layer.get_act_minmax()
        image = layer.make_image(np.arange(minmax[0], minmax[1], .01),
                                 colormap_name,
                                 {"pixels_per_unit": 1,
                                  "svg_rotate": self.net.config["svg_rotate"]}).resize((300, 25))
        return image

    def set_attr(self, obj, attr, value):
        if value not in [{}, None]: ## value is None when shutting down
            if isinstance(value, dict):
                value = value["value"]
            if isinstance(obj, dict):
                obj[attr] = value
            else:
                setattr(obj, attr, value)
            ## was crashing on Widgets.__del__, if get_ipython() no longer existed
            self.regenerate()

    def make_controls(self):
        layout = Layout(width='100%', height="100%")
        button_begin = Button(icon="fast-backward", layout=layout)
        button_prev = Button(icon="backward", layout=layout)
        button_next = Button(icon="forward", layout=layout)
        button_end = Button(icon="fast-forward", layout=layout)
        #button_prop = Button(description="Propagate", layout=Layout(width='100%'))
        #button_train = Button(description="Train", layout=Layout(width='100%'))
        self.button_play = Button(icon="play", description="Play", layout=layout)
        step_down = Button(icon="sort-down", layout=Layout(width="95%", height="100%"))
        step_up = Button(icon="sort-up", layout=Layout(width="95%", height="100%"))
        up_down = HBox([step_down, step_up], layout=Layout(width="100%", height="100%"))
        refresh_button = Button(icon="refresh", layout=Layout(width="25%", height="100%"))

        self.position_text = IntText(value=0, layout=layout)

        self.control_buttons = HBox([
            button_begin,
            button_prev,
            #button_train,
            self.position_text,
            button_next,
            button_end,
            self.button_play,
            up_down,
            refresh_button
        ], layout=Layout(width='100%', height="100%"))
        length = (len(self.net.dataset.train_inputs) - 1) if len(self.net.dataset.train_inputs) > 0 else 0
        self.control_slider = IntSlider(description="Dataset index",
                                   continuous_update=False,
                                   min=0,
                                   max=max(length, 0),
                                   value=0,
                                   layout=Layout(width='100%'))
        if self.net.config["dashboard.dataset"] == "Train":
            length = len(self.net.dataset.train_inputs)
        else:
            length = len(self.net.dataset.test_inputs)
        self.total_text = Label(value="of %s" % length, layout=Layout(width="100px"))
        self.zoom_slider = FloatSlider(description="Zoom",
                                       continuous_update=False,
                                       min=0, max=1.0,
                                       style={"description_width": 'initial'},
                                       layout=Layout(width="65%"),
                                       value=self.net.config["svg_scale"] if self.net.config["svg_scale"] is not None else 0.5)

        ## Hook them up:
        button_begin.on_click(lambda button: self.goto("begin"))
        button_end.on_click(lambda button: self.goto("end"))
        button_next.on_click(lambda button: self.goto("next"))
        button_prev.on_click(lambda button: self.goto("prev"))
        self.button_play.on_click(self.toggle_play)
        self.control_slider.observe(self.update_slider_control, names='value')
        refresh_button.on_click(lambda widget: (self.update_control_slider(),
                                                self.output.clear_output(),
                                                self.regenerate()))
        step_down.on_click(lambda widget: self.move_step("down"))
        step_up.on_click(lambda widget: self.move_step("up"))
        self.zoom_slider.observe(self.update_zoom_slider, names='value')
        self.position_text.observe(self.update_position_text, names='value')
        # Put them together:
        controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")),
                         self.control_buttons], layout=Layout(width='100%'))

        #net_page = VBox([control, self.net_svg], layout=Layout(width='95%'))
        controls.on_displayed(lambda widget: self.regenerate())
        return controls

    def move_step(self, direction):
        """
        Move the layer stepper up/down through network
        """
        options = [""] + [layer.name for layer in self.net.layers]
        index = options.index(self.feature_bank.value)
        if direction == "up":
            new_index = (index + 1) % len(options)
        else: ## down
            new_index = (index - 1) % len(options)
        self.feature_bank.value = options[new_index]
        self.regenerate()

    def make_config(self):
        layout = Layout()
        style = {"description_width": "initial"}
        checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"],
                             layout=layout, style=style)
        checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value')
        checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"],
                             layout=layout, style=style)
        checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value')

        hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:",
                         style=style, layout=layout)
        hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value')
        vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:",
                         style=style, layout=layout)
        vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value')
        self.feature_bank = Select(description="Details:", value=self.net.config["dashboard.features.bank"],
                              options=[""] + [layer.name for layer in self.net.layers],
                              rows=1)
        self.feature_bank.observe(self.regenerate, names='value')
        self.control_select = Select(
            options=['Test', 'Train'],
            value=self.net.config["dashboard.dataset"],
            description='Dataset:',
            rows=1
        )
        self.control_select.observe(self.change_select, names='value')
        column1 = [self.control_select,
                   self.zoom_slider,
                   hspace,
                   vspace,
                   HBox([checkbox1, checkbox2]),
                   self.feature_bank,
                   self.feature_columns,
                   self.feature_scale
        ]
        ## Make layer selectable, and update-able:
        column2 = []
        layer = self.net.layers[-1]
        self.layer_select = Select(description="Layer:", value=layer.name,
                                   options=[layer.name for layer in
                                            self.net.layers],
                                   rows=1)
        self.layer_select.observe(self.update_layer_selection, names='value')
        column2.append(self.layer_select)
        self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout)
        self.layer_visible_checkbox.observe(self.update_layer, names='value')
        column2.append(self.layer_visible_checkbox)
        self.layer_colormap = Select(description="Colormap:",
                                     options=[""] + AVAILABLE_COLORMAPS,
                                     value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1)
        self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)))
        self.layer_colormap.observe(self.update_layer, names='value')
        column2.append(self.layer_colormap)
        column2.append(self.layer_colormap_image)
        ## get dynamic minmax; if you change it it will set it in layer as override:
        minmax = layer.get_act_minmax()
        self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style)
        self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style)
        self.layer_mindim.observe(self.update_layer, names='value')
        self.layer_maxdim.observe(self.update_layer, names='value')
        column2.append(self.layer_mindim)
        column2.append(self.layer_maxdim)
        output_shape = layer.get_output_shape()
        self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style)
        self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout)
        self.layer_feature.observe(self.update_layer, names='value')
        column2.append(self.layer_feature)
        self.svg_rotate = Checkbox(description="Rotate network",
                                   value=self.net.config["svg_rotate"],
                                   style={"description_width": 'initial'},
                                   layout=Layout(width="52%"))
        self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value')
        self.save_config_button = Button(icon="save", layout=Layout(width="10%"))
        self.save_config_button.on_click(self.save_config)
        column2.append(HBox([self.svg_rotate, self.save_config_button]))
        config_children = HBox([VBox(column1, layout=Layout(width="100%")),
                                VBox(column2, layout=Layout(width="100%"))])
        accordion = Accordion(children=[config_children])
        accordion.set_title(0, self.net.name)
        accordion.selected_index = None
        return accordion

    def save_config(self, widget=None):
        self.net.save_config()

    def update_layer(self, change):
        """
        Update the layer object, and redisplay.
        """
        if self._ignore_layer_updates:
            return
        ## The rest indicates a change to a display variable.
        ## We need to save the value in the layer, and regenerate
        ## the display.
        # Get the layer:
        layer = self.net[self.layer_select.value]
        # Save the changed value in the layer:
        layer.feature = self.layer_feature.value
        layer.visible = self.layer_visible_checkbox.value
        ## These three, dealing with colors of activations,
        ## can be done with a prop_one():
        if "color" in change["owner"].description.lower():
            ## Matches: Colormap, lefmost color, rightmost color
            ## overriding dynamic minmax!
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.colormap = self.layer_colormap.value if self.layer_colormap.value else None
            self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
            self.prop_one()
        else:
            self.regenerate()

    def update_layer_selection(self, change):
        """
        Just update the widgets; don't redraw anything.
        """
        ## No need to redisplay anything
        self._ignore_layer_updates = True
        ## First, get the new layer selected:
        layer = self.net[self.layer_select.value]
        ## Now, let's update all of the values without updating:
        self.layer_visible_checkbox.value = layer.visible
        self.layer_colormap.value = layer.colormap if layer.colormap != "" else ""
        self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
        minmax = layer.get_act_minmax()
        self.layer_mindim.value = minmax[0]
        self.layer_maxdim.value = minmax[1]
        self.layer_feature.value = layer.feature
        self._ignore_layer_updates = False
class InteractiveLogging():
    def __init__(self, settings, test_name=None, default_window='hanning'):

        if default_window is None:
            default_window = 'None'
        self.settings = settings
        self.test_name = test_name
        self.dataset = datastructure.DataSet()

        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out = Output()
        self.out_top = Output(layout=Layout(height='120px'))

        # Initialise variables
        self.current_view = 'Time'
        self.N_frames = 1
        self.overlap = 0.5
        self.iw_fft_power = 0
        self.iw_tf_power = 0
        self.legend_loc = 'lower right'

        # puts plot inside widget so can have buttons next to plot
        self.outplot = Output(layout=Layout(width='85%'))

        # BUTTONS
        items_measure = [
            'Log Data', 'Delete Last Measurement', 'Reset (clears all data)',
            'Load Data'
        ]
        items_view = ['View Time', 'View FFT', 'View TF']
        items_calc = ['Calc FFT', 'Calc TF', 'Calc TF average']
        items_axes = ['xmin', 'xmax', 'ymin', 'ymax']
        items_save = ['Save Dataset', 'Save Figure']
        items_iw = ['multiply iw', 'divide iw']

        items_legend = ['Left', 'on/off', 'Right']

        self.buttons_measure = [
            Button(description=i, layout=Layout(width='33%'))
            for i in items_measure
        ]
        self.buttons_measure[0].button_style = 'success'
        self.buttons_measure[0].style.font_weight = 'bold'
        self.buttons_measure[1].button_style = 'warning'
        self.buttons_measure[1].style.font_weight = 'bold'
        self.buttons_measure[2].button_style = 'danger'
        self.buttons_measure[2].style.font_weight = 'bold'
        self.buttons_measure[3].button_style = 'primary'
        self.buttons_measure[3].style.font_weight = 'bold'

        self.buttons_view = [
            Button(description=i, layout=Layout(width='95%'))
            for i in items_view
        ]
        self.buttons_view[0].button_style = 'info'
        self.buttons_view[1].button_style = 'info'
        self.buttons_view[2].button_style = 'info'

        self.buttons_calc = [
            Button(description=i, layout=Layout(width='99%'))
            for i in items_calc
        ]
        self.buttons_calc[0].button_style = 'primary'
        self.buttons_calc[1].button_style = 'primary'
        self.buttons_calc[2].button_style = 'primary'

        self.buttons_iw_fft = [
            Button(description=i, layout=Layout(width='50%')) for i in items_iw
        ]
        self.buttons_iw_fft[0].button_style = 'info'
        self.buttons_iw_fft[1].button_style = 'info'

        self.buttons_iw_tf = [
            Button(description=i, layout=Layout(width='50%')) for i in items_iw
        ]
        self.buttons_iw_tf[0].button_style = 'info'
        self.buttons_iw_tf[1].button_style = 'info'

        self.buttons_match = Button(description='Match Amplitudes',
                                    layout=Layout(width='99%'))
        self.buttons_match.button_style = 'info'

        self.buttons_save = [
            Button(description=i, layout=Layout(width='50%'))
            for i in items_save
        ]
        self.buttons_save[0].button_style = 'success'
        self.buttons_save[0].style.font_weight = 'bold'
        self.buttons_save[1].button_style = 'success'
        self.buttons_save[1].style.font_weight = 'bold'

        self.button_warning = Button(
            description=
            'WARNING: Data may be clipped. Press here to delete last measurement.',
            layout=Layout(width='100%'))
        self.button_warning.button_style = 'danger'
        self.button_warning.style.font_weight = 'bold'

        self.button_X = Button(description='Auto X',
                               layout=Layout(width='95%'))
        self.button_Y = Button(description='Auto Y',
                               layout=Layout(width='95%'))
        self.button_X.button_style = 'success'
        self.button_Y.button_style = 'success'

        self.buttons_legend = [
            Button(description=i, layout=Layout(width='31%'))
            for i in items_legend
        ]
        self.buttons_legend_toggle = ToggleButton(description='Click-and-drag',
                                                  layout=Layout(
                                                      width='95%',
                                                      alignitems='start'))

        # TEXT/LABELS/DROPDOWNS
        self.item_iw_fft_label = Label(value='iw power={}'.format(0),
                                       layout=Layout(width='100%'))
        self.item_iw_tf_label = Label(value='iw power={}'.format(0),
                                      layout=Layout(width='100%'))
        self.item_label = Label(value="Frame length = {:.2f} seconds.".format(
            settings.stored_time /
            (self.N_frames - self.overlap * self.N_frames + self.overlap)))
        self.item_axis_label = Label(value="Axes control:",
                                     layout=Layout(width='95%'))
        self.item_view_label = Label(value="View data:",
                                     layout=Layout(width='95%'))
        self.item_legend_label = Label(value="Legend position:",
                                       layout=Layout(width='95%'))
        self.item_blank_label = Label(value="", layout=Layout(width='95%'))

        self.text_axes = [
            FloatText(value=0, description=i, layout=Layout(width='95%'))
            for i in items_axes
        ]
        self.text_axes = [self.button_X] + [self.button_Y] + self.text_axes
        self.drop_window = Dropdown(options=['None', 'hanning'],
                                    value=default_window,
                                    description='Window:',
                                    layout=Layout(width='99%'))
        self.slide_Nframes = IntSlider(value=1,
                                       min=1,
                                       max=30,
                                       step=1,
                                       description='N_frames:',
                                       continuous_update=True,
                                       readout=False,
                                       layout=Layout(width='99%'))
        self.text_Nframes = IntText(value=1,
                                    description='N_frames:',
                                    layout=Layout(width='99%'))

        # VERTICAL GROUPS

        group0 = VBox([
            self.buttons_calc[0], self.drop_window,
            HBox(self.buttons_iw_fft)
        ],
                      layout=Layout(width='33%'))
        group1 = VBox([
            self.buttons_calc[1], self.drop_window, self.slide_Nframes,
            self.text_Nframes, self.item_label,
            HBox(self.buttons_iw_tf), self.buttons_match
        ],
                      layout=Layout(width='33%'))
        group2 = VBox([
            self.buttons_calc[2], self.drop_window,
            HBox(self.buttons_iw_tf), self.buttons_match
        ],
                      layout=Layout(width='33%'))

        group_view = VBox(
            [self.item_axis_label] + self.text_axes +
            [self.item_legend_label, self.buttons_legend_toggle] +
            [HBox(self.buttons_legend), self.item_view_label] +
            self.buttons_view,
            layout=Layout(width='20%'))

        # ASSEMBLE
        display(self.out_top)
        display(HBox([self.button_warning]))
        display(HBox(self.buttons_measure))
        display(HBox([self.outplot, group_view]))
        display(HBox([group0, group1, group2]))
        display(HBox(self.buttons_save))
        self.button_warning.layout.visibility = 'hidden'

        # second part to putting plot inside widget
        with self.outplot:
            self.p = plotting.PlotData(figsize=(7.5, 4))

        ## Make buttons/boxes interactive

        self.text_axes[2].observe(self.xmin, names='value')
        self.text_axes[3].observe(self.xmax, names='value')
        self.text_axes[4].observe(self.ymin, names='value')
        self.text_axes[5].observe(self.ymax, names='value')

        self.button_X.on_click(self.auto_x)
        self.button_Y.on_click(self.auto_y)

        self.buttons_legend[0].on_click(self.legend_left)
        self.buttons_legend[1].on_click(self.legend_onoff)
        self.buttons_legend[2].on_click(self.legend_right)
        self.buttons_legend_toggle.observe(self.legend_toggle)

        self.slide_Nframes.observe(self.nframes_slide)
        self.text_Nframes.observe(self.nframes_text)

        self.buttons_measure[0].on_click(self.measure)
        self.buttons_measure[1].on_click(self.undo)
        self.buttons_measure[2].on_click(self.reset)
        self.buttons_measure[3].on_click(self.load_data)

        self.buttons_view[0].on_click(self.view_time)
        self.buttons_view[1].on_click(self.view_fft)
        self.buttons_view[2].on_click(self.view_tf)

        self.buttons_calc[0].on_click(self.fft)
        self.buttons_calc[1].on_click(self.tf)
        self.buttons_calc[2].on_click(self.tf_av)

        self.buttons_iw_fft[0].on_click(self.xiw_fft)
        self.buttons_iw_fft[1].on_click(self.diw_fft)
        self.buttons_iw_tf[0].on_click(self.xiw_tf)
        self.buttons_iw_tf[1].on_click(self.diw_tf)

        self.buttons_match.on_click(self.match)

        self.buttons_save[0].on_click(self.save_data)
        self.buttons_save[1].on_click(self.save_fig)

        self.button_warning.on_click(self.undo)

        self.refresh_buttons()

        # Put output text at bottom of display
        display(self.out)

        with self.out_top:
            try:
                streams.start_stream(settings)
                self.rec = streams.REC
            except:
                print('Data stream not initialised.')
                print(
                    'Possible reasons: pyaudio or PyDAQmx not installed, or acquisition hardware not connected.'
                )
                print('Please note that it won' 't be possible to log data.')

    def xmin(self, v):
        xmin = self.text_axes[2].value
        xlim = self.p.ax.get_xlim()
        self.p.ax.set_xlim([xmin, xlim[1]])

    def xmax(self, v):
        xmax = self.text_axes[3].value
        xlim = self.p.ax.get_xlim()
        self.p.ax.set_xlim([xlim[0], xmax])

    def ymin(self, v):
        ymin = self.text_axes[4].value
        ylim = self.p.ax.get_ylim()
        self.p.ax.set_ylim([ymin, ylim[1]])

    def ymax(self, v):
        ymax = self.text_axes[5].value
        ylim = self.p.ax.get_ylim()
        self.p.ax.set_ylim([ylim[0], ymax])

    def auto_x(self, b):
        self.p.auto_x()

    def auto_y(self, b):
        self.p.auto_y()

    def legend_left(self, b):
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.legend_loc = 'lower left'
            self.p.update_legend(self.legend_loc)

    def legend_right(self, b):
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.legend_loc = 'lower right'
            self.p.update_legend(self.legend_loc)

    def legend_onoff(self, b):
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            visibility = self.p.ax.get_legend().get_visible()
            self.p.ax.get_legend().set_visible(not visibility)

    def legend_toggle(self, b):
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.p.ax.get_legend().set_visible(True)
            self.p.legend.set_draggable(self.buttons_legend_toggle.value)

    def nframes_slide(self, v):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.N_frames = self.slide_Nframes.value
            self.text_Nframes.value = self.N_frames
            if len(self.dataset.time_data_list) is not 0:
                stored_time = self.dataset.time_data_list[
                    0].settings.stored_time
            elif len(self.dataset.tf_data_list) is not 0:
                stored_time = self.dataset.tf_data_list[0].settings.stored_time
            else:
                stored_time = 0
                print('Time or TF data settings not found')

            self.item_label.value = "Frame length = {:.2f} seconds.".format(
                stored_time /
                (self.N_frames - self.overlap * self.N_frames + self.overlap))
            if self.current_view is 'TF':
                self.tf(None)

    def nframes_text(self, v):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.N_frames = self.text_Nframes.value
            self.slide_Nframes.value = self.N_frames
            if len(self.dataset.time_data_list) is not 0:
                stored_time = self.dataset.time_data_list[
                    0].settings.stored_time
            elif len(self.dataset.tf_data_list) is not 0:
                stored_time = self.dataset.tf_data_list[0].settings.stored_time
            else:
                stored_time = 0
                print('Time or TF data settings not found')
            self.item_label.value = "Frame length = {:.2f} seconds.".format(
                stored_time /
                (self.N_frames - self.overlap * self.N_frames + self.overlap))
            if self.current_view is 'TF':
                self.tf(None)

    def measure(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.rec.trigger_detected = False
            self.buttons_measure[0].button_style = ''
            if self.settings.pretrig_samples is None:
                self.buttons_measure[0].description = 'Logging ({}s)'.format(
                    self.settings.stored_time)
            else:
                self.buttons_measure[
                    0].description = 'Logging ({}s, with trigger)'.format(
                        self.settings.stored_time)

            d = acquisition.log_data(self.settings,
                                     test_name=self.test_name,
                                     rec=self.rec)
            self.dataset.add_to_dataset(d.time_data_list)
            N = len(self.dataset.time_data_list)
            self.p.update(self.dataset.time_data_list,
                          sets=[N - 1],
                          channels='all')
            self.p.auto_x()
            #            self.p.auto_y()
            self.p.ax.set_ylim([-1, 1])
            self.current_view = 'Time'
            self.buttons_measure[0].button_style = 'success'
            self.buttons_measure[0].description = 'Log Data'

            if np.any(np.abs(d.time_data_list[-1].time_data) > 0.95):
                self.button_warning.layout.visibility = 'visible'
            else:
                self.button_warning.layout.visibility = 'hidden'

            xlim = self.p.ax.get_xlim()
            ylim = self.p.ax.get_ylim()
            self.text_axes[2].value = xlim[0]
            self.text_axes[3].value = xlim[1]
            self.text_axes[4].value = ylim[0]
            self.text_axes[5].value = ylim[1]
            self.refresh_buttons()

    def undo(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.dataset.remove_last_data_item('TimeData')
            self.dataset.freq_data_list = datastructure.FreqDataList()
            self.dataset.tf_data_list = datastructure.TfDataList()
            N = len(self.dataset.time_data_list)
            self.p.update(self.dataset.time_data_list,
                          sets=[N - 1],
                          channels='all')
            self.button_warning.layout.visibility = 'hidden'
            self.refresh_buttons()

    def reset(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.dataset = datastructure.DataSet()
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            N = len(self.dataset.time_data_list)
            self.p.update(self.dataset.time_data_list,
                          sets=[N - 1],
                          channels='all')
            self.refresh_buttons()

    def load_data(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            d = file.load_data()
            if d is not None:
                self.dataset.add_to_dataset(d.time_data_list)
                self.dataset.add_to_dataset(d.freq_data_list)
                self.dataset.add_to_dataset(d.tf_data_list)
                self.dataset.add_to_dataset(d.cross_spec_data_list)
                self.dataset.add_to_dataset(d.sono_data_list)
            else:
                print('No data loaded')

            if len(self.dataset.time_data_list) is not 0:
                self.p.update(self.dataset.time_data_list,
                              sets='all',
                              channels='all')
            elif len(self.dataset.freq_data_list) is not 0:
                self.p.update(self.dataset.freq_data_list,
                              sets='all',
                              channels='all')
            elif len(self.dataset.tf_data_list) is not 0:
                self.p.update(self.dataset.tf_data_list,
                              sets='all',
                              channels='all')
            else:
                print('No data to view')

            self.refresh_buttons()

            self.p.auto_x()
            self.p.auto_y()

            xlim = self.p.ax.get_xlim()
            ylim = self.p.ax.get_ylim()
            self.text_axes[2].value = xlim[0]
            self.text_axes[3].value = xlim[1]
            self.text_axes[4].value = ylim[0]
            self.text_axes[5].value = ylim[1]

    def view_time(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.refresh_buttons()
            N = len(self.dataset.time_data_list)
            if N is not 0:
                self.p.update(self.dataset.time_data_list)
                if self.current_view is not 'Time':
                    self.current_view = 'Time'
                    self.p.auto_x()
                    self.p.auto_y()

                xlim = self.p.ax.get_xlim()
                ylim = self.p.ax.get_ylim()
                self.text_axes[2].value = xlim[0]
                self.text_axes[3].value = xlim[1]
                self.text_axes[4].value = ylim[0]
                self.text_axes[5].value = ylim[1]
            else:
                print('no time data to display')

    def view_fft(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            N = len(self.dataset.freq_data_list)
            self.refresh_buttons()
            if N is not 0:
                self.p.update(self.dataset.freq_data_list)
                if self.current_view is not 'FFT':
                    self.current_view = 'FFT'
                    self.p.auto_x()
                    self.p.auto_y()

                xlim = self.p.ax.get_xlim()
                ylim = self.p.ax.get_ylim()
                self.text_axes[2].value = xlim[0]
                self.text_axes[3].value = xlim[1]
                self.text_axes[4].value = ylim[0]
                self.text_axes[5].value = ylim[1]
            else:
                print('no FFT data to display')

    def view_tf(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            N = len(self.dataset.tf_data_list)
            self.refresh_buttons()
            if N is not 0:
                self.p.update(self.dataset.tf_data_list)
                if self.current_view is not 'TF':
                    self.current_view = 'TF'
                    self.p.auto_x()
                    self.p.auto_y()
                xlim = self.p.ax.get_xlim()
                ylim = self.p.ax.get_ylim()
                self.text_axes[2].value = xlim[0]
                self.text_axes[3].value = xlim[1]
                self.text_axes[4].value = ylim[0]
                self.text_axes[5].value = ylim[1]
            else:
                print('no TF data to display')

    def fft(self, b):
        window = self.drop_window.value
        if window is 'None':
            window = None
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.dataset.calculate_fft_set(window=window)
            self.p.update(self.dataset.freq_data_list)
            if self.current_view is not 'FFT':
                self.current_view = 'FFT'
                self.p.auto_x()
                self.p.auto_y()

            xlim = self.p.ax.get_xlim()
            ylim = self.p.ax.get_ylim()
            self.text_axes[2].value = xlim[0]
            self.text_axes[3].value = xlim[1]
            self.text_axes[4].value = ylim[0]
            self.text_axes[5].value = ylim[1]
            self.refresh_buttons()

    def tf(self, b):
        N_frames = self.N_frames
        window = self.drop_window.value
        if window is 'None':
            window = None
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.dataset.calculate_tf_set(window=window,
                                          N_frames=N_frames,
                                          overlap=self.overlap)
            self.p.update(self.dataset.tf_data_list)
            if self.current_view is not 'TF':
                self.current_view = 'TF'
                self.p.auto_x()
                self.p.auto_y()
            xlim = self.p.ax.get_xlim()
            ylim = self.p.ax.get_ylim()
            self.text_axes[2].value = xlim[0]
            self.text_axes[3].value = xlim[1]
            self.text_axes[4].value = ylim[0]
            self.text_axes[5].value = ylim[1]
            self.refresh_buttons()

    def tf_av(self, b):
        window = self.drop_window.value
        if window is 'None':
            window = None
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            self.dataset.calculate_tf_averaged(window=window)
            self.p.update(self.dataset.tf_data_list)
            if self.current_view is not 'TFAV':
                self.current_view = 'TFAV'
                self.p.auto_x()
                self.p.auto_y()

            xlim = self.p.ax.get_xlim()
            ylim = self.p.ax.get_ylim()
            self.text_axes[2].value = xlim[0]
            self.text_axes[3].value = xlim[1]
            self.text_axes[4].value = ylim[0]
            self.text_axes[5].value = ylim[1]
            self.refresh_buttons()

    def xiw_fft(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            if self.current_view is 'FFT':
                s = self.p.get_selected_channels()
                n_sets, n_chans = np.shape(s)
                for ns in range(n_sets):
                    newdata = analysis.multiply_by_power_of_iw(
                        self.dataset.freq_data_list[ns],
                        power=1,
                        channel_list=s[ns, :])
                    self.dataset.freq_data_list[ns] = newdata
                self.p.update(self.dataset.freq_data_list)
                self.p.auto_y()
                self.iw_fft_power += 1
                print('Multiplied by (iw)**{}'.format(self.iw_fft_power))
            else:
                print('First press <Calc FFT>')

    def diw_fft(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            if self.current_view is 'FFT':
                s = self.p.get_selected_channels()
                n_sets, n_chans = np.shape(s)
                for ns in range(n_sets):
                    newdata = analysis.multiply_by_power_of_iw(
                        self.dataset.freq_data_list[ns],
                        power=-1,
                        channel_list=s[ns, :])
                    self.dataset.freq_data_list[ns] = newdata
                self.p.update(self.dataset.freq_data_list)
                self.p.auto_y()
                self.iw_fft_power -= 1
                print('Multiplied by (iw)**{}'.format(self.iw_fft_power))
            else:
                print('First press <Calc FFT>')

    def xiw_tf(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            if (self.current_view is 'TF') or (self.current_view is 'TFAV'):
                s = self.p.get_selected_channels()
                n_sets, n_chans = np.shape(s)
                for ns in range(n_sets):
                    newdata = analysis.multiply_by_power_of_iw(
                        self.dataset.tf_data_list[ns],
                        power=1,
                        channel_list=s[ns, :])
                    self.dataset.tf_data_list[ns] = newdata
                self.p.update(self.dataset.tf_data_list)
                self.p.auto_y()
                self.iw_tf_power += 1
                print('Multiplied selected channel by (iw)**{}'.format(
                    self.iw_tf_power))
            else:
                print('First press <Calc TF> or <Calc TF average>')

    def diw_tf(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            if (self.current_view is 'TF') or (self.current_view is 'TFAV'):
                s = self.p.get_selected_channels()
                n_sets, n_chans = np.shape(s)
                for ns in range(n_sets):
                    newdata = analysis.multiply_by_power_of_iw(
                        self.dataset.tf_data_list[ns],
                        power=-1,
                        channel_list=s[ns, :])
                    self.dataset.tf_data_list[ns] = newdata
                self.p.update(self.dataset.tf_data_list)
                self.p.auto_y()
                self.iw_tf_power -= 1
                print('Multiplied selected channel by (iw)**{}'.format(
                    self.iw_tf_power))
            else:
                print('First press <Calc TF> or <Calc TF average>')

    def match(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            if (self.current_view is 'TF') or (self.current_view is 'TFAV'):
                freq_range = self.p.ax.get_xlim()
                current_calibration_factors = self.dataset.tf_data_list.get_calibration_factors(
                )
                reference = current_calibration_factors[0][0]
                factors = analysis.best_match(self.dataset.tf_data_list,
                                              freq_range=freq_range,
                                              set_ref=0,
                                              ch_ref=0)
                factors = [reference * x for x in factors]
                self.dataset.tf_data_list.set_calibration_factors_all(factors)
                self.p.update(self.dataset.tf_data_list)
                print('scale factors:')
                print(factors)
                #self.p.auto_y()
            else:
                print(
                    'First press <View TF> or <Calc TF> or <Calc TF average>')

    def save_data(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            print('Saving dataset:')
            print(self.dataset)
            self.dataset.save_data()

    def save_fig(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        self.out_top.clear_output(wait=False)
        with self.out_top:
            file.save_fig(self.p, figsize=(9, 5))

    def refresh_buttons(self):
        if len(self.dataset.time_data_list) is 0:
            self.buttons_view[0].button_style = ''
        else:
            self.buttons_view[0].button_style = 'info'

        if len(self.dataset.freq_data_list) is 0:
            self.buttons_view[1].button_style = ''
        else:
            self.buttons_view[1].button_style = 'info'

        if len(self.dataset.tf_data_list) is 0:
            self.buttons_view[2].button_style = ''
        else:
            self.buttons_view[2].button_style = 'info'