def make_config(self): layout = Layout() style = {"description_width": "initial"} checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"], layout=layout, style=style) checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value') checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"], layout=layout, style=style) checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value') hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:", style=style, layout=layout) hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value') vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:", style=style, layout=layout) vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value') self.feature_bank = Select(description="Features:", value=self.net.config["dashboard.features.bank"], options=[""] + [layer.name for layer in self.net.layers if self.net._layer_has_features(layer.name)], rows=1) self.feature_bank.observe(self.regenerate, names='value') self.control_select = Select( options=['Test', 'Train'], value=self.net.config["dashboard.dataset"], description='Dataset:', rows=1 ) self.control_select.observe(self.change_select, names='value') column1 = [self.control_select, self.zoom_slider, hspace, vspace, HBox([checkbox1, checkbox2]), self.feature_bank, self.feature_columns, self.feature_scale ] ## Make layer selectable, and update-able: column2 = [] layer = self.net.layers[-1] self.layer_select = Select(description="Layer:", value=layer.name, options=[layer.name for layer in self.net.layers], rows=1) self.layer_select.observe(self.update_layer_selection, names='value') column2.append(self.layer_select) self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout) self.layer_visible_checkbox.observe(self.update_layer, names='value') column2.append(self.layer_visible_checkbox) self.layer_colormap = Select(description="Colormap:", options=[""] + AVAILABLE_COLORMAPS, value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1) self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))) self.layer_colormap.observe(self.update_layer, names='value') column2.append(self.layer_colormap) column2.append(self.layer_colormap_image) ## get dynamic minmax; if you change it it will set it in layer as override: minmax = layer.get_act_minmax() self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style) self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style) self.layer_mindim.observe(self.update_layer, names='value') self.layer_maxdim.observe(self.update_layer, names='value') column2.append(self.layer_mindim) column2.append(self.layer_maxdim) output_shape = layer.get_output_shape() self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style) self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout) self.layer_feature.observe(self.update_layer, names='value') column2.append(self.layer_feature) self.svg_rotate = Checkbox(description="Rotate network", value=self.net.config["svg_rotate"], style={"description_width": 'initial'}, layout=Layout(width="52%")) self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value') self.save_config_button = Button(icon="save", layout=Layout(width="10%")) self.save_config_button.on_click(self.save_config) column2.append(HBox([self.svg_rotate, self.save_config_button])) config_children = HBox([VBox(column1, layout=Layout(width="100%")), VBox(column2, layout=Layout(width="100%"))]) accordion = Accordion(children=[config_children]) accordion.set_title(0, self.net.name) accordion.selected_index = None return accordion
class Dashboard(VBox): """ Build the dashboard for Jupyter widgets. Requires running in a notebook/jupyterlab. """ def __init__(self, net, width="95%", height="550px", play_rate=0.5): self._ignore_layer_updates = False self.player = _Player(self, play_rate) self.player.start() self.net = net r = random.randint(1, 1000000) self.class_id = "picture-dashboard-%s-%s" % (self.net.name, r) self._width = width self._height = height ## Global widgets: style = {"description_width": "initial"} self.feature_columns = IntText(description="Feature columns:", value=self.net.config["dashboard.features.columns"], min=0, max=1024, style=style) self.feature_scale = FloatText(description="Feature scale:", value=self.net.config["dashboard.features.scale"], min=0.1, max=10, style=style) self.feature_columns.observe(self.regenerate, names='value') self.feature_scale.observe(self.regenerate, names='value') ## Hack to center SVG as justify-content is broken: self.net_svg = HTML(value="""<p style="text-align:center">%s</p>""" % ("",), layout=Layout( width=self._width, overflow_x='auto', overflow_y="auto", justify_content="center")) # Make controls first: self.output = Output() controls = self.make_controls() config = self.make_config() super().__init__([config, controls, self.net_svg, self.output]) def propagate(self, inputs): """ Propagate inputs through the dashboard view of the network. """ if dynamic_pictures_check(): return self.net.propagate(inputs, class_id=self.class_id, update_pictures=True) else: self.regenerate(inputs=input) def goto(self, position): if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0: return if self.control_select.value == "Train": length = len(self.net.dataset.train_inputs) elif self.control_select.value == "Test": length = len(self.net.dataset.test_inputs) #### Position it: if position == "begin": self.control_slider.value = 0 elif position == "end": self.control_slider.value = length - 1 elif position == "prev": if self.control_slider.value - 1 < 0: self.control_slider.value = length - 1 # wrap around else: self.control_slider.value = max(self.control_slider.value - 1, 0) elif position == "next": if self.control_slider.value + 1 > length - 1: self.control_slider.value = 0 # wrap around else: self.control_slider.value = min(self.control_slider.value + 1, length - 1) self.position_text.value = self.control_slider.value def change_select(self, change=None): """ """ self.update_control_slider(change) self.regenerate() def update_control_slider(self, change=None): self.net.config["dashboard.dataset"] = self.control_select.value if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0: self.total_text.value = "of 0" self.control_slider.value = 0 self.position_text.value = 0 self.control_slider.disabled = True self.position_text.disabled = True for child in self.control_buttons.children: if not hasattr(child, "icon") or child.icon != "refresh": child.disabled = True return if self.control_select.value == "Test": self.total_text.value = "of %s" % len(self.net.dataset.test_inputs) minmax = (0, max(len(self.net.dataset.test_inputs) - 1, 0)) if minmax[0] <= self.control_slider.value <= minmax[1]: pass # ok else: self.control_slider.value = 0 self.control_slider.min = minmax[0] self.control_slider.max = minmax[1] if len(self.net.dataset.test_inputs) == 0: disabled = True else: disabled = False elif self.control_select.value == "Train": self.total_text.value = "of %s" % len(self.net.dataset.train_inputs) minmax = (0, max(len(self.net.dataset.train_inputs) - 1, 0)) if minmax[0] <= self.control_slider.value <= minmax[1]: pass # ok else: self.control_slider.value = 0 self.control_slider.min = minmax[0] self.control_slider.max = minmax[1] if len(self.net.dataset.train_inputs) == 0: disabled = True else: disabled = False self.control_slider.disabled = disabled self.position_text.disbaled = disabled self.position_text.value = self.control_slider.value for child in self.control_buttons.children: if not hasattr(child, "icon") or child.icon != "refresh": child.disabled = disabled def update_zoom_slider(self, change): if change["name"] == "value": self.net.config["svg_scale"] = self.zoom_slider.value self.regenerate() def update_position_text(self, change): # {'name': 'value', 'old': 2, 'new': 3, 'owner': IntText(value=3, layout=Layout(width='100%')), 'type': 'change'} self.control_slider.value = change["new"] def get_current_input(self): if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0: return self.net.dataset.train_inputs[self.control_slider.value] elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0: return self.net.dataset.test_inputs[self.control_slider.value] def get_current_targets(self): if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0: return self.net.dataset.train_targets[self.control_slider.value] elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0: return self.net.dataset.test_targets[self.control_slider.value] def update_slider_control(self, change): if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0: self.total_text.value = "of 0" return if change["name"] == "value": self.position_text.value = self.control_slider.value if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0: self.total_text.value = "of %s" % len(self.net.dataset.train_inputs) if self.net.model is None: return if not dynamic_pictures_check(): self.regenerate(inputs=self.net.dataset.train_inputs[self.control_slider.value], targets=self.net.dataset.train_targets[self.control_slider.value]) return output = self.net.propagate(self.net.dataset.train_inputs[self.control_slider.value], class_id=self.class_id, update_pictures=True) if self.feature_bank.value in self.net.layer_dict.keys(): self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.train_inputs[self.control_slider.value], cols=self.feature_columns.value, scale=self.feature_scale.value, html=False) if self.net.config["show_targets"]: if len(self.net.output_bank_order) == 1: ## FIXME: use minmax of output bank self.net.display_component([self.net.dataset.train_targets[self.control_slider.value]], "targets", class_id=self.class_id, minmax=(-1, 1)) else: self.net.display_component(self.net.dataset.train_targets[self.control_slider.value], "targets", class_id=self.class_id, minmax=(-1, 1)) if self.net.config["show_errors"]: ## minmax is error if len(self.net.output_bank_order) == 1: errors = np.array(output) - np.array(self.net.dataset.train_targets[self.control_slider.value]) self.net.display_component([errors.tolist()], "errors", class_id=self.class_id, minmax=(-1, 1)) else: errors = [] for bank in range(len(self.net.output_bank_order)): errors.append( np.array(output[bank]) - np.array(self.net.dataset.train_targets[self.control_slider.value][bank])) self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1)) elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0: self.total_text.value = "of %s" % len(self.net.dataset.test_inputs) if self.net.model is None: return if not dynamic_pictures_check(): self.regenerate(inputs=self.net.dataset.test_inputs[self.control_slider.value], targets=self.net.dataset.test_targets[self.control_slider.value]) return output = self.net.propagate(self.net.dataset.test_inputs[self.control_slider.value], class_id=self.class_id, update_pictures=True) if self.feature_bank.value in self.net.layer_dict.keys(): self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.test_inputs[self.control_slider.value], cols=self.feature_columns.value, scale=self.feature_scale.value, html=False) if self.net.config["show_targets"]: ## FIXME: use minmax of output bank self.net.display_component([self.net.dataset.test_targets[self.control_slider.value]], "targets", class_id=self.class_id, minmax=(-1, 1)) if self.net.config["show_errors"]: ## minmax is error if len(self.net.output_bank_order) == 1: errors = np.array(output) - np.array(self.net.dataset.test_targets[self.control_slider.value]) self.net.display_component([errors.tolist()], "errors", class_id=self.class_id, minmax=(-1, 1)) else: errors = [] for bank in range(len(self.net.output_bank_order)): errors.append( np.array(output[bank]) - np.array(self.net.dataset.test_targets[self.control_slider.value][bank])) self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1)) def toggle_play(self, button): ## toggle if self.button_play.description == "Play": self.button_play.description = "Stop" self.button_play.icon = "pause" self.player.resume() else: self.button_play.description = "Play" self.button_play.icon = "play" self.player.pause() def prop_one(self, button=None): self.update_slider_control({"name": "value"}) def regenerate(self, button=None, inputs=None, targets=None): ## Protection when deleting object on shutdown: if isinstance(button, dict) and 'new' in button and button['new'] is None: return ## Update the config: self.net.config["dashboard.features.bank"] = self.feature_bank.value self.net.config["dashboard.features.columns"] = self.feature_columns.value self.net.config["dashboard.features.scale"] = self.feature_scale.value inputs = inputs if inputs is not None else self.get_current_input() targets = targets if targets is not None else self.get_current_targets() features = None if self.feature_bank.value in self.net.layer_dict.keys() and inputs is not None: if self.net.model is not None: features = self.net.propagate_to_features(self.feature_bank.value, inputs, cols=self.feature_columns.value, scale=self.feature_scale.value, display=False) svg = """<p style="text-align:center">%s</p>""" % (self.net.to_svg(inputs=inputs, targets=targets, class_id=self.class_id),) if inputs is not None and features is not None: html_horizontal = """ <table align="center" style="width: 100%%;"> <tr> <td valign="top" style="width: 50%%;">%s</td> <td valign="top" align="center" style="width: 50%%;"><p style="text-align:center"><b>%s</b></p>%s</td> </tr> </table>""" html_vertical = """ <table align="center" style="width: 100%%;"> <tr> <td valign="top">%s</td> </tr> <tr> <td valign="top" align="center"><p style="text-align:center"><b>%s</b></p>%s</td> </tr> </table>""" self.net_svg.value = (html_vertical if self.net.config["svg_rotate"] else html_horizontal) % ( svg, "%s features" % self.feature_bank.value, features) else: self.net_svg.value = svg def make_colormap_image(self, colormap_name): from .layers import Layer if not colormap_name: colormap_name = get_colormap() layer = Layer("Colormap", 100) minmax = layer.get_act_minmax() image = layer.make_image(np.arange(minmax[0], minmax[1], .01), colormap_name, {"pixels_per_unit": 1, "svg_rotate": self.net.config["svg_rotate"]}).resize((300, 25)) return image def set_attr(self, obj, attr, value): if value not in [{}, None]: ## value is None when shutting down if isinstance(value, dict): value = value["value"] if isinstance(obj, dict): obj[attr] = value else: setattr(obj, attr, value) ## was crashing on Widgets.__del__, if get_ipython() no longer existed self.regenerate() def make_controls(self): button_begin = Button(icon="fast-backward", layout=Layout(width='100%')) button_prev = Button(icon="backward", layout=Layout(width='100%')) button_next = Button(icon="forward", layout=Layout(width='100%')) button_end = Button(icon="fast-forward", layout=Layout(width='100%')) #button_prop = Button(description="Propagate", layout=Layout(width='100%')) #button_train = Button(description="Train", layout=Layout(width='100%')) self.button_play = Button(icon="play", description="Play", layout=Layout(width="100%")) refresh_button = Button(icon="refresh", layout=Layout(width="25%")) self.position_text = IntText(value=0, layout=Layout(width="100%")) self.control_buttons = HBox([ button_begin, button_prev, #button_train, self.position_text, button_next, button_end, self.button_play, refresh_button ], layout=Layout(width='100%', height="50px")) length = (len(self.net.dataset.train_inputs) - 1) if len(self.net.dataset.train_inputs) > 0 else 0 self.control_slider = IntSlider(description="Dataset index", continuous_update=False, min=0, max=max(length, 0), value=0, layout=Layout(width='100%')) if self.net.config["dashboard.dataset"] == "Train": length = len(self.net.dataset.train_inputs) else: length = len(self.net.dataset.test_inputs) self.total_text = Label(value="of %s" % length, layout=Layout(width="100px")) self.zoom_slider = FloatSlider(description="Zoom", continuous_update=False, min=0, max=1.0, style={"description_width": 'initial'}, layout=Layout(width="65%"), value=self.net.config["svg_scale"] if self.net.config["svg_scale"] is not None else 0.5) ## Hook them up: button_begin.on_click(lambda button: self.goto("begin")) button_end.on_click(lambda button: self.goto("end")) button_next.on_click(lambda button: self.goto("next")) button_prev.on_click(lambda button: self.goto("prev")) self.button_play.on_click(self.toggle_play) self.control_slider.observe(self.update_slider_control, names='value') refresh_button.on_click(lambda widget: (self.update_control_slider(), self.output.clear_output(), self.regenerate())) self.zoom_slider.observe(self.update_zoom_slider, names='value') self.position_text.observe(self.update_position_text, names='value') # Put them together: controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")), self.control_buttons], layout=Layout(width='100%')) #net_page = VBox([control, self.net_svg], layout=Layout(width='95%')) controls.on_displayed(lambda widget: self.regenerate()) return controls def make_config(self): layout = Layout() style = {"description_width": "initial"} checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"], layout=layout, style=style) checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value') checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"], layout=layout, style=style) checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value') hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:", style=style, layout=layout) hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value') vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:", style=style, layout=layout) vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value') self.feature_bank = Select(description="Features:", value=self.net.config["dashboard.features.bank"], options=[""] + [layer.name for layer in self.net.layers if self.net._layer_has_features(layer.name)], rows=1) self.feature_bank.observe(self.regenerate, names='value') self.control_select = Select( options=['Test', 'Train'], value=self.net.config["dashboard.dataset"], description='Dataset:', rows=1 ) self.control_select.observe(self.change_select, names='value') column1 = [self.control_select, self.zoom_slider, hspace, vspace, HBox([checkbox1, checkbox2]), self.feature_bank, self.feature_columns, self.feature_scale ] ## Make layer selectable, and update-able: column2 = [] layer = self.net.layers[-1] self.layer_select = Select(description="Layer:", value=layer.name, options=[layer.name for layer in self.net.layers], rows=1) self.layer_select.observe(self.update_layer_selection, names='value') column2.append(self.layer_select) self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout) self.layer_visible_checkbox.observe(self.update_layer, names='value') column2.append(self.layer_visible_checkbox) self.layer_colormap = Select(description="Colormap:", options=[""] + AVAILABLE_COLORMAPS, value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1) self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))) self.layer_colormap.observe(self.update_layer, names='value') column2.append(self.layer_colormap) column2.append(self.layer_colormap_image) ## get dynamic minmax; if you change it it will set it in layer as override: minmax = layer.get_act_minmax() self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style) self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style) self.layer_mindim.observe(self.update_layer, names='value') self.layer_maxdim.observe(self.update_layer, names='value') column2.append(self.layer_mindim) column2.append(self.layer_maxdim) output_shape = layer.get_output_shape() self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style) self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout) self.layer_feature.observe(self.update_layer, names='value') column2.append(self.layer_feature) self.svg_rotate = Checkbox(description="Rotate network", value=self.net.config["svg_rotate"], style={"description_width": 'initial'}, layout=Layout(width="52%")) self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value') self.save_config_button = Button(icon="save", layout=Layout(width="10%")) self.save_config_button.on_click(self.save_config) column2.append(HBox([self.svg_rotate, self.save_config_button])) config_children = HBox([VBox(column1, layout=Layout(width="100%")), VBox(column2, layout=Layout(width="100%"))]) accordion = Accordion(children=[config_children]) accordion.set_title(0, self.net.name) accordion.selected_index = None return accordion def save_config(self, widget=None): self.net.save_config() def update_layer(self, change): """ Update the layer object, and redisplay. """ if self._ignore_layer_updates: return ## The rest indicates a change to a display variable. ## We need to save the value in the layer, and regenerate ## the display. # Get the layer: layer = self.net[self.layer_select.value] # Save the changed value in the layer: layer.feature = self.layer_feature.value layer.visible = self.layer_visible_checkbox.value ## These three, dealing with colors of activations, ## can be done with a prop_one(): if "color" in change["owner"].description.lower(): ## Matches: Colormap, lefmost color, rightmost color ## overriding dynamic minmax! layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value) layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value) layer.colormap = self.layer_colormap.value if self.layer_colormap.value else None self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)) self.prop_one() else: self.regenerate() def update_layer_selection(self, change): """ Just update the widgets; don't redraw anything. """ ## No need to redisplay anything self._ignore_layer_updates = True ## First, get the new layer selected: layer = self.net[self.layer_select.value] ## Now, let's update all of the values without updating: self.layer_visible_checkbox.value = layer.visible self.layer_colormap.value = layer.colormap if layer.colormap != "" else "" self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)) minmax = layer.get_act_minmax() self.layer_mindim.value = minmax[0] self.layer_maxdim.value = minmax[1] self.layer_feature.value = layer.feature self._ignore_layer_updates = False
class CompleteWordEmbeddingVisualizer: def __init__(self, fasttext_model): self.fasttext_model = fasttext_model self.n_samples = self.fig = None self._hold = False # Get vocabulary self.vocabulary = self.fasttext_model.get_labels() random.shuffle(self.vocabulary) self.n_words = len(self.vocabulary) # Ticks for slider self._n_ticks = 100 self._ticks = np.logspace(start=1, stop=np.log10(self.n_words), endpoint=True, base=10, num=self._n_ticks) # Slider for how many to plot self.n_samples_slider = IntSlider( value=5, min=0, max=self._n_ticks - 1, step=1, description='# Samples:', disabled=False, continuous_update=False, orientation='horizontal', readout=False, layout=dict(width="50%"), ) # Label self.n_samples_text = IntText( value=0, layout=dict(width="15%"), ) # Button self.button = Button( description='Plot points', disabled=False, button_style= 'success', # 'success', 'info', 'warning', 'danger' or '' tooltip='Click me', layout=dict(width="15%"), ) self.button.on_click(self._button_clicked) # Checkbox for showing words instead of points self.show_words = Checkbox( value=False, description='Show words', disabled=False, layout=dict(width="15%", padding="0pt 0pt 0pt 10pt"), indent=False, ) self.show_words.observe(self._update_button_color) # Observations self._slider_moved() self.n_samples_slider.observe(self._slider_moved) self.n_samples_text.observe(self._int_text_edited) self.display() def display(self): clear_output(wait=True) box = HBox((self.n_samples_slider, self.n_samples_text, self.button, self.show_words)) if self.fig is not None: # noinspection PyTypeChecker display(box, self.fig) else: # noinspection PyTypeChecker display(box) def _update_button_color(self, _=None): factor = 0.5 if self.show_words.value else 1.0 if self.n_samples > self._ticks[int(self._n_ticks * factor * 4 / 5)]: self.button.button_style = "danger" elif self.n_samples > self._ticks[int(self._n_ticks * factor * 2 / 3)]: self.button.button_style = "warning" elif self.n_samples > self._ticks[int(self._n_ticks * factor * 1 / 2)]: self.button.button_style = "info" else: self.button.button_style = "success" def _int_text_edited(self, _=None): if not self._hold: self._hold = True self.n_samples = self.n_samples_text.value # Ensure minimum if self.n_samples < 10: self.n_samples = self.n_samples_text.value = 10 # Determine tick for slider and ensure maximum if self._ticks[-1] > self.n_samples: loc = next(idx for idx, val in enumerate(self._ticks) if val > self.n_samples) else: self.n_samples = self._ticks[-1] self.n_samples_text.value = self.n_samples loc = self._n_ticks self.n_samples_slider.value = loc self._update_button_color() self._hold = False def _slider_moved(self, _=None): if not self._hold: self._hold = True self.n_samples = int(self._ticks[self.n_samples_slider.value]) self.n_samples_text.value = self.n_samples self._update_button_color() self._hold = False def _disable_button(self): self.button.button_style = "" self.button.disabled = True def _enable_button(self): self.button.disabled = False self._update_button_color() def _button_clicked(self, _=None): self._disable_button() sleep(0.1) # Get a sample of words words = self.vocabulary[:self.n_samples] # Get vectors vectors = np.array( [self.fasttext_model.get_word_vector(word) for word in words]) # Fit PCA pca = PCA(n_components=3) _ = pca.fit(vectors) # Get projections projection_matrix = pca.components_ projections = vectors.dot(projection_matrix.T) # Make figure plt.close("all") self.fig = plt.figure(figsize=(8, 6)) ax = Axes3D(self.fig) if self.show_words.value: for loc, word in zip(projections, words): ax.text(x=loc[0], y=loc[1], z=loc[2], s=word, ha="center", va="center") else: ax.scatter( xs=projections[:, 0], ys=projections[:, 1], zs=projections[:, 2], ) # Limits ax.set_xlim(projections[:, 0].min(), projections[:, 0].max()) ax.set_ylim(projections[:, 1].min(), projections[:, 1].max()) ax.set_zlim(projections[:, 2].min(), projections[:, 2].max()) # Re-enable button self._enable_button() self.display()
def make_config(self): layout = Layout() style = {"description_width": "initial"} checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"], layout=layout, style=style) checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value') checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"], layout=layout, style=style) checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value') hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:", style=style, layout=layout) hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value') vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:", style=style, layout=layout) vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value') self.feature_bank = Select(description="Details:", value=self.net.config["dashboard.features.bank"], options=[""] + [layer.name for layer in self.net.layers], rows=1) self.feature_bank.observe(self.regenerate, names='value') self.control_select = Select( options=['Test', 'Train'], value=self.net.config["dashboard.dataset"], description='Dataset:', rows=1 ) self.control_select.observe(self.change_select, names='value') column1 = [self.control_select, self.zoom_slider, hspace, vspace, HBox([checkbox1, checkbox2]), self.feature_bank, self.feature_columns, self.feature_scale ] ## Make layer selectable, and update-able: column2 = [] layer = self.net.layers[-1] self.layer_select = Select(description="Layer:", value=layer.name, options=[layer.name for layer in self.net.layers], rows=1) self.layer_select.observe(self.update_layer_selection, names='value') column2.append(self.layer_select) self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout) self.layer_visible_checkbox.observe(self.update_layer, names='value') column2.append(self.layer_visible_checkbox) self.layer_colormap = Select(description="Colormap:", options=[""] + AVAILABLE_COLORMAPS, value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1) self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))) self.layer_colormap.observe(self.update_layer, names='value') column2.append(self.layer_colormap) column2.append(self.layer_colormap_image) ## get dynamic minmax; if you change it it will set it in layer as override: minmax = layer.get_act_minmax() self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style) self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style) self.layer_mindim.observe(self.update_layer, names='value') self.layer_maxdim.observe(self.update_layer, names='value') column2.append(self.layer_mindim) column2.append(self.layer_maxdim) output_shape = layer.get_output_shape() self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style) self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout) self.layer_feature.observe(self.update_layer, names='value') column2.append(self.layer_feature) self.svg_rotate = Checkbox(description="Rotate network", value=self.net.config["svg_rotate"], style={"description_width": 'initial'}, layout=Layout(width="52%")) self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value') self.save_config_button = Button(icon="save", layout=Layout(width="10%")) self.save_config_button.on_click(self.save_config) column2.append(HBox([self.svg_rotate, self.save_config_button])) config_children = HBox([VBox(column1, layout=Layout(width="100%")), VBox(column2, layout=Layout(width="100%"))]) accordion = Accordion(children=[config_children]) accordion.set_title(0, self.net.name) accordion.selected_index = None return accordion
class Dashboard(VBox): """ Build the dashboard for Jupyter widgets. Requires running in a notebook/jupyterlab. """ def __init__(self, net, width="95%", height="550px", play_rate=0.5): self._ignore_layer_updates = False self.player = _Player(self, play_rate) self.player.start() self.net = net r = random.randint(1, 1000000) self.class_id = "picture-dashboard-%s-%s" % (self.net.name, r) self._width = width self._height = height ## Global widgets: style = {"description_width": "initial"} self.feature_columns = IntText(description="Detail columns:", value=self.net.config["dashboard.features.columns"], min=0, max=1024, style=style) self.feature_scale = FloatText(description="Detail scale:", value=self.net.config["dashboard.features.scale"], min=0.1, max=10, style=style) self.feature_columns.observe(self.regenerate, names='value') self.feature_scale.observe(self.regenerate, names='value') ## Hack to center SVG as justify-content is broken: self.net_svg = HTML(value="""<p style="text-align:center">%s</p>""" % ("",), layout=Layout( width=self._width, overflow_x='auto', overflow_y="auto", justify_content="center")) # Make controls first: self.output = Output() controls = self.make_controls() config = self.make_config() super().__init__([config, controls, self.net_svg, self.output]) def propagate(self, inputs): """ Propagate inputs through the dashboard view of the network. """ if dynamic_pictures_check(): return self.net.propagate(inputs, class_id=self.class_id, update_pictures=True) else: self.regenerate(inputs=input) def goto(self, position): if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0: return if self.control_select.value == "Train": length = len(self.net.dataset.train_inputs) elif self.control_select.value == "Test": length = len(self.net.dataset.test_inputs) #### Position it: if position == "begin": self.control_slider.value = 0 elif position == "end": self.control_slider.value = length - 1 elif position == "prev": if self.control_slider.value - 1 < 0: self.control_slider.value = length - 1 # wrap around else: self.control_slider.value = max(self.control_slider.value - 1, 0) elif position == "next": if self.control_slider.value + 1 > length - 1: self.control_slider.value = 0 # wrap around else: self.control_slider.value = min(self.control_slider.value + 1, length - 1) self.position_text.value = self.control_slider.value def change_select(self, change=None): """ """ self.update_control_slider(change) self.regenerate() def update_control_slider(self, change=None): self.net.config["dashboard.dataset"] = self.control_select.value if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0: self.total_text.value = "of 0" self.control_slider.value = 0 self.position_text.value = 0 self.control_slider.disabled = True self.position_text.disabled = True for child in self.control_buttons.children: if not hasattr(child, "icon") or child.icon != "refresh": child.disabled = True return if self.control_select.value == "Test": self.total_text.value = "of %s" % len(self.net.dataset.test_inputs) minmax = (0, max(len(self.net.dataset.test_inputs) - 1, 0)) if minmax[0] <= self.control_slider.value <= minmax[1]: pass # ok else: self.control_slider.value = 0 self.control_slider.min = minmax[0] self.control_slider.max = minmax[1] if len(self.net.dataset.test_inputs) == 0: disabled = True else: disabled = False elif self.control_select.value == "Train": self.total_text.value = "of %s" % len(self.net.dataset.train_inputs) minmax = (0, max(len(self.net.dataset.train_inputs) - 1, 0)) if minmax[0] <= self.control_slider.value <= minmax[1]: pass # ok else: self.control_slider.value = 0 self.control_slider.min = minmax[0] self.control_slider.max = minmax[1] if len(self.net.dataset.train_inputs) == 0: disabled = True else: disabled = False self.control_slider.disabled = disabled self.position_text.disbaled = disabled self.position_text.value = self.control_slider.value for child in self.control_buttons.children: if not hasattr(child, "icon") or child.icon != "refresh": child.disabled = disabled def update_zoom_slider(self, change): if change["name"] == "value": self.net.config["svg_scale"] = self.zoom_slider.value self.regenerate() def update_position_text(self, change): # {'name': 'value', 'old': 2, 'new': 3, 'owner': IntText(value=3, layout=Layout(width='100%')), 'type': 'change'} self.control_slider.value = change["new"] def get_current_input(self): if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0: return self.net.dataset.train_inputs[self.control_slider.value] elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0: return self.net.dataset.test_inputs[self.control_slider.value] def get_current_targets(self): if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0: return self.net.dataset.train_targets[self.control_slider.value] elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0: return self.net.dataset.test_targets[self.control_slider.value] def update_slider_control(self, change): if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0: self.total_text.value = "of 0" return if change["name"] == "value": self.position_text.value = self.control_slider.value if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0: self.total_text.value = "of %s" % len(self.net.dataset.train_inputs) if self.net.model is None: return if not dynamic_pictures_check(): self.regenerate(inputs=self.net.dataset.train_inputs[self.control_slider.value], targets=self.net.dataset.train_targets[self.control_slider.value]) return output = self.net.propagate(self.net.dataset.train_inputs[self.control_slider.value], class_id=self.class_id, update_pictures=True) if self.feature_bank.value in self.net.layer_dict.keys(): self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.train_inputs[self.control_slider.value], cols=self.feature_columns.value, scale=self.feature_scale.value, html=False) if self.net.config["show_targets"]: if len(self.net.output_bank_order) == 1: ## FIXME: use minmax of output bank self.net.display_component([self.net.dataset.train_targets[self.control_slider.value]], "targets", class_id=self.class_id, minmax=(-1, 1)) else: self.net.display_component(self.net.dataset.train_targets[self.control_slider.value], "targets", class_id=self.class_id, minmax=(-1, 1)) if self.net.config["show_errors"]: ## minmax is error if len(self.net.output_bank_order) == 1: errors = np.array(output) - np.array(self.net.dataset.train_targets[self.control_slider.value]) self.net.display_component([errors.tolist()], "errors", class_id=self.class_id, minmax=(-1, 1)) else: errors = [] for bank in range(len(self.net.output_bank_order)): errors.append( np.array(output[bank]) - np.array(self.net.dataset.train_targets[self.control_slider.value][bank])) self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1)) elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0: self.total_text.value = "of %s" % len(self.net.dataset.test_inputs) if self.net.model is None: return if not dynamic_pictures_check(): self.regenerate(inputs=self.net.dataset.test_inputs[self.control_slider.value], targets=self.net.dataset.test_targets[self.control_slider.value]) return output = self.net.propagate(self.net.dataset.test_inputs[self.control_slider.value], class_id=self.class_id, update_pictures=True) if self.feature_bank.value in self.net.layer_dict.keys(): self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.test_inputs[self.control_slider.value], cols=self.feature_columns.value, scale=self.feature_scale.value, html=False) if self.net.config["show_targets"]: ## FIXME: use minmax of output bank self.net.display_component([self.net.dataset.test_targets[self.control_slider.value]], "targets", class_id=self.class_id, minmax=(-1, 1)) if self.net.config["show_errors"]: ## minmax is error if len(self.net.output_bank_order) == 1: errors = np.array(output) - np.array(self.net.dataset.test_targets[self.control_slider.value]) self.net.display_component([errors.tolist()], "errors", class_id=self.class_id, minmax=(-1, 1)) else: errors = [] for bank in range(len(self.net.output_bank_order)): errors.append( np.array(output[bank]) - np.array(self.net.dataset.test_targets[self.control_slider.value][bank])) self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1)) def toggle_play(self, button): ## toggle if self.button_play.description == "Play": self.button_play.description = "Stop" self.button_play.icon = "pause" self.player.resume() else: self.button_play.description = "Play" self.button_play.icon = "play" self.player.pause() def prop_one(self, button=None): self.update_slider_control({"name": "value"}) def regenerate(self, button=None, inputs=None, targets=None): ## Protection when deleting object on shutdown: if isinstance(button, dict) and 'new' in button and button['new'] is None: return ## Update the config: self.net.config["dashboard.features.bank"] = self.feature_bank.value self.net.config["dashboard.features.columns"] = self.feature_columns.value self.net.config["dashboard.features.scale"] = self.feature_scale.value inputs = inputs if inputs is not None else self.get_current_input() targets = targets if targets is not None else self.get_current_targets() features = None if self.feature_bank.value in self.net.layer_dict.keys() and inputs is not None: if self.net.model is not None: features = self.net.propagate_to_features(self.feature_bank.value, inputs, cols=self.feature_columns.value, scale=self.feature_scale.value, display=False) svg = """<p style="text-align:center">%s</p>""" % (self.net.to_svg( inputs=inputs, targets=targets, class_id=self.class_id, highlights={self.feature_bank.value: { "border_color": "orange", "border_width": 30, }})) if inputs is not None and features is not None: html_horizontal = """ <table align="center" style="width: 100%%;"> <tr> <td valign="top" style="width: 50%%;">%s</td> <td valign="top" align="center" style="width: 50%%;"><p style="text-align:center"><b>%s</b></p>%s</td> </tr> </table>""" html_vertical = """ <table align="center" style="width: 100%%;"> <tr> <td valign="top">%s</td> </tr> <tr> <td valign="top" align="center"><p style="text-align:center"><b>%s</b></p>%s</td> </tr> </table>""" self.net_svg.value = (html_vertical if self.net.config["svg_rotate"] else html_horizontal) % ( svg, "%s details" % self.feature_bank.value, features) else: self.net_svg.value = svg def make_colormap_image(self, colormap_name): from .layers import Layer if not colormap_name: colormap_name = get_colormap() layer = Layer("Colormap", 100) minmax = layer.get_act_minmax() image = layer.make_image(np.arange(minmax[0], minmax[1], .01), colormap_name, {"pixels_per_unit": 1, "svg_rotate": self.net.config["svg_rotate"]}).resize((300, 25)) return image def set_attr(self, obj, attr, value): if value not in [{}, None]: ## value is None when shutting down if isinstance(value, dict): value = value["value"] if isinstance(obj, dict): obj[attr] = value else: setattr(obj, attr, value) ## was crashing on Widgets.__del__, if get_ipython() no longer existed self.regenerate() def make_controls(self): layout = Layout(width='100%', height="100%") button_begin = Button(icon="fast-backward", layout=layout) button_prev = Button(icon="backward", layout=layout) button_next = Button(icon="forward", layout=layout) button_end = Button(icon="fast-forward", layout=layout) #button_prop = Button(description="Propagate", layout=Layout(width='100%')) #button_train = Button(description="Train", layout=Layout(width='100%')) self.button_play = Button(icon="play", description="Play", layout=layout) step_down = Button(icon="sort-down", layout=Layout(width="95%", height="100%")) step_up = Button(icon="sort-up", layout=Layout(width="95%", height="100%")) up_down = HBox([step_down, step_up], layout=Layout(width="100%", height="100%")) refresh_button = Button(icon="refresh", layout=Layout(width="25%", height="100%")) self.position_text = IntText(value=0, layout=layout) self.control_buttons = HBox([ button_begin, button_prev, #button_train, self.position_text, button_next, button_end, self.button_play, up_down, refresh_button ], layout=Layout(width='100%', height="100%")) length = (len(self.net.dataset.train_inputs) - 1) if len(self.net.dataset.train_inputs) > 0 else 0 self.control_slider = IntSlider(description="Dataset index", continuous_update=False, min=0, max=max(length, 0), value=0, layout=Layout(width='100%')) if self.net.config["dashboard.dataset"] == "Train": length = len(self.net.dataset.train_inputs) else: length = len(self.net.dataset.test_inputs) self.total_text = Label(value="of %s" % length, layout=Layout(width="100px")) self.zoom_slider = FloatSlider(description="Zoom", continuous_update=False, min=0, max=1.0, style={"description_width": 'initial'}, layout=Layout(width="65%"), value=self.net.config["svg_scale"] if self.net.config["svg_scale"] is not None else 0.5) ## Hook them up: button_begin.on_click(lambda button: self.goto("begin")) button_end.on_click(lambda button: self.goto("end")) button_next.on_click(lambda button: self.goto("next")) button_prev.on_click(lambda button: self.goto("prev")) self.button_play.on_click(self.toggle_play) self.control_slider.observe(self.update_slider_control, names='value') refresh_button.on_click(lambda widget: (self.update_control_slider(), self.output.clear_output(), self.regenerate())) step_down.on_click(lambda widget: self.move_step("down")) step_up.on_click(lambda widget: self.move_step("up")) self.zoom_slider.observe(self.update_zoom_slider, names='value') self.position_text.observe(self.update_position_text, names='value') # Put them together: controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")), self.control_buttons], layout=Layout(width='100%')) #net_page = VBox([control, self.net_svg], layout=Layout(width='95%')) controls.on_displayed(lambda widget: self.regenerate()) return controls def move_step(self, direction): """ Move the layer stepper up/down through network """ options = [""] + [layer.name for layer in self.net.layers] index = options.index(self.feature_bank.value) if direction == "up": new_index = (index + 1) % len(options) else: ## down new_index = (index - 1) % len(options) self.feature_bank.value = options[new_index] self.regenerate() def make_config(self): layout = Layout() style = {"description_width": "initial"} checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"], layout=layout, style=style) checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value') checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"], layout=layout, style=style) checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value') hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:", style=style, layout=layout) hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value') vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:", style=style, layout=layout) vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value') self.feature_bank = Select(description="Details:", value=self.net.config["dashboard.features.bank"], options=[""] + [layer.name for layer in self.net.layers], rows=1) self.feature_bank.observe(self.regenerate, names='value') self.control_select = Select( options=['Test', 'Train'], value=self.net.config["dashboard.dataset"], description='Dataset:', rows=1 ) self.control_select.observe(self.change_select, names='value') column1 = [self.control_select, self.zoom_slider, hspace, vspace, HBox([checkbox1, checkbox2]), self.feature_bank, self.feature_columns, self.feature_scale ] ## Make layer selectable, and update-able: column2 = [] layer = self.net.layers[-1] self.layer_select = Select(description="Layer:", value=layer.name, options=[layer.name for layer in self.net.layers], rows=1) self.layer_select.observe(self.update_layer_selection, names='value') column2.append(self.layer_select) self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout) self.layer_visible_checkbox.observe(self.update_layer, names='value') column2.append(self.layer_visible_checkbox) self.layer_colormap = Select(description="Colormap:", options=[""] + AVAILABLE_COLORMAPS, value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1) self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))) self.layer_colormap.observe(self.update_layer, names='value') column2.append(self.layer_colormap) column2.append(self.layer_colormap_image) ## get dynamic minmax; if you change it it will set it in layer as override: minmax = layer.get_act_minmax() self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style) self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style) self.layer_mindim.observe(self.update_layer, names='value') self.layer_maxdim.observe(self.update_layer, names='value') column2.append(self.layer_mindim) column2.append(self.layer_maxdim) output_shape = layer.get_output_shape() self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style) self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout) self.layer_feature.observe(self.update_layer, names='value') column2.append(self.layer_feature) self.svg_rotate = Checkbox(description="Rotate network", value=self.net.config["svg_rotate"], style={"description_width": 'initial'}, layout=Layout(width="52%")) self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value') self.save_config_button = Button(icon="save", layout=Layout(width="10%")) self.save_config_button.on_click(self.save_config) column2.append(HBox([self.svg_rotate, self.save_config_button])) config_children = HBox([VBox(column1, layout=Layout(width="100%")), VBox(column2, layout=Layout(width="100%"))]) accordion = Accordion(children=[config_children]) accordion.set_title(0, self.net.name) accordion.selected_index = None return accordion def save_config(self, widget=None): self.net.save_config() def update_layer(self, change): """ Update the layer object, and redisplay. """ if self._ignore_layer_updates: return ## The rest indicates a change to a display variable. ## We need to save the value in the layer, and regenerate ## the display. # Get the layer: layer = self.net[self.layer_select.value] # Save the changed value in the layer: layer.feature = self.layer_feature.value layer.visible = self.layer_visible_checkbox.value ## These three, dealing with colors of activations, ## can be done with a prop_one(): if "color" in change["owner"].description.lower(): ## Matches: Colormap, lefmost color, rightmost color ## overriding dynamic minmax! layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value) layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value) layer.colormap = self.layer_colormap.value if self.layer_colormap.value else None self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)) self.prop_one() else: self.regenerate() def update_layer_selection(self, change): """ Just update the widgets; don't redraw anything. """ ## No need to redisplay anything self._ignore_layer_updates = True ## First, get the new layer selected: layer = self.net[self.layer_select.value] ## Now, let's update all of the values without updating: self.layer_visible_checkbox.value = layer.visible self.layer_colormap.value = layer.colormap if layer.colormap != "" else "" self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)) minmax = layer.get_act_minmax() self.layer_mindim.value = minmax[0] self.layer_maxdim.value = minmax[1] self.layer_feature.value = layer.feature self._ignore_layer_updates = False
class InteractiveLogging(): def __init__(self, settings, test_name=None, default_window='hanning'): if default_window is None: default_window = 'None' self.settings = settings self.test_name = test_name self.dataset = datastructure.DataSet() # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out = Output() self.out_top = Output(layout=Layout(height='120px')) # Initialise variables self.current_view = 'Time' self.N_frames = 1 self.overlap = 0.5 self.iw_fft_power = 0 self.iw_tf_power = 0 self.legend_loc = 'lower right' # puts plot inside widget so can have buttons next to plot self.outplot = Output(layout=Layout(width='85%')) # BUTTONS items_measure = [ 'Log Data', 'Delete Last Measurement', 'Reset (clears all data)', 'Load Data' ] items_view = ['View Time', 'View FFT', 'View TF'] items_calc = ['Calc FFT', 'Calc TF', 'Calc TF average'] items_axes = ['xmin', 'xmax', 'ymin', 'ymax'] items_save = ['Save Dataset', 'Save Figure'] items_iw = ['multiply iw', 'divide iw'] items_legend = ['Left', 'on/off', 'Right'] self.buttons_measure = [ Button(description=i, layout=Layout(width='33%')) for i in items_measure ] self.buttons_measure[0].button_style = 'success' self.buttons_measure[0].style.font_weight = 'bold' self.buttons_measure[1].button_style = 'warning' self.buttons_measure[1].style.font_weight = 'bold' self.buttons_measure[2].button_style = 'danger' self.buttons_measure[2].style.font_weight = 'bold' self.buttons_measure[3].button_style = 'primary' self.buttons_measure[3].style.font_weight = 'bold' self.buttons_view = [ Button(description=i, layout=Layout(width='95%')) for i in items_view ] self.buttons_view[0].button_style = 'info' self.buttons_view[1].button_style = 'info' self.buttons_view[2].button_style = 'info' self.buttons_calc = [ Button(description=i, layout=Layout(width='99%')) for i in items_calc ] self.buttons_calc[0].button_style = 'primary' self.buttons_calc[1].button_style = 'primary' self.buttons_calc[2].button_style = 'primary' self.buttons_iw_fft = [ Button(description=i, layout=Layout(width='50%')) for i in items_iw ] self.buttons_iw_fft[0].button_style = 'info' self.buttons_iw_fft[1].button_style = 'info' self.buttons_iw_tf = [ Button(description=i, layout=Layout(width='50%')) for i in items_iw ] self.buttons_iw_tf[0].button_style = 'info' self.buttons_iw_tf[1].button_style = 'info' self.buttons_match = Button(description='Match Amplitudes', layout=Layout(width='99%')) self.buttons_match.button_style = 'info' self.buttons_save = [ Button(description=i, layout=Layout(width='50%')) for i in items_save ] self.buttons_save[0].button_style = 'success' self.buttons_save[0].style.font_weight = 'bold' self.buttons_save[1].button_style = 'success' self.buttons_save[1].style.font_weight = 'bold' self.button_warning = Button( description= 'WARNING: Data may be clipped. Press here to delete last measurement.', layout=Layout(width='100%')) self.button_warning.button_style = 'danger' self.button_warning.style.font_weight = 'bold' self.button_X = Button(description='Auto X', layout=Layout(width='95%')) self.button_Y = Button(description='Auto Y', layout=Layout(width='95%')) self.button_X.button_style = 'success' self.button_Y.button_style = 'success' self.buttons_legend = [ Button(description=i, layout=Layout(width='31%')) for i in items_legend ] self.buttons_legend_toggle = ToggleButton(description='Click-and-drag', layout=Layout( width='95%', alignitems='start')) # TEXT/LABELS/DROPDOWNS self.item_iw_fft_label = Label(value='iw power={}'.format(0), layout=Layout(width='100%')) self.item_iw_tf_label = Label(value='iw power={}'.format(0), layout=Layout(width='100%')) self.item_label = Label(value="Frame length = {:.2f} seconds.".format( settings.stored_time / (self.N_frames - self.overlap * self.N_frames + self.overlap))) self.item_axis_label = Label(value="Axes control:", layout=Layout(width='95%')) self.item_view_label = Label(value="View data:", layout=Layout(width='95%')) self.item_legend_label = Label(value="Legend position:", layout=Layout(width='95%')) self.item_blank_label = Label(value="", layout=Layout(width='95%')) self.text_axes = [ FloatText(value=0, description=i, layout=Layout(width='95%')) for i in items_axes ] self.text_axes = [self.button_X] + [self.button_Y] + self.text_axes self.drop_window = Dropdown(options=['None', 'hanning'], value=default_window, description='Window:', layout=Layout(width='99%')) self.slide_Nframes = IntSlider(value=1, min=1, max=30, step=1, description='N_frames:', continuous_update=True, readout=False, layout=Layout(width='99%')) self.text_Nframes = IntText(value=1, description='N_frames:', layout=Layout(width='99%')) # VERTICAL GROUPS group0 = VBox([ self.buttons_calc[0], self.drop_window, HBox(self.buttons_iw_fft) ], layout=Layout(width='33%')) group1 = VBox([ self.buttons_calc[1], self.drop_window, self.slide_Nframes, self.text_Nframes, self.item_label, HBox(self.buttons_iw_tf), self.buttons_match ], layout=Layout(width='33%')) group2 = VBox([ self.buttons_calc[2], self.drop_window, HBox(self.buttons_iw_tf), self.buttons_match ], layout=Layout(width='33%')) group_view = VBox( [self.item_axis_label] + self.text_axes + [self.item_legend_label, self.buttons_legend_toggle] + [HBox(self.buttons_legend), self.item_view_label] + self.buttons_view, layout=Layout(width='20%')) # ASSEMBLE display(self.out_top) display(HBox([self.button_warning])) display(HBox(self.buttons_measure)) display(HBox([self.outplot, group_view])) display(HBox([group0, group1, group2])) display(HBox(self.buttons_save)) self.button_warning.layout.visibility = 'hidden' # second part to putting plot inside widget with self.outplot: self.p = plotting.PlotData(figsize=(7.5, 4)) ## Make buttons/boxes interactive self.text_axes[2].observe(self.xmin, names='value') self.text_axes[3].observe(self.xmax, names='value') self.text_axes[4].observe(self.ymin, names='value') self.text_axes[5].observe(self.ymax, names='value') self.button_X.on_click(self.auto_x) self.button_Y.on_click(self.auto_y) self.buttons_legend[0].on_click(self.legend_left) self.buttons_legend[1].on_click(self.legend_onoff) self.buttons_legend[2].on_click(self.legend_right) self.buttons_legend_toggle.observe(self.legend_toggle) self.slide_Nframes.observe(self.nframes_slide) self.text_Nframes.observe(self.nframes_text) self.buttons_measure[0].on_click(self.measure) self.buttons_measure[1].on_click(self.undo) self.buttons_measure[2].on_click(self.reset) self.buttons_measure[3].on_click(self.load_data) self.buttons_view[0].on_click(self.view_time) self.buttons_view[1].on_click(self.view_fft) self.buttons_view[2].on_click(self.view_tf) self.buttons_calc[0].on_click(self.fft) self.buttons_calc[1].on_click(self.tf) self.buttons_calc[2].on_click(self.tf_av) self.buttons_iw_fft[0].on_click(self.xiw_fft) self.buttons_iw_fft[1].on_click(self.diw_fft) self.buttons_iw_tf[0].on_click(self.xiw_tf) self.buttons_iw_tf[1].on_click(self.diw_tf) self.buttons_match.on_click(self.match) self.buttons_save[0].on_click(self.save_data) self.buttons_save[1].on_click(self.save_fig) self.button_warning.on_click(self.undo) self.refresh_buttons() # Put output text at bottom of display display(self.out) with self.out_top: try: streams.start_stream(settings) self.rec = streams.REC except: print('Data stream not initialised.') print( 'Possible reasons: pyaudio or PyDAQmx not installed, or acquisition hardware not connected.' ) print('Please note that it won' 't be possible to log data.') def xmin(self, v): xmin = self.text_axes[2].value xlim = self.p.ax.get_xlim() self.p.ax.set_xlim([xmin, xlim[1]]) def xmax(self, v): xmax = self.text_axes[3].value xlim = self.p.ax.get_xlim() self.p.ax.set_xlim([xlim[0], xmax]) def ymin(self, v): ymin = self.text_axes[4].value ylim = self.p.ax.get_ylim() self.p.ax.set_ylim([ymin, ylim[1]]) def ymax(self, v): ymax = self.text_axes[5].value ylim = self.p.ax.get_ylim() self.p.ax.set_ylim([ylim[0], ymax]) def auto_x(self, b): self.p.auto_x() def auto_y(self, b): self.p.auto_y() def legend_left(self, b): self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.legend_loc = 'lower left' self.p.update_legend(self.legend_loc) def legend_right(self, b): self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.legend_loc = 'lower right' self.p.update_legend(self.legend_loc) def legend_onoff(self, b): self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: visibility = self.p.ax.get_legend().get_visible() self.p.ax.get_legend().set_visible(not visibility) def legend_toggle(self, b): self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.p.ax.get_legend().set_visible(True) self.p.legend.set_draggable(self.buttons_legend_toggle.value) def nframes_slide(self, v): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.N_frames = self.slide_Nframes.value self.text_Nframes.value = self.N_frames if len(self.dataset.time_data_list) is not 0: stored_time = self.dataset.time_data_list[ 0].settings.stored_time elif len(self.dataset.tf_data_list) is not 0: stored_time = self.dataset.tf_data_list[0].settings.stored_time else: stored_time = 0 print('Time or TF data settings not found') self.item_label.value = "Frame length = {:.2f} seconds.".format( stored_time / (self.N_frames - self.overlap * self.N_frames + self.overlap)) if self.current_view is 'TF': self.tf(None) def nframes_text(self, v): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.N_frames = self.text_Nframes.value self.slide_Nframes.value = self.N_frames if len(self.dataset.time_data_list) is not 0: stored_time = self.dataset.time_data_list[ 0].settings.stored_time elif len(self.dataset.tf_data_list) is not 0: stored_time = self.dataset.tf_data_list[0].settings.stored_time else: stored_time = 0 print('Time or TF data settings not found') self.item_label.value = "Frame length = {:.2f} seconds.".format( stored_time / (self.N_frames - self.overlap * self.N_frames + self.overlap)) if self.current_view is 'TF': self.tf(None) def measure(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.rec.trigger_detected = False self.buttons_measure[0].button_style = '' if self.settings.pretrig_samples is None: self.buttons_measure[0].description = 'Logging ({}s)'.format( self.settings.stored_time) else: self.buttons_measure[ 0].description = 'Logging ({}s, with trigger)'.format( self.settings.stored_time) d = acquisition.log_data(self.settings, test_name=self.test_name, rec=self.rec) self.dataset.add_to_dataset(d.time_data_list) N = len(self.dataset.time_data_list) self.p.update(self.dataset.time_data_list, sets=[N - 1], channels='all') self.p.auto_x() # self.p.auto_y() self.p.ax.set_ylim([-1, 1]) self.current_view = 'Time' self.buttons_measure[0].button_style = 'success' self.buttons_measure[0].description = 'Log Data' if np.any(np.abs(d.time_data_list[-1].time_data) > 0.95): self.button_warning.layout.visibility = 'visible' else: self.button_warning.layout.visibility = 'hidden' xlim = self.p.ax.get_xlim() ylim = self.p.ax.get_ylim() self.text_axes[2].value = xlim[0] self.text_axes[3].value = xlim[1] self.text_axes[4].value = ylim[0] self.text_axes[5].value = ylim[1] self.refresh_buttons() def undo(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.dataset.remove_last_data_item('TimeData') self.dataset.freq_data_list = datastructure.FreqDataList() self.dataset.tf_data_list = datastructure.TfDataList() N = len(self.dataset.time_data_list) self.p.update(self.dataset.time_data_list, sets=[N - 1], channels='all') self.button_warning.layout.visibility = 'hidden' self.refresh_buttons() def reset(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.dataset = datastructure.DataSet() self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: N = len(self.dataset.time_data_list) self.p.update(self.dataset.time_data_list, sets=[N - 1], channels='all') self.refresh_buttons() def load_data(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: d = file.load_data() if d is not None: self.dataset.add_to_dataset(d.time_data_list) self.dataset.add_to_dataset(d.freq_data_list) self.dataset.add_to_dataset(d.tf_data_list) self.dataset.add_to_dataset(d.cross_spec_data_list) self.dataset.add_to_dataset(d.sono_data_list) else: print('No data loaded') if len(self.dataset.time_data_list) is not 0: self.p.update(self.dataset.time_data_list, sets='all', channels='all') elif len(self.dataset.freq_data_list) is not 0: self.p.update(self.dataset.freq_data_list, sets='all', channels='all') elif len(self.dataset.tf_data_list) is not 0: self.p.update(self.dataset.tf_data_list, sets='all', channels='all') else: print('No data to view') self.refresh_buttons() self.p.auto_x() self.p.auto_y() xlim = self.p.ax.get_xlim() ylim = self.p.ax.get_ylim() self.text_axes[2].value = xlim[0] self.text_axes[3].value = xlim[1] self.text_axes[4].value = ylim[0] self.text_axes[5].value = ylim[1] def view_time(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.refresh_buttons() N = len(self.dataset.time_data_list) if N is not 0: self.p.update(self.dataset.time_data_list) if self.current_view is not 'Time': self.current_view = 'Time' self.p.auto_x() self.p.auto_y() xlim = self.p.ax.get_xlim() ylim = self.p.ax.get_ylim() self.text_axes[2].value = xlim[0] self.text_axes[3].value = xlim[1] self.text_axes[4].value = ylim[0] self.text_axes[5].value = ylim[1] else: print('no time data to display') def view_fft(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: N = len(self.dataset.freq_data_list) self.refresh_buttons() if N is not 0: self.p.update(self.dataset.freq_data_list) if self.current_view is not 'FFT': self.current_view = 'FFT' self.p.auto_x() self.p.auto_y() xlim = self.p.ax.get_xlim() ylim = self.p.ax.get_ylim() self.text_axes[2].value = xlim[0] self.text_axes[3].value = xlim[1] self.text_axes[4].value = ylim[0] self.text_axes[5].value = ylim[1] else: print('no FFT data to display') def view_tf(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: N = len(self.dataset.tf_data_list) self.refresh_buttons() if N is not 0: self.p.update(self.dataset.tf_data_list) if self.current_view is not 'TF': self.current_view = 'TF' self.p.auto_x() self.p.auto_y() xlim = self.p.ax.get_xlim() ylim = self.p.ax.get_ylim() self.text_axes[2].value = xlim[0] self.text_axes[3].value = xlim[1] self.text_axes[4].value = ylim[0] self.text_axes[5].value = ylim[1] else: print('no TF data to display') def fft(self, b): window = self.drop_window.value if window is 'None': window = None # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.dataset.calculate_fft_set(window=window) self.p.update(self.dataset.freq_data_list) if self.current_view is not 'FFT': self.current_view = 'FFT' self.p.auto_x() self.p.auto_y() xlim = self.p.ax.get_xlim() ylim = self.p.ax.get_ylim() self.text_axes[2].value = xlim[0] self.text_axes[3].value = xlim[1] self.text_axes[4].value = ylim[0] self.text_axes[5].value = ylim[1] self.refresh_buttons() def tf(self, b): N_frames = self.N_frames window = self.drop_window.value if window is 'None': window = None # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.dataset.calculate_tf_set(window=window, N_frames=N_frames, overlap=self.overlap) self.p.update(self.dataset.tf_data_list) if self.current_view is not 'TF': self.current_view = 'TF' self.p.auto_x() self.p.auto_y() xlim = self.p.ax.get_xlim() ylim = self.p.ax.get_ylim() self.text_axes[2].value = xlim[0] self.text_axes[3].value = xlim[1] self.text_axes[4].value = ylim[0] self.text_axes[5].value = ylim[1] self.refresh_buttons() def tf_av(self, b): window = self.drop_window.value if window is 'None': window = None # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: self.dataset.calculate_tf_averaged(window=window) self.p.update(self.dataset.tf_data_list) if self.current_view is not 'TFAV': self.current_view = 'TFAV' self.p.auto_x() self.p.auto_y() xlim = self.p.ax.get_xlim() ylim = self.p.ax.get_ylim() self.text_axes[2].value = xlim[0] self.text_axes[3].value = xlim[1] self.text_axes[4].value = ylim[0] self.text_axes[5].value = ylim[1] self.refresh_buttons() def xiw_fft(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: if self.current_view is 'FFT': s = self.p.get_selected_channels() n_sets, n_chans = np.shape(s) for ns in range(n_sets): newdata = analysis.multiply_by_power_of_iw( self.dataset.freq_data_list[ns], power=1, channel_list=s[ns, :]) self.dataset.freq_data_list[ns] = newdata self.p.update(self.dataset.freq_data_list) self.p.auto_y() self.iw_fft_power += 1 print('Multiplied by (iw)**{}'.format(self.iw_fft_power)) else: print('First press <Calc FFT>') def diw_fft(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: if self.current_view is 'FFT': s = self.p.get_selected_channels() n_sets, n_chans = np.shape(s) for ns in range(n_sets): newdata = analysis.multiply_by_power_of_iw( self.dataset.freq_data_list[ns], power=-1, channel_list=s[ns, :]) self.dataset.freq_data_list[ns] = newdata self.p.update(self.dataset.freq_data_list) self.p.auto_y() self.iw_fft_power -= 1 print('Multiplied by (iw)**{}'.format(self.iw_fft_power)) else: print('First press <Calc FFT>') def xiw_tf(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: if (self.current_view is 'TF') or (self.current_view is 'TFAV'): s = self.p.get_selected_channels() n_sets, n_chans = np.shape(s) for ns in range(n_sets): newdata = analysis.multiply_by_power_of_iw( self.dataset.tf_data_list[ns], power=1, channel_list=s[ns, :]) self.dataset.tf_data_list[ns] = newdata self.p.update(self.dataset.tf_data_list) self.p.auto_y() self.iw_tf_power += 1 print('Multiplied selected channel by (iw)**{}'.format( self.iw_tf_power)) else: print('First press <Calc TF> or <Calc TF average>') def diw_tf(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: if (self.current_view is 'TF') or (self.current_view is 'TFAV'): s = self.p.get_selected_channels() n_sets, n_chans = np.shape(s) for ns in range(n_sets): newdata = analysis.multiply_by_power_of_iw( self.dataset.tf_data_list[ns], power=-1, channel_list=s[ns, :]) self.dataset.tf_data_list[ns] = newdata self.p.update(self.dataset.tf_data_list) self.p.auto_y() self.iw_tf_power -= 1 print('Multiplied selected channel by (iw)**{}'.format( self.iw_tf_power)) else: print('First press <Calc TF> or <Calc TF average>') def match(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: if (self.current_view is 'TF') or (self.current_view is 'TFAV'): freq_range = self.p.ax.get_xlim() current_calibration_factors = self.dataset.tf_data_list.get_calibration_factors( ) reference = current_calibration_factors[0][0] factors = analysis.best_match(self.dataset.tf_data_list, freq_range=freq_range, set_ref=0, ch_ref=0) factors = [reference * x for x in factors] self.dataset.tf_data_list.set_calibration_factors_all(factors) self.p.update(self.dataset.tf_data_list) print('scale factors:') print(factors) #self.p.auto_y() else: print( 'First press <View TF> or <Calc TF> or <Calc TF average>') def save_data(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: print('Saving dataset:') print(self.dataset) self.dataset.save_data() def save_fig(self, b): # the 'out' construction is to refresh the text output at each update # to stop text building up in the widget display self.out.clear_output(wait=False) self.out_top.clear_output(wait=False) with self.out_top: file.save_fig(self.p, figsize=(9, 5)) def refresh_buttons(self): if len(self.dataset.time_data_list) is 0: self.buttons_view[0].button_style = '' else: self.buttons_view[0].button_style = 'info' if len(self.dataset.freq_data_list) is 0: self.buttons_view[1].button_style = '' else: self.buttons_view[1].button_style = 'info' if len(self.dataset.tf_data_list) is 0: self.buttons_view[2].button_style = '' else: self.buttons_view[2].button_style = 'info'