示例#1
0
class MpltFigureViewerWidget(HBox):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._current_cell = None

        self.layout = Layout(
            display="none",
            flex_direction="row",
            border="solid 1px gray",
            width="100%",
            height="auto",
            overflow="auto",
        )
        self._output = Output()
        self.children = [self._output]

    def reset(self):
        self._output.clear_output()
        self.layout.display = "none"
        self._current_cell = None

    def show_figure(self, figure, cell):
        if self.layout.display != "flex":
            self.layout.display = "flex"
        if self._current_cell is not None:
            self._current_cell.unselect_figure()
        self._current_cell = cell
        self._output.clear_output()
        with self._output:
            if hasattr(figure, "canvas"):
                display(figure.canvas)
            else:
                display(figure)
示例#2
0
class SearchWidget:

    def __init__(self):
        self.search = Text(placeholder='Name')
        self.go_btn = Button(description='Search')
        self.clear_btn = Button(description='Clear')
        self.out = Output()

        self.search.continuous_update = False
        self.search.observe(self.do_search, 'value')
        self.go_btn.on_click(self.do_search)
        self.clear_btn.on_click(self.do_clear)

    def do_clear(self, cb):
        self.out.clear_output()

    def do_search(self, cb):
        name = self.search.value

        self.out.clear_output()

        with self.out:
            obj = env.project[name]
            if obj.exists():
                display(obj.gui.header())
            else:
                print(f"{obj.name} valid, but doesn't exist")

    def as_widget(self):
        return VBox([HBox([self.search, self.go_btn, self.clear_btn]), self.out])
class Plotter(object):
    def __init__(self, id_string='', width=12, height=2.5, show_plot=True):
        """A dynamic plotting widget for tracking training progress in notebooks."""
        self.id_string = id_string
        self.width = width
        self.height = height
        self.output = Output()
        self.metrics = defaultdict(list)
        self.show_plot = show_plot

    def update(self, **metrics):
        for k, v in metrics.items():
            if type(v) is list:
                self.metrics[k] += v
            else:
                self.metrics[k].append(v)

        self.output.clear_output(wait=True)
        with self.output:
            if self.show_plot:
                self.plot()
                plt.show()
            maxlen = max(map(len, self.metrics.keys()))
            print(self.id_string)
            for k, v in self.metrics.items():
                print(('%' + str(maxlen) + 's') % k,
                      '| current = %.2e' % v[-1],
                      '| max = %.2e (iter %4d)' % (np.max(v), np.argmax(v)),
                      '| min = %.2e (iter %4d)' % (np.min(v), np.argmin(v)))

    def show(self):
        display(self.output)

    def progress_string(self):
        s = self.id_string + '\n'
        maxlen = max(map(len, self.metrics.keys()))
        for k, v in self.metrics.items():
            s += ''.join([
                ('%' + str(maxlen) + 's') % k,
                '| current = %.2e' % v[-1],
                '| max = %.2e (iter %4d)' % (np.max(v), np.argmax(v)),
                '| min = %.2e (iter %4d)' % (np.min(v), np.argmin(v))
            ])
            s += '\n'
        return s

    def plot(self):
        fig = plt.figure(figsize=(self.width, self.height * len(self.metrics)))
        axs = fig.subplots(len(self.metrics))
        fig.suptitle(self.id_string)
        if len(self.metrics) == 1:
            axs = [axs]
        for ax, (k, v) in zip(axs, self.metrics.items()):
            ax.plot(v)
            ax.grid()
            ax.set_title(k)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        return fig
示例#4
0
class SuperImpose(object):
    "a tool to superimpose spectra"

    def __init__(self, base=None, filetype='*.msh5', N=None):
        if N is None:
            N = int(input('how many spectra do you want to compare:  '))
        self.Chooser = FileChooser(base=base,
                                   filetype=filetype,
                                   mode='r',
                                   show=False)
        self.bsel = widgets.Button(
            description='Copy',
            layout=Layout(width='10%'),
            button_style='info',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='copy selected data-set to entry below')
        self.to = widgets.IntText(value=1,
                                  min=1,
                                  max=N,
                                  layout=Layout(width='10%'))
        self.bsel.on_click(self.copy)
        self.bdisplay = widgets.Button(
            description='Display',
            layout=Layout(width='10%'),
            button_style='info',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='display superimposition')
        self.bdisplay.on_click(self.display)
        self.spec = Output(layout={'border': '1px solid black'})
        self.DataList = [SpforSuper(i + 1, 'None') for i in range(N)]
        self.DataList[0].color.value = 'black'
        self.DataList[0].fig = True  # switches on the very first one

    def Show(self):
        display(
            widgets.Label(
                'Select a file, and click on the Copy button to copy it to the chosen slot'
            ))
        self.Chooser.show()
        display(HBox([self.bsel, widgets.Label('to'), self.to]))
        display(VBox([sp.me for sp in self.DataList]))
        display(self.bdisplay)
        display(self.spec)

    def copy(self, event):
        if self.to.value < 1 or self.to.value > len(self.DataList):
            print('Destination is out of range !')
        else:
            self.DataList[self.to.value - 1].name.value = self.Chooser.file
            self.DataList[self.to.value - 1].direct.value = 'up'
        self.to.value = min(self.to.value, len(self.DataList)) + 1

    def display(self, event):
        self.spec.clear_output(wait=True)
        for i, s in enumerate(self.DataList):
            with self.spec:
                s.display()
示例#5
0
 def toggle_lineout(self, change):
     if change['new']:
         # start a new thread so the interaction with original figure won't be blocked
         self.observer_thrd = threading.Thread(
             target=self.__new_lineout_plot)
         self.observer_thrd.daemon = True
         self.observer_thrd.start()
         # display(self.out)
     else:
         self.observer_thrd.join()  # kill the thread
         Output.clear_output(self.out)
示例#6
0
class JupyterPlottingContext(PlottingContextBase):
    """ plotting in a jupyter widget using the `inline` backend """

    supports_update = False
    """ flag indicating whether the context supports that plots can be updated
    with out redrawing the entire plot. The jupyter backend (`inline`) requires
    replotting of the entire figure, so an update is not supported."""

    def __enter__(self):
        from IPython.display import display
        from ipywidgets import Output

        if self.initial_plot:
            # close all previous plots
            import matplotlib.pyplot as plt

            plt.close("all")

            # create output widget for capturing all plotting
            self._ipython_out = Output()

            if self.show:
                # only show the widget if necessary
                display(self._ipython_out)

        # capture plots in the output widget
        self._ipython_out.__enter__()

    def __exit__(self, *exc):
        import matplotlib.pyplot as plt

        # finalize plot
        super().__exit__(*exc)

        if self.show:
            # show the plot, but ...
            plt.show()  # show the figure to make sure it can be captured
        # ... also clear it the next time something is done
        self._ipython_out.clear_output(wait=True)

        # stop capturing plots in the output widget
        self._ipython_out.__exit__(*exc)

        # close the figure, so figure windows do not accumulate
        plt.close(self.fig)

    def close(self):
        """ close the plot """
        super().close()
        # close ipython output
        try:
            self._ipython_out.close()
        except Exception:
            pass
示例#7
0
class RunWidget:

    def __init__(self):
        from ipywidgets import Button
        self.code = Textarea()
        self.execute = Button(description="Run")
        self.clear_btn = Button(description="Clear")
        self.output = Output()
        self.exec_count = Label("[ ]")
        self.execute.layout.width = "60px"
        self.execute.layout.height = "50px"
        self.clear_btn.layout.width = "60px"
        self.clear_btn.layout.height = "50px"
        self.code.layout.width = "550px"
        self.code.rows = 6
        self.view = VBox([
            HBox([
                VBox([
                    self.execute, 
                    self.clear_btn,
                    self.exec_count
                ]),
                self.code
            ]),
            self.output
        ])
        self.execute.on_click(self.click)
        self.clear_btn.on_click(self.clear_click)
        self.ipython = get_ipython()

    def clear(self):
        self.code.value = ""
        self.code.rows = 6
        self.exec_count.value = "[ ]"

    def set_code(self, code):
        self.code.value = code
        self.code.rows = max(len(code.split("\n")) + 1, 6)
        self.exec_count.value = "[ ]"

    def click(self, b):
        with self.output:
            result = self.ipython.run_cell(self.code.value, store_history=True)
            self.exec_count.value = "[{}]".format(result.execution_count or " ")

    def clear_click(self, b):
        self.output.clear_output()
        self.clear()
示例#8
0
 def plot_series(self,
                 out_widget: Output,
                 variable: str,
                 loc_id: str,
                 dim_id: str = None):
     """
     """
     tts = self.get_data(variable=variable, loc_id=loc_id, dim_id=dim_id)
     # if using bqplot down the track, see https://github.com/jtpio/voila-gpx-viewer
     out_widget.clear_output()
     with out_widget:
         _ = tts.plot(figsize=(16, 8))
         # do not use display(blah) which then displays the obnoxious matplotlib.lines.Line2D object at etc.>]
         ax = plt.gca()
         ax.set_title(loc_id)
         show_inline_matplotlib_plots()
示例#9
0
 def __save_function(self, image_path):
     img_name = os.path.basename(image_path).split('.')[0]
     j_code = """
         require(["html2canvas"], function(html2canvas) {
             var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
             console.log(element);
              html2canvas(element).then(function (canvas) { 
                 var myImage = canvas.toDataURL(); 
                 var a = document.createElement("a"); 
                 a.href = myImage; 
                 a.download = "$img_name$.png"; 
                 a.click(); 
                 a.remove(); 
             });
         });
         """
     j_code = j_code.replace('$it_name$', self.name)
     j_code = j_code.replace('$img_name$', img_name)
     tmp_out = Output()
     with tmp_out:
         display(Javascript(j_code))
         tmp_out.clear_output()
示例#10
0
class OutputWidgetHandler(logging.Handler):
    """ Custom logging handler sending logs to an output widget """
    def __init__(self, *args, **kwargs):
        super(OutputWidgetHandler, self).__init__(*args, **kwargs)
        layout = {'border': '1px solid black'}
        self.out = Output(layout=layout)

    def emit(self, record):
        """ Overload of logging.Handler method """
        formatted_record = self.format(record)
        new_output = {
            'name': 'stdout',
            'output_type': 'stream',
            'text': formatted_record + '\n'
        }
        self.out.outputs = (new_output, ) + self.out.outputs

    def show_logs(self):
        """ Show the logs """
        display(self.out)

    def clear_logs(self):
        """ Clear the current logs """
        self.out.clear_output()
示例#11
0
class Replay(object):
    def __init__(self, quality, deviation, angular_tolerance, edge_accuracy,
                 debug, cad_width, height):
        self.debug_output = Output()
        self.quality = quality
        self.deviation = deviation
        self.angular_tolerance = angular_tolerance
        self.edge_accuracy = edge_accuracy
        self.debug = debug
        self.cad_width = cad_width
        self.height = height
        self.display = CadqueryDisplay()
        widget = self.display.create(height=height, cad_width=cad_width)
        self.display.display(widget)

    def format_steps(self, raw_steps):
        def to_code(step, results):
            def to_name(obj):
                if isinstance(obj, cq.Workplane):
                    name = results.get(obj, None)
                else:
                    name = str(obj)
                return obj if name is None else name

            if step.func != "":
                if step.func == "newObject":
                    args = ("...", )
                else:
                    args = tuple([to_name(arg) for arg in step.args])
                code = "%s%s%s" % ("| " * step.level, step.func, args)
                code = code[:-2] if len(step.args) == 1 else code[:-1]
                if len(step.args) > 0 and len(step.kwargs) > 0:
                    code += ","
                if step.kwargs != {}:
                    code += ", ".join(
                        ["%s=%s" % (k, v) for k, v in step.kwargs.items()])
                code += ")"
                if step.result_name != "":
                    code += " => %s" % step.result_name
            elif step.var != "":
                code = "%s%s" % ("| " * step.level, step.var)
            else:
                code = "ERROR"
            return code

        steps = []
        entries = []
        obj_index = 1

        results = {step.result_obj: None for step in raw_steps}

        for i in range(len(raw_steps)):
            step = raw_steps[i]
            next_level = step.level if i == (len(raw_steps) -
                                             1) else raw_steps[i + 1].level

            # level change, so add/use the variable name
            if step.level > 0 and step.level != next_level and step.result_name == "":
                obj_name = "_v%d" % obj_index
                obj_index += 1
                step.result_name = obj_name
            steps.append(step)

        for step in steps:
            if results[step.result_obj] is None:
                # first occurence, take note and keep
                results[step.result_obj] = step.result_name
            else:
                # next occurences remove function and add variable name
                step.var = results[step.result_obj]
                step.clear_func()

        last_level = 1000000
        for step in reversed(steps):
            if step.level < last_level:
                last_level = 1000000
                entries.insert(0, (to_code(step, results), step.result_obj))
                if step.var != "":
                    last_level = step.level

        return entries

    def to_array(self, workplane, level=0, result_name=""):
        def walk(caller, level=0, result_name=""):
            stack = [
                Step(
                    level,
                    func=caller["func"],
                    args=caller["args"],
                    kwargs=caller["kwargs"],
                    result_name=result_name,
                    result_obj=caller["obj"],
                )
            ]
            for child in reversed(caller["children"]):
                stack = walk(child, level + 1) + stack
                for arg in child["args"]:
                    if isinstance(arg, cq.Workplane):
                        result_name = getattr(arg, "name", None)
                        stack = self.to_array(arg,
                                              level=level + 2,
                                              result_name=result_name) + stack
            return stack

        stack = []

        obj = workplane
        while obj is not None:
            caller = getattr(obj, "_caller", None)
            result_name = getattr(obj, "name", "")
            if caller is not None:
                stack = walk(caller, level, result_name) + stack
                for arg in caller["args"]:
                    if isinstance(arg, cq.Workplane):
                        result_name = getattr(arg, "name", "")
                        stack = self.to_array(arg,
                                              level=level + 1,
                                              result_name=result_name) + stack
            obj = obj.parent

        return stack

    def select_handler(self, change):
        with self.debug_output:
            if change["name"] == "index":
                self.select(change["new"])

    def select(self, indexes):
        self.debug_output.clear_output()
        with self.debug_output:
            self.indexes = indexes
            cad_objs = [self.stack[i][1] for i in self.indexes]

        # Add hidden result to start with final size and allow for comparison
        if not isinstance(self.stack[-1][1].val(), cq.Vector):
            result = Part(self.stack[-1][1],
                          "Result",
                          show_faces=False,
                          show_edges=False)
            objs = [result] + cad_objs
        else:
            objs = cad_objs

        with self.debug_output:
            assembly = to_assembly(*objs)
            mapping = assembly.to_state()
            shapes = assembly.collect_mapped_shapes(
                mapping,
                quality=self.quality,
                deviation=self.deviation,
                angular_tolerance=self.angular_tolerance,
                edge_accuracy=self.edge_accuracy,
                render_edges=get_default("render_edges"),
                render_normals=get_default("render_normals"),
            )
            tree = assembly.to_nav_dict()

            self.display.add_shapes(shapes=shapes,
                                    mapping=mapping,
                                    tree=tree,
                                    bb=_combined_bb(shapes),
                                    reset_camera=False)
示例#12
0
class Graph(VBox):
    """Graph widget class for creating interactive graphs

    Keyword arguments:

    * `name` -- graph name
    * `delayed` -- use a draw button instead of updating on every change
    * `**kwargs` -- default configurations for the graph according to
      GraphConfig attributes and category names
    """
    def __init__(self, name="graph", delayed=False, **kwargs):
        self._display_stack = 1
        self._display_categories = set()
        self._filter_in = None
        self._filter_out = None
        self._svg_name = ""
        self._initial = kwargs
        self.delayed = delayed

        self.graph_name = name
        self.toggle_widgets = OrderedDict()
        self.color_widgets = OrderedDict()
        self.font_color_widgets = OrderedDict()

        self.filter_in_widget = Text(description="Filter In",
                                     value=kwargs.get("filter_in", ""))
        self.filter_out_widget = Text(description="Filter Out",
                                      value=kwargs.get("filter_out", ""))

        self.r_widget = self.slider("R",
                                    "r",
                                    5,
                                    70,
                                    1,
                                    21,
                                    fn=self.update_r_widget)
        self.margin_widget = self.slider("Margin", "margin", 5, 170, 1, 59)
        self.margin_left_widget = self.slider("M. Left", "margin_left", 5, 170,
                                              1, 21)
        self.dist_x_widget = self.slider("Dist. X", "dist_x", 5, 170, 1, 76)
        self.dist_y_widget = self.slider("Dist. Y", "dist_y", 5, 170, 1, 76)
        self.letters_widget = self.slider("Letters", "letters", 1, 40, 1, 7)
        self.by_year_widget = self.slider("By Year", "max_by_year", 0, 50, 1,
                                          5)

        self.places_widget = ToggleButton(description="Places",
                                          value=kwargs.get("places", False))
        self.references_widget = ToggleButton(description="References",
                                              value=kwargs.get(
                                                  "references", True))
        self.delayed_widget = Button(description="Draw")

        self.output_widget = Output()

        self.filter_in_widget.observe(self.update_widget, "value")
        self.filter_out_widget.observe(self.update_widget, "value")
        self.places_widget.observe(self.update_widget, "value")
        self.references_widget.observe(self.update_widget, "value")
        self.delayed_widget.on_click(self.delayed_draw)

        self.create_widgets()

        self.update_r_widget()

        super(Graph, self).__init__([
            HBox([
                VBox([self.filter_in_widget, self.filter_out_widget] +
                     list(self.toggle_widgets.values()) + [
                         HBox([w1, w2])
                         for w1, w2 in zip(self.color_widgets.values(),
                                           self.font_color_widgets.values())
                     ] + [self.places_widget, self.references_widget] +
                     ([self.delayed_widget] if delayed else [])),
                VBox([
                    self.r_widget,
                    self.margin_widget,
                    self.margin_left_widget,
                    self.dist_x_widget,
                    self.dist_y_widget,
                    self.letters_widget,
                    self.by_year_widget,
                ]),
            ]), self.output_widget
        ])
        self.layout.display = "flex"
        self.layout.align_items = "stretch"
        self.delayed_draw()

    def delayed_draw(self, *args):
        """Draw graph"""
        self._display_stack = 0
        self.display()

    def slider(self, description, attribute, min, max, step, default, fn=None):
        """Creates slider"""
        widget = IntSlider(
            description=description,
            min=min,
            max=max,
            step=step,
            value=self._initial.get(attribute, default),
        )
        widget._configattr = attribute
        widget.observe(fn or self.update_widget, "value")
        return widget

    def update_widget(self, *args):
        """Callback for generic widgets"""
        self._display_stack += 1
        self.display()

    def update_r_widget(self, *args):
        """Callback for updating r_widget value"""
        self._display_stack += 1
        r_value = self.r_widget.value
        dist_min = 2 * r_value + 2
        letters_max = int(r_value / 3.6)
        self.margin_left_widget.min = -1
        self.margin_left_widget.value = max(r_value,
                                            self.margin_left_widget.value)
        self.margin_left_widget.min = r_value
        self.dist_x_widget.min = -1
        self.dist_x_widget.value = max(dist_min, self.dist_x_widget.value)
        self.dist_x_widget.min = dist_min
        self.dist_y_widget.min = -1
        self.dist_y_widget.value = max(dist_min, self.dist_y_widget.value)
        self.dist_y_widget.min = dist_min
        self.letters_widget.max = 5000
        self.letters_widget.value = min(letters_max, self.letters_widget.value)
        self.letters_widget.max = letters_max
        self.display()

    def visible_classes(self):
        """Generate classes"""
        for class_ in config.CLASSES:
            if class_[2] in ("display", "hide"):
                yield class_

    def create_category(self, name, attr, value, color, font_color):
        """Create category widget"""
        VIS = ["none", ""]
        widget = self.toggle_widgets[attr] = ToggleButton(value=value,
                                                          description=name)
        wcolor = self.color_widgets[attr] = ColorPicker(value=color,
                                                        description=name,
                                                        width="180px")
        wfont_color = self.font_color_widgets[attr] = ColorPicker(
            value=font_color, width="110px")

        def visibility(*args):
            """" Toggles visibility of category """
            self._display_stack += 1
            wcolor.layout.display = VIS[int(widget.value)]
            wfont_color.layout.display = VIS[int(widget.value)]
            self.display()

        widget.observe(visibility, "value")
        wcolor.observe(self.update_widget, "value")
        wfont_color.observe(self.update_widget, "value")
        visibility()

    def create_widgets(self):
        """Create custom categories"""
        for class_ in self.visible_classes():
            self.create_category(
                class_[0],
                class_[1],
                (class_[2] == "display"),
                class_[3],
                class_[4],
            )

    def graph(self):
        """Create graph"""
        reload()
        work_list = load_work()
        references = load_citations()

        self._svg_name = str(Path("output") / (self.graph_name + ".svg"))
        self._display_categories = {
            key
            for key, widget in self.toggle_widgets.items() if widget.value
        }
        self._filter_in = self.filter_in_widget.value.lower()
        self._filter_out = self.filter_out_widget.value.lower()

        work_list = list(filter(self.filter_work, work_list))
        ref_list = []
        if self.references_widget.value:
            references = ref_list = list(
                filter(
                    lambda x: self.filter_work(x.citation) and self.
                    filter_work(x.work), references))

        graph_config = GraphConfig()
        graph_config.r = self.r_widget.value
        graph_config.margin = self.margin_widget.value
        graph_config.margin_left = self.margin_left_widget.value
        graph_config.dist_x = self.dist_x_widget.value
        graph_config.dist_y = self.dist_y_widget.value
        graph_config.letters = self.letters_widget.value
        graph_config.max_by_year = self.by_year_widget.value
        graph_config.draw_place = self.places_widget.value
        graph_config.fill_color = self.work_colors

        create_graph(self._svg_name, work_list, ref_list, graph_config)
        return work_list, ref_list

    def work_key(self, work):
        """Return work category"""
        return oget(work, "category")

    def work_colors(self, work):
        """Return colors for work"""
        key = self.work_key(work)
        if key not in self.color_widgets:
            return ("white", "black")
        return (self.color_widgets[key].value,
                self.font_color_widgets[key].value)

    def filter_work(self, work):
        """Filter work"""
        key = self.work_key(work)
        if key not in self._display_categories:
            return False
        for attr in dir(work):
            if self._filter_out and self._filter_out in str(getattr(
                    work, attr)).lower():
                return False
        for attr in dir(work):
            if self._filter_in in str(getattr(work, attr)).lower():
                return True
        return False

    def display(self, *args):
        """Display interactive graph"""
        if self._display_stack:
            if not self.delayed:
                self._display_stack -= 1
            if self._display_stack:
                # Skip display if other widgets will invoke display soon
                return False
        self.output_widget.clear_output()
        with self.output_widget:
            work_list, references = self.graph()
            display(self._svg_name)
            svg = SVG(self._svg_name)
            svg._data = svg._data[:4] + ' class="refgraph"' + svg._data[4:]
            display(svg)

            interaction = """
                $(".hoverable polyline, .hoverable line").mouseenter(
                    function(e) {
                        //e.stopPropagation();
                        $(this).css("stroke", "blue");
                        $(this).css("stroke-width", "3px");
                    }).mouseleave(
                    function() {
                        $(this).css("stroke", "black");
                        $(this).css("stroke-width", "inherit");
                    });
            """
            display(Javascript(interaction))
            display(
                HTML("""
                <script type="text/javascript">
                    %s

                    require(["./svg-pan-zoom"], function(svgPanZoom) {
                        svgPanZoom('.refgraph', {'minZoom': 0.1});
                    });
                </script>
            """ % (open(
                    Path(__file__) / ".." / ".." / "resources" /
                    "svg-pan-zoom.min.js").read(), )))

        return True
示例#13
0
class IFTMS(object):
    "a widget to set all 1D MS tools into one screen"

    def __init__(self, show=True, style=True):
        # header
        #   filechooser
        self.base = BASE
        self.filechooser = FileChooser(self.base, accept=('MS', 'fid', 'FID'))
        self.datap = None
        self.MAX_DISP_PEAKS = NbMaxDisplayPeaks

        #   buttons
        #       load
        self.bload = Button(description='Load',
                            layout=Layout(width='15%'),
                            tooltip='load and display experiment')
        self.bload.on_click(self.load)
        #       FT
        self.bproc = Button(description='Process',
                            layout=Layout(width='15%'),
                            tooltip='Fourier transform of the fid')
        self.bproc.on_click(self.process)
        #       pp
        self.bpeak = Button(description='Peak Pick',
                            layout=Layout(width='15%'),
                            tooltip='Detect Peaks')
        self.bpeak.on_click(self.peakpick)
        self.bsave = Button(description='Save',
                            layout=Layout(width='15%'),
                            tooltip='Save processed data set in msh5 format')
        self.bsave.on_click(self.save)

        # GUI set-up and scene
        # tools
        self.header = Output()
        with self.header:
            self.waitarea = Output()
            self.buttonbar = HBox([
                self.bload, self.bproc, self.bpeak, self.bsave, self.waitarea
            ])
            display(
                Markdown('---\n# Select an experiment, and load to process'))
            display(self.filechooser)
            display(self.buttonbar)

        NODATA = HTML("<br><br><h3><i><center>No Data</center></i></h3>")
        # fid
        self.fid = Output()  # the area where 1D is shown
        with self.fid:
            display(NODATA)
            display(Markdown("use the `Load` button above"))
        # spectrum
        self.out1D = Output()  # the area where 1D is shown
        with self.out1D:
            display(NODATA)
            display(
                Markdown(
                    "After loading, use the `Process` or `Load` buttons above")
            )

        # peaklist
        self.peaklist = Output()  # the area where peak list is shown
        with self.peaklist:
            display(NODATA)
            display(
                Markdown("After Processing, use the `Peak Pick` button above"))

        # form
        self.outform = Output(
        )  # the area where processing parameters are displayed
        with self.outform:
            self.paramform()
            display(self.form)

        # Info
        self.outinfo = Output()  # the area where info is shown
        self.showinfo()

        #  tabs
        self.tabs = widgets.Tab()
        self.tabs.children = [
            self.fid, self.out1D, self.peaklist, self.outform, self.outinfo
        ]
        self.tabs.set_title(0, 'raw fid')
        self.tabs.set_title(1, 'Spectrum')
        self.tabs.set_title(2, 'Peak List')
        self.tabs.set_title(3, 'Processing Parameters')
        self.tabs.set_title(4, 'Info')

        #        self.tabs = VBox([ self.out2D, self.outpp2D, self.out1D, self.outinfo ])
        self.box = VBox([self.header, self.tabs])
        # self.box = VBox([self.title,
        #                 self.FC,
        #                 HBox([self.bdisp2D, self.bpp2D, self.bdisp1D])
        #                 ])
        if style:
            injectcss()
        if show:
            display(self.box)

    def showinfo(self):
        """
        Show info on the data-set in memory - several possible cases 
        """
        self.outinfo.clear_output()
        with self.outinfo:
            if self.datap == None:
                display(
                    HTML("<br><br><h3><i><center>No Data</center></i></h3>"))
            else:
                if self.datap.data:  #  a fid is load
                    display(Markdown("# Raw Dataset\n%s\n" %
                                     (self.selected, )))
                    print(self.datap.data)
                    if self.datap.DATA != None:  # and has been processed
                        display(Markdown("# Audi-Trail"))
                        with open('audit_trail.txt', 'r') as F:
                            display(Markdown(F.read()))
                else:
                    if self.datap.DATA != None:  # a processed has been loaded
                        display(
                            Markdown("# Processed Dataset\n%s\n" %
                                     (self.selected, )))
                        print(self.datap.DATA)
                        with open('audit_trail.txt', 'r') as F:
                            display(Markdown(F.read()))

    def wait(self):
        "show a little waiting wheel"
        here = Path(__file__).parent
        with open(here / "icon-loader.gif", "rb") as F:
            with self.waitarea:
                self.wwait = widgets.Image(value=F.read(),
                                           format='gif',
                                           width=40)
                display(self.wwait)

    def done(self):
        "remove the waiting wheel"
        self.wwait.close()

    @property
    def selected(self):
        return str(self.filechooser.selected)

    @property
    def title(self):
        return str(self.filechooser.name)

    def load(self, e):
        "load 1D data-set and display"
        self.fid.clear_output(wait=True)
        if self.selected.endswith(".msh5"):
            self.loadspike()
        else:
            self.loadbruker()

    def loadspike(self):
        fullpath = self.selected
        try:
            DATA = FTICRData(name=fullpath)
        except:
            self.waitarea.clear_output(wait=True)
            with self.waitarea:
                print('Error while loading', self.selected)
                self.waitarea.clear_output(wait=True)
            with self.outinfo:
                traceback.print_exc()
            return
        data = None
        DATA.filename = self.selected  # filename and fullpath are equivalent !
        DATA.fullpath = fullpath
        audit = U.auditinitial(title="Load file", append=False)
        DATA.set_unit('m/z')
        self.datap = Dataproc(data)
        self.datap.data = None
        self.datap.DATA = DATA
        self.showinfo()
        self.out1D.clear_output()
        with self.out1D:
            DATA.display(title=self.title, new_fig={'figsize': (10, 5)})
        self.tabs.selected_index = 1

    def loadbruker(self):
        fullpath = self.selected
        try:
            data = BrukerMS.Import_1D(fullpath)
        except:
            self.waitarea.clear_output(wait=True)
            with self.waitarea:
                print('Error while loading -', self.selected)
                self.waitarea.clear_output(wait=True)
            with self.outinfo:
                traceback.print_exc()
            return
        data.filename = self.selected
        data.fullpath = fullpath  # filename and fullpath are equivalent !
        audit = U.auditinitial(title="Load file", append=False)
        data.set_unit('sec')
        with self.fid:
            data.display(title=self.title, new_fig={'figsize': (10, 5)})
        self.datap = Dataproc(data)
        self.showinfo()
        self.param2form(self.datap.procparam)
        self.outform.clear_output()
        with self.outform:  # refresh param form
            display(self.form)
        self.tabs.selected_index = 0

    def save(self, e):
        "save 1D spectrum to msh5 file"
        self.wait()
        audit = U.auditinitial(title="Save file", append=True)
        # find name
        try:
            fullpath = self.datap.DATA.fullpath
        except:
            self.waitarea.clear_output(wait=True)
            with self.waitarea:
                print('No processed dataset to save')
                self.waitarea.clear_output(wait=True)
            return
        # find name
        expname = U.find_free_filename(fullpath, 'Processed', '.msh5')

        # clean if required
        self.form2param()
        parameters = self.datap.procparam
        data = self.datap.DATA
        compress = False
        if parameters['grass_noise_todo'] == 'storage':  # to do !
            #            print("text", "grass noise removal","noise threshold", parameters['grass_noise_level'])
            data.zeroing(parameters['grass_noise_level'] * data.noise)
            data.eroding()
            compress = True
            U.audittrail(audit, "text", "grass noise removal",
                         "noise threshold", parameters['grass_noise_level'])
        try:
            self.datap.DATA.save_msh5(expname, compressed=compress)
        except:
            self.waitarea.clear_output(wait=True)
            with self.waitarea:
                print('Error while saving to file', self.selected)
                self.waitarea.clear_output(wait=True)
            with self.outinfo:
                traceback.print_exc()
            return
        self.datap.DATA.filename = expname
        # copy audit_trail.txt
        pexp = Path(expname)
        destination = str(pexp.with_suffix('')) + '_audit.txt'
        subprocess.run(["mv", "audit_trail.txt", destination])

        with self.outinfo:
            display(
                Markdown("""# Save locally
 Data set saved as "%s"
 """ % (expname, )))
        self.done()
        with self.waitarea:
            print('Data-set saved')
            self.waitarea.clear_output(wait=True)
        self.filechooser.refresh('event')

    def process(self, e):
        "do the FT"
        if self.datap == None or self.datap.data == None:
            with self.waitarea:
                print('Please load a raw dataset first')
                self.waitarea.clear_output(wait=True)
            return
        self.wait()
        self.out1D.clear_output(wait=True)
        self.form2param()
        self.datap.process()
        DATA = self.datap.DATA
        ti = self.selected
        with self.out1D:
            DATA.display(title=self.title, new_fig={'figsize': (10, 5)})
        self.showinfo()
        self.tabs.selected_index = 1
        self.done()

    def peakpick(self, e):
        "do the peak-picking"
        if self.datap == None:
            with self.waitarea:
                print('Please load a dataset first')
                self.waitarea.clear_output(wait=True)
            return
        if self.datap.DATA == None:
            with self.waitarea:
                print('Please process the dataset first')
                self.waitarea.clear_output(wait=True)
            return
        self.wait()
        self.peaklist.clear_output(wait=True)
        self.form2param()
        self.datap.peakpick()
        with self.out1D:
            self.datap.DATA.display_peaks(peak_label=True,
                                          NbMaxPeaks=self.MAX_DISP_PEAKS)
        with self.peaklist:
            display(Markdown('%d Peaks detected' % len(self.datap.DATA.peaks)))
            display(HTML(self.datap.DATA.pk2pandas().to_html()))
        self.showinfo()
        self.tabs.selected_index = 2
        self.done()

    def paramform(self):
        "draw the processing parameter form"

        def dropdown(options, value, description):
            opt = [(v, k) for k, v in U.procparam_MS[options].items()]
            return widgets.Dropdown(options=opt,
                                    value=value,
                                    description=description,
                                    layout=ly,
                                    style=style)

        chld = []
        ly = widgets.Layout(width='50%')
        style = {'description_width': '30%'}
        chld.append(widgets.HTML('<h3>Processing</h3>'))

        chld.append(
            widgets.RadioButtons(options=['Yes', 'No'],
                                 value='Yes',
                                 description='center fid',
                                 layout=ly,
                                 style=style))

        opt = U.procparam_MS["apodisations"].keys()
        chld.append(
            widgets.Dropdown(options=opt,
                             value='hamming',
                             description='apod todo',
                             layout=ly,
                             style=style))

        chld.append(
            widgets.IntSlider(value=2,
                              min=1,
                              max=4,
                              description='zf level',
                              layout=ly,
                              style=style))

        chld.append(dropdown("baseline_correction", "offset", "baseline todo"))

        chld.append(dropdown("grass_noise", "storage", "grass noise todo"))

        chld.append(
            widgets.FloatSlider(value=3,
                                min=1,
                                max=10,
                                description='grass noise level',
                                layout=ly,
                                style=style))

        chld.append(widgets.HTML('<h3>Peak Picking</h3>'))

        chld.append(dropdown("peakpicking", "manual", "peakpicking todo"))

        chld.append(
            widgets.FloatLogSlider(value=10,
                                   min=0,
                                   max=3,
                                   description='peakpicking noise level',
                                   layout=ly,
                                   style=style))

        chld.append(
            widgets.RadioButtons(options=['Yes', 'No'],
                                 value='Yes',
                                 description='centroid',
                                 layout=ly,
                                 style=style))

        chld.append(
            widgets.IntText(value=self.MAX_DISP_PEAKS,
                            min=10,
                            max=10000,
                            description='max peak displayed',
                            layout=ly,
                            style=style))

        self.form = widgets.VBox(chld)

    # WARNING procparams have '_' in names while the form display with ' '
    def param2form(self, dico, verbose=DEBUG):
        """copy parameters stored in dico to form."""
        myform = {}  # a dico to handle widgets in the form
        for vb in self.form.children:
            myform[vb.description] = vb
        keys = myform.keys()  # keys of form
        # then propagate
        for k, v in dico.items():
            k = k.replace('_', ' ')
            if k not in keys:
                if verbose:
                    print('key not in form:', k)
            else:
                myform[k].value = v

    def form2param(self, verbose=DEBUG):
        """copy form parameters to internal parameters."""
        val = self.form2dico()
        keys = self.datap.procparam.keys()  # keys of the procparam
        for k, v in val.items():
            if k not in keys:
                if k == 'max_peak_displayed':
                    self.MAX_DISP_PEAKS = v
                elif verbose:
                    print('key missmatch:', k)
                continue
            self.datap.procparam[k] = v

    def form2dico(self):
        """copy form parameters to a dico."""
        val = {}
        for vb in self.form.children:
            k = vb.description.replace(' ', '_')
            val[k] = vb.value
        return val
示例#14
0
class MSPeaker(object):
    "a peak-picker for MS experiments"

    def __init__(self, npkd, pkname):
        if not isinstance(npkd, FTMSData):
            raise Exception('This modules requires a FTMS Dataset')
        self.npkd = npkd
        self.pkname = pkname
        self.zoom = widgets.FloatRangeSlider(
            value=[npkd.axis1.lowmass, npkd.axis1.highmass],
            min=npkd.axis1.lowmass,
            max=npkd.axis1.highmass,
            step=0.1,
            layout=Layout(width='100%'),
            description='zoom',
            continuous_update=False,
            readout=True,
            readout_format='.1f',
        )
        self.zoom.observe(self.display)
        self.tlabel = Label('threshold (x noise level):')
        self.thresh = widgets.FloatLogSlider(value=20.0,
                                             min=np.log10(3),
                                             max=2.0,
                                             base=10,
                                             step=0.01,
                                             layout=Layout(width='30%'),
                                             continuous_update=False,
                                             readout=True,
                                             readout_format='.1f')
        self.thresh.observe(self.pickpeak)
        self.peak_mode = widgets.Dropdown(options=['marker', 'bar'],
                                          value='marker',
                                          description='show as')
        self.peak_mode.observe(self.display)
        self.bexport = widgets.Button(
            description="Export",
            layout=Layout(width='7%'),
            button_style=
            'success',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Export to csv file')
        self.bexport.on_click(self.pkexport)
        self.bprint = widgets.Button(description="Print",
                                     layout=Layout(width='7%'),
                                     button_style='success',
                                     tooltip='Print to screen')
        self.bprint.on_click(self.pkprint)
        self.bdone = widgets.Button(description="Done",
                                    layout=Layout(width='7%'),
                                    button_style='warning',
                                    tooltip='Fix results')
        self.bdone.on_click(self.done)
        #        self.spec = Output(layout={'border': '1px solid black'})
        self.out = Output(layout={'border': '1px solid red'})
        display(
            VBox([
                self.zoom,
                HBox([
                    self.tlabel, self.thresh, self.peak_mode, self.bprint,
                    self.bexport, self.bdone
                ])
            ]))
        self.fig, self.ax = plt.subplots()
        self.npkd.set_unit('m/z').peakpick(autothresh=self.thresh.value,
                                           verbose=False,
                                           zoom=self.zoom.value).centroid()
        self.display()
        display(self.out)

    def pkprint(self, event):
        self.out.clear_output(wait=True)
        with self.out:
            display(HTML(self.npkd.pk2pandas().to_html()))
#            print(self.pklist())

    def pkexport(self, event):
        "exports the peaklist to file"
        with open(self.pkname, 'w') as FPK:
            print(self.pklist(), file=FPK)
        print('Peak list stored in ', self.pkname)

    def pklist(self):
        "creates peaklist"
        text = ["m/z\t\tInt.(%)\tR\tarea(a.u.)"]
        data = self.npkd
        intmax = max(data.peaks.intens) / 100
        for pk in data.peaks:
            mz = data.axis1.itomz(pk.pos)
            Dm = 0.5 * (data.axis1.itomz(pk.pos - pk.width) -
                        data.axis1.itomz(pk.pos + pk.width))
            area = pk.intens * Dm
            l = "%.6f\t%.1f\t%.0f\t%.0f" % (mz, pk.intens / intmax,
                                            round(mz / Dm, -3), area)
            text.append(l)
        return "\n".join(text)

    def display(self, event={'name': 'value'}):
        "display spectrum and peaks"
        if event[
                'name'] == 'value':  # event is passed by GUI - make it optionnal
            self.ax.clear()
            self.npkd.display(new_fig=False,
                              figure=self.ax,
                              zoom=self.zoom.value)
            try:
                self.npkd.display_peaks(peak_label=True,
                                        peak_mode=self.peak_mode.value,
                                        figure=self.ax,
                                        zoom=self.zoom.value,
                                        NbMaxPeaks=NbMaxDisplayPeaks)
                x = self.zoom.value
                y = [self.npkd.peaks.threshold] * 2
                self.ax.plot(x, y, ':r')
                self.ax.annotate('%d peaks detected' % len(self.npkd.peaks),
                                 (0.05, 0.95),
                                 xycoords='figure fraction')
            except:
                pass

    def pickpeak(self, event):
        "interactive wrapper to peakpick"
        if event['name'] == 'value':
            self.pp()

    def pp(self):
        "do the peak-picking calling pp().centroid()"
        #self.spec.clear_output(wait=True)
        self.npkd.set_unit('m/z').peakpick(autothresh=self.thresh.value,
                                           verbose=False,
                                           zoom=self.zoom.value).centroid()
        self.display()

    def done(self, event):
        "exit GUI"
        for w in [
                self.zoom, self.thresh, self.peak_mode, self.bprint,
                self.bexport, self.bdone
        ]:
            w.close()
        self.tlabel.value = "threshold %.2f noise level" % self.thresh.value
示例#15
0
class measure_stepped_sine():
    def __init__(self, settings):

        self.settings = settings

        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out = Output()

        words = ['Measure', 'Delete Last Measurement', 'Save Data', 'Save Fig']
        self.buttons = [
            widgets.Button(description=w, layout=Layout(width='25%'))
            for w in words
        ]
        self.buttons[0].button_style = 'success'
        self.buttons[0].style.font_weight = 'bold'
        self.buttons[1].button_style = 'warning'
        self.buttons[1].style.font_weight = 'bold'
        self.buttons[2].button_style = 'primary'
        self.buttons[2].style.font_weight = 'bold'
        self.buttons[3].button_style = 'primary'
        self.buttons[3].style.font_weight = 'bold'
        display(widgets.HBox(self.buttons))

        try:
            dvma.start_stream(settings)
            self.rec = dvma.streams.REC
        except:
            print('Data stream not initialised.')
            print(
                'Possible reasons: pyaudio or PyDAQmx not installed, or acquisition hardware not connected.'
            )
            print('Please note that it won' 't be possible to log data.')

        self.f = np.array([])
        self.G = np.array([])
        self.fig, self.ax = plt.subplots(1, 1, figsize=(9, 5), dpi=100)
        self.ax.set_xlabel('Frequency (Hz)')
        self.ax.set_ylabel('Amplitude (dB)')
        self.ax.grid()
        self.ax.set_xlim([0, 500])
        self.ax.set_ylim([-50, 50])
        self.line, = self.ax.plot([], [],
                                  'x',
                                  markeredgewidth=2,
                                  label='stepped sine')
        self.line.axes.set_autoscaley_on(True)

        self.buttons[0].on_click(self.measure)
        self.buttons[1].on_click(self.undo)
        self.buttons[2].on_click(self.save)
        self.buttons[3].on_click(self.savefig)

        display(self.out)

    def measure(self, b):
        time_data = dvma.stream_snapshot(self.rec)
        freq_data = dvma.calculate_fft(time_data, window='hanning')

        index = np.argmax(np.abs(freq_data.freq_data[:, 0]))

        self.f = np.append(self.f, freq_data.freq_axis[index])
        self.G = np.append(
            self.G,
            freq_data.freq_data[index, 1] / freq_data.freq_data[index, 0])

        self.update_line()

    def undo(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        if len(self.f) > 0:
            self.f = np.delete(self.f, -1)
            self.G = np.delete(self.G, -1)
            self.update_line()
        else:
            with self.out:
                print('nothing to undo!')

    def save(self, b):
        self.out.clear_output(wait=False)
        with self.out:
            i_sort = np.argsort(self.f)
            ff = self.f[i_sort]
            GG = self.G[i_sort]
            tf_data = dvma.TfData(ff,
                                  GG,
                                  None,
                                  self.rec.settings,
                                  test_name='stepped_sine')
            d = dvma.DataSet(tf_data)
            d.save_data()

    def savefig(self, b):
        # the 'out' construction is to refresh the text output at each update
        # to stop text building up in the widget display
        self.out.clear_output(wait=False)
        with self.out:
            dvma.save_fig(self.fig, figsize=(9, 5))

    def update_line(self):
        i_sort = np.argsort(self.f)
        self.GdB = 20 * np.log10(np.abs(self.G))
        self.line.set_xdata(self.f[i_sort])
        self.line.set_ydata(self.GdB[i_sort])
        if len(self.GdB) > 0:
            self.line.axes.set_ylim(bottom=min(self.GdB) - 3,
                                    top=max(self.GdB) + 3)
示例#16
0
class Replay(object):
    def __init__(self, debug=False, cad_width=600, height=600):
        self._debug = debug

        self.debug_output = Output()
        self.cad_width = cad_width
        self.height = height
        self.indexes = [0]
        self.view = None
        self.state = None

    def to_array(self, workplane, indent="", result_name=None):
        def to_code(name, args, kwargs, indent, result_name):
            def to_name(obj):
                name = getattr(obj, "name", None)
                return name or obj

            args = tuple([to_name(arg) for arg in args])
            code = "%s%s%s" % (indent, name, args)
            code = code[:-2] if len(args) == 1 else code[:-1]
            if len(args) > 0 and len(kwargs) > 0:
                code += ","
            if kwargs != {}:
                code += ", ".join(
                    ["%s=%s" % (k, v) for k, v in kwargs.items()])
            code += ")"
            if result_name is not None:
                code += (" => %s" % result_name)
            return code

        def walk(caller, indent="", result_name=None):
            delim = "| "
            stack = [(to_code(caller["func"], caller["args"], caller["kwargs"],
                              indent, result_name), caller["obj"])]
            for child in reversed(caller["children"]):
                stack = walk(child, indent + delim) + stack
                for arg in child["args"]:
                    if isinstance(arg, cq.Workplane):
                        result_name = getattr(arg, "name", None)
                        stack = self.to_array(arg, indent +
                                              (delim * 2), result_name) + stack
            return stack

        stack = []
        delim = "| "

        obj = workplane
        while obj is not None:
            caller = getattr(obj, "_caller", None)
            result_name = getattr(obj, "name", None)
            if caller is not None:
                stack = walk(caller, indent, result_name) + stack
                for arg in caller["args"]:
                    if isinstance(arg, cq.Workplane):
                        result_name = getattr(arg, "name", None)
                        stack = self.to_array(arg, indent + delim,
                                              result_name) + stack
            obj = obj.parent

        return stack

    def dump(self):
        for o in self.stack:
            print(o, o[1].val().__class__.__name__)

    def select(self, indexes):
        with self.debug_output:
            self.indexes = indexes
            cad_objs = [self.stack[i][1] for i in self.indexes]

            # Save state
            axes = True if self.view is None else self.view.cq_view.axes.get_visibility(
            )
            grid = True if self.view is None else self.view.cq_view.grid.get_visibility(
            )
            axes0 = True if self.view is None else self.view.cq_view.axes.is_center(
            )
            ortho = True if self.view is None else self.view.cq_view.is_ortho()
            transparent = False if self.view is None else self.view.cq_view.is_transparent(
            )
            rotation = None if self.view is None else self.view.cq_view.camera.rotation
            zoom = None if self.view is None else self.view.cq_view.camera.zoom
            position = None if self.view is None else self.view.cq_view.camera.position
            # substract center out of position to be prepared for _scale function
            if position is not None:
                position = self.view.cq_view._sub(position,
                                                  self.view.cq_view.bb.center)

            # Show new view
            self.view = self.show(cad_objs, position, rotation, zoom, axes,
                                  grid, axes0, ortho, transparent)

    def select_handler(self, change):
        with self.debug_output:
            if change["name"] == "index":
                self.select(change["new"])

    def show(self,
             cad_objs,
             position,
             rotation,
             zoom,
             axes=True,
             grid=True,
             axes0=True,
             ortho=True,
             transparent=True):

        self.debug_output.clear_output()

        # Add hidden result to start with final size and allow for comparison
        result = Part(self.stack[-1][1],
                      "Result",
                      show_faces=False,
                      show_edges=False)
        with self.debug_output:
            return show(result,
                        *cad_objs,
                        transparent=transparent,
                        axes=axes,
                        grid=grid,
                        axes0=axes0,
                        ortho=ortho,
                        cad_width=self.cad_width,
                        height=self.height,
                        show_parents=(len(cad_objs) == 1),
                        position=position,
                        rotation=rotation,
                        zoom=zoom)
示例#17
0
文件: widgets.py 项目: Calysto/conx
class SequenceViewer(VBox):
    """
    SequenceViewer

    Arguments:
        title (str) - Title of sequence
        function (callable) - takes an index 0 to length - 1. Function should
            a displayable or list of displayables
        length (int) - total number of frames in sequence
        play_rate (float) - seconds to wait between frames when auto-playing.
            Optional. Default is 0.5 seconds.

    >>> def function(index):
    ...     return [None]
    >>> sv = SequenceViewer("Title", function, 10)
    >>> ## Do this manually for testing:
    >>> sv.initialize()
    None
    >>> ## Testing:
    >>> class Dummy:
    ...     def update(self, result):
    ...         return result
    >>> sv.displayers = [Dummy()]
    >>> print("Testing"); sv.goto("begin") # doctest: +ELLIPSIS
    Testing...
    >>> print("Testing"); sv.goto("end") # doctest: +ELLIPSIS
    Testing...
    >>> print("Testing"); sv.goto("prev") # doctest: +ELLIPSIS
    Testing...
    >>> print("Testing"); sv.goto("next") # doctest: +ELLIPSIS
    Testing...

    """
    def __init__(self, title, function, length, play_rate=0.5):
        self.player = _Player(self, play_rate)
        self.player.start()
        self.title = title
        self.function = function
        self.length = length
        self.output = Output()
        self.position_text = IntText(value=0, layout=Layout(width="100%"))
        self.total_text = Label(value="of %s" % self.length, layout=Layout(width="100px"))
        controls = self.make_controls()
        super().__init__([controls, self.output])

    def goto(self, position):
        #### Position it:
        if position == "begin":
            self.control_slider.value = 0
        elif position == "end":
            self.control_slider.value = self.length - 1
        elif position == "prev":
            if self.control_slider.value - 1 < 0:
                self.control_slider.value = self.length - 1 # wrap around
            else:
                self.control_slider.value = max(self.control_slider.value - 1, 0)
        elif position == "next":
            if self.control_slider.value + 1 > self.length - 1:
                self.control_slider.value = 0 # wrap around
            else:
                self.control_slider.value = min(self.control_slider.value + 1, self.length - 1)
        elif isinstance(position, int):
            self.control_slider.value = position
        self.position_text.value = self.control_slider.value

    def toggle_play(self, button):
        ## toggle
        if self.button_play.description == "Play":
            self.button_play.description = "Stop"
            self.button_play.icon = "pause"
            self.player.resume()
        else:
            self.button_play.description = "Play"
            self.button_play.icon = "play"
            self.player.pause()

    def make_controls(self):
        button_begin = Button(icon="fast-backward", layout=Layout(width='100%'))
        button_prev = Button(icon="backward", layout=Layout(width='100%'))
        button_next = Button(icon="forward", layout=Layout(width='100%'))
        button_end = Button(icon="fast-forward", layout=Layout(width='100%'))
        self.button_play = Button(icon="play", description="Play", layout=Layout(width="100%"))
        self.control_buttons = HBox([
            button_begin,
            button_prev,
            self.position_text,
            button_next,
            button_end,
            self.button_play,
        ], layout=Layout(width='100%', height="50px"))
        self.control_slider = IntSlider(description=self.title,
                                        continuous_update=False,
                                        min=0,
                                        max=max(self.length - 1, 0),
                                        value=0,
                                        style={"description_width": 'initial'},
                                        layout=Layout(width='100%'))
        ## Hook them up:
        button_begin.on_click(lambda button: self.goto("begin"))
        button_end.on_click(lambda button: self.goto("end"))
        button_next.on_click(lambda button: self.goto("next"))
        button_prev.on_click(lambda button: self.goto("prev"))
        self.button_play.on_click(self.toggle_play)
        self.control_slider.observe(self.update_slider_control, names='value')
        controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")),
                         self.control_buttons], layout=Layout(width='100%'))
        controls.on_displayed(lambda widget: self.initialize())
        return controls

    def initialize(self):
        results = self.function(self.control_slider.value)
        try:
            results = list(results)
        except:
            results = [results]
        self.displayers = [display(x, display_id=True) for x in results]

    def update_slider_control(self, change):
        if change["name"] == "value":
            self.position_text.value = self.control_slider.value
            self.output.clear_output(wait=True)
            results = self.function(self.control_slider.value)
            try:
                results = list(results)
            except:
                results = [results]
            for i in range(len(self.displayers)):
                self.displayers[i].update(results[i])
示例#18
0
class ArticleNavigator:
    """Navigate on article list for insertion"""

    def __init__(self, citation_var=None, citation_file=None, articles=None, backward=True, force_citation_file=True):
        reload()
        self.force_citation_file = force_citation_file
        self.citation_var = citation_var
        self.citation_file = citation_file or citation_var
        self.disable_show = False
        self.work = work_by_varname(citation_var) if citation_var else None
        self.backward = backward
        self.to_display = []
        self.custom_widgets = []

        self.next_article_widget = Button(
            description="Next Article", icon="fa-caret-right")
        self.previous_article_widget = Button(
            description="Previous Article", icon="fa-caret-left")
        self.selector_widget = IntSlider(value=0, min=0, max=20, step=1)
        self.reload_article_widget = Button(
            description="Reload Article", icon="fa-refresh")

        
        self.article_number_widget = Label(value="")
        self.output_widget = Output()

        self.next_article_widget.on_click(self.next_article)
        self.previous_article_widget.on_click(self.previous_article)
        self.selector_widget.observe(self.show)
        self.reload_article_widget.on_click(self.show)
        
        self.widgets = {}
        self.widgets_empty = {}
        
        form = form_definition()
        for widget in form["widgets"]:
            if widget[0] not in WIDGET_CLS:
                print("Error: Widgets type {} not found".format(widget[0]))
            else:
                self.widgets[widget[2]] = WIDGET_CLS[widget[0]](widget)
            if len(widget) >= 4:
                self.widgets_empty[widget[2]] = widget[3]
        
        for event in form["events"]:
            if event[1] == "observe":
                self.widgets[event[0]].observe(self.process(event))
            if event[1] == "click":
                self.widgets[event[0]].on_click(self.process(event))
                
        self.show_event = form["show"]

        hboxes = [
            HBox([
                self.previous_article_widget,
                self.reload_article_widget,
                self.next_article_widget
            ]),
        ]
        
        for widgets in form["order"]:
            hboxes.append(HBox([
                self.widgets[widget] for widget in widgets
            ]))

        hboxes.append(HBox([
            self.reload_article_widget,
            self.selector_widget,
            self.article_number_widget
        ]))

        hboxes.append(self.output_widget)

        self.runner_widget = RunWidget() if config.RUN_WIDGET else ReplaceCellWidget() 
        hboxes.append(self.runner_widget.view)
        self.view = VBox(hboxes)

        self.set_articles(articles)
        self.erase_article_form()

    def process(self, event):
        runner = EventRunner(self)
        def action(b):
            runner.execute(event[2])
        return action
    
    def set_articles(self, articles):
        """Set list of articles and restart slider"""
        self.articles = list(self.valid_articles(articles))
        self.disable_show = True
        self.selector_widget.value = 0
        self.selector_widget.max = max(len(self.articles) - 1, 0)
        self.next_article_widget.disabled = self.selector_widget.value == self.selector_widget.max
        self.previous_article_widget.disabled = self.selector_widget.value == 0
        self.article_number_widget.value = "{}/{}".format(
            min(self.selector_widget.value + 1, len(self.articles)),
            len(self.articles)
        )
        self.disable_show = False

    def erase_article_form(self):
        """Erases form fields"""
        self.article_number_widget.value = "{}/{}".format(
            min(self.selector_widget.value + 1, len(self.articles)),
            len(self.articles)
        )
        for key, widget in self.widgets.items():
            if key in self.widgets_empty:
                if isinstance(widget, ToggleButton):
                    widget.value = self.widgets_empty[key] or False 
                else:   
                    widget.value = self.widgets_empty[key] or ""

    def next_article(self, b=None):
        """Next article click event"""
        self.selector_widget.value = min(self.selector_widget.value + 1, self.selector_widget.max)
        self.erase_article_form()
        self.show(clear=True)

    def previous_article(self, b=None):
        """Previous article click event"""
        self.selector_widget.value = max(self.selector_widget.value - 1, self.selector_widget.min)
        self.erase_article_form()
        self.show(clear=True)

    def valid_articles(self, articles, show=False):
        """Generate valid articles"""
        if not articles:
            return
        for article in articles:
            should, nwork, info = should_add_info(
                article, self.work, article=article,
                backward=self.backward,
                citation_file=self.citation_file if self.force_citation_file else None,
                warning=lambda x: self.to_display.append(x)
            )
            if should["add"]:
                yield article, nwork, info, should

    def clear(self):
        """Clear cell and output"""
        if self.disable_show:
            return
        self.to_display = []
        self.runner_widget.clear()
        self.output_widget.clear_output()

    def update_info(self, info, field, widget, value=None, default=""):
        """Update info according to widget"""
        if widget.value != default:
            info[field] = widget.value if value is None else value
        return bool(widget.value)

    def show_article(self, article, nwork, info, should):
        """Display article"""
        result = create_info_code(
            nwork, info,
            self.citation_var, self.citation_file, should,
            ref=article.get("_ref", "")
        )
        self.runner_widget.set_code(result["code"])
        self.output_widget.clear_output()
        with self.output_widget:
            if self.to_display:
                display("\n".join(self.to_display))
            display_list(config.display_article(article))
            for key, value in result["extra"].items():
                display(HTML("<label>{}</label><input value='{}' style='width: 100%'></input>".format(key, value)))
        self.to_display = []

    def show(self, b=None, clear=True):
        """Generic display"""
        _up = self.update_info
        reload()
        self.next_article_widget.disabled = self.selector_widget.value == self.selector_widget.max
        self.previous_article_widget.disabled = self.selector_widget.value == 0
        if clear:
            self.clear()
        if self.disable_show or not self.articles:
            return
        article, _, _, _ = self.articles[self.selector_widget.value]
        with self.output_widget:
            display_list(config.display_article(article))
        for article, nwork, info, should in self.valid_articles([article], show=True):
            runner = EventRunner(self, info=info)
            runner.execute(self.show_event)
            self.show_article(article, nwork, info, should)

    def browser(self):
        """Widget visualization"""
        with self.output_widget:
            print("Press 'Reload Article'")
        return self.view

    def _ipython_display_(self):
        """ Displays widget """
        with self.output_widget:
            print("Press 'Reload Article'")
        display(self.view)
示例#19
0
class ScholarUpdate:
    """Widget for curating database"""

    def __init__(self, querier, worklist, force=False, debug=False, index=0, rules=None):
        reload()
        self.rules = rules or config.BIBTEX_TO_INFO
        self.worklist = worklist
        self.force = force
        self.querier = querier
        self.next_page_widget = Button(description="Next Work", icon="fa-arrow-right")
        self.reload_widget = Button(description="Reload", icon="fa-refresh")
        self.previous_page_widget = Button(description="Previous Work", icon="fa-arrow-left")
        self.debug_widget = ToggleButton(value=debug, description="Debug")
        self.textarea_widget = ToggleButton(value=False, description="TextArea")
        self.page_number_widget = Label(value="")
        self.output_widget = Output()
        self.next_page_widget.on_click(self.next_page)
        self.reload_widget.on_click(self.reload)
        self.previous_page_widget.on_click(self.previous_page)
        self.textarea_widget.observe(self.show)
        self.runner_widget = RunWidget() if config.RUN_WIDGET else ReplaceCellWidget()
        self.view = VBox([
            HBox([
                self.previous_page_widget,
                self.reload_widget,
                self.next_page_widget,
                self.debug_widget,
                self.textarea_widget,
                self.page_number_widget
            ]),
            self.output_widget,
            self.runner_widget.view,
        ])
        self.index = index
        self.varname = ""
        self.work = None
        self.articles = []
        self.reload(show=False)


    def next_page(self, b):
        """Go to next page"""
        self.index = min(len(self.worklist) - 1, self.index + 1)
        self.reload(b)

    def previous_page(self, b):
        """Go to previous page"""
        self.query = max(0, self.index - 1)
        self.reload(b)

    def set_index(self):
        """Set page index"""
        self.page_number_widget.value = str(self.index)
        self.next_page_widget.disabled = self.index == len(self.worklist) - 1
        self.previous_page_widget.disabled = self.index == 0

    def show(self, b=None):
        """Show comparison"""
        self.output_widget.clear_output()
        with self.output_widget:
            if not self.articles:
                print(self.varname, "<unknown>")
                return
            try:
                print(self.varname, getattr(self.work, self.rules.get(
                    "<scholar_ok>", "_some_invalid_attr_for_scholar_ok"
                ), False))
                var, work, articles = self.varname, self.work, self.articles
                meta = extract_info(articles[0])
                table = "<table>{}</table>"
                rows = ["<tr><th></th><th>{}</th><th>{}</th></tr>".format(var, "Scholar")]
                changes = set_by_info(work, meta, rules=self.rules)
                set_text = changes_dict_to_set_attribute(var, changes["set"])
                for key, value in changes["show"].items():
                    if value is not None:
                        meta_value, work_value = value
                        rows.append("<tr><td>{}</td><td>{}</td><td>{}</td></tr>".format(
                            key, work_value, meta_value
                        ))
                textarea = ""
                if self.textarea_widget.value:
                    textarea = "<textarea rows='{}' style='width: 100%'>{}</textarea>".format(len(rows), set_text)
                else:
                    self.runner_widget.set_code(set_text)
                display(HTML(table.format("".join(rows))+"<br>"+textarea))
            except:
                traceback.print_exc(file=sys.stdout)
                print(self.varname, "<error>")

    def reload(self, b=None, show=True):
        """Reload"""
        self.output_widget.clear_output()
        with self.output_widget:
            if self.debug_widget.value:
                ScholarConf.LOG_LEVEL = 3
            else:
                ScholarConf.LOG_LEVEL = 2
            reload()
            self.querier.tasks.clear()

            if self.index >= len(self.worklist):
                self.set_index()
                return
            self.varname = self.worklist[self.index]
            self.work = work_by_varname(self.varname)
            print(self.varname, oget(self.work, "scholar_ok", False, cvar=config.SCHOLAR_MAP))
            if oget(self.work, "scholar_ok", False, cvar=config.SCHOLAR_MAP) and not self.force:
                self.set_index()
                return
            from .selenium_scholar import SearchScholarQuery
            query = SearchScholarQuery()

            query.set_scope(False)
            query.set_words(config.query_str(self.work))
            query.set_num_page_results(1)
            self.querier.send_query(query)

            self.articles = self.querier.articles
        if show:
            self.show()

        self.set_index()

    def browser(self):
        """Present widget"""
        self.show()
        return self.view

    def _ipython_display_(self):
        """ Displays widget """
        self.show()
        display(self.view)
示例#20
0
class Player(VBox):
    def __init__(self, title, function, length, play_rate=0.1):
        """
        function - takes a slider value and returns displayables
        """
        self.player = _Player(self, play_rate)
        self.player.start()
        self.title = title
        self.function = function
        self.length = length
        self.output = Output()
        self.position_text = FloatText(value=0.0, layout=Layout(width="100%"))
        self.total_text = Label(value="of %s" % round(self.length * 0.1, 1),
                                layout=Layout(width="100px"))
        controls = self.make_controls()
        super().__init__([controls, self.output])

    def update_length(self, length):
        self.length = length
        self.total_text.value = "of %s" % round(self.length * 0.1, 1)
        self.control_slider.max = round(max(self.length * 0.1, 0), 1)

    def goto(self, position):
        #### Position it:
        if position == "begin":
            self.control_slider.value = 0.0
        elif position == "end":
            self.control_slider.value = round(self.length * 0.1, 1)
        elif position == "prev":
            if self.control_slider.value - 0.1 < 0:
                self.control_slider.value = round(self.length * 0.1,
                                                  1)  # wrap around
            else:
                self.control_slider.value = round(
                    max(self.control_slider.value - 0.1, 0), 1)
        elif position == "next":
            if round(self.control_slider.value + 0.1, 1) > round(
                    self.length * 0.1, 1):
                self.control_slider.value = 0  # wrap around
            else:
                self.control_slider.value = round(
                    min(self.control_slider.value + 0.1, self.length * 0.1), 1)
        self.position_text.value = round(self.control_slider.value, 1)

    def toggle_play(self, button):
        ## toggle
        if self.button_play.description == "Play":
            self.button_play.description = "Stop"
            self.button_play.icon = "pause"
            self.player.resume()
        else:
            self.button_play.description = "Play"
            self.button_play.icon = "play"
            self.player.pause()

    def make_controls(self):
        button_begin = Button(icon="fast-backward",
                              layout=Layout(width="100%"))
        button_prev = Button(icon="backward", layout=Layout(width="100%"))
        button_next = Button(icon="forward", layout=Layout(width="100%"))
        button_end = Button(icon="fast-forward", layout=Layout(width="100%"))
        self.button_play = Button(icon="play",
                                  description="Play",
                                  layout=Layout(width="100%"))
        self.control_buttons = HBox(
            [
                button_begin,
                button_prev,
                self.position_text,
                button_next,
                button_end,
                self.button_play,
            ],
            layout=Layout(width="100%", height="50px"),
        )
        self.control_slider = FloatSlider(
            description=self.title,
            continuous_update=False,
            min=0.0,
            step=0.1,
            max=max(round(self.length * 0.1, 1), 0.0),
            value=0.0,
            readout_format=".1f",
            style={"description_width": "initial"},
            layout=Layout(width="100%"),
        )
        ## Hook them up:
        button_begin.on_click(lambda button: self.goto("begin"))
        button_end.on_click(lambda button: self.goto("end"))
        button_next.on_click(lambda button: self.goto("next"))
        button_prev.on_click(lambda button: self.goto("prev"))
        self.button_play.on_click(self.toggle_play)
        self.control_slider.observe(self.update_slider_control, names="value")
        controls = VBox(
            [
                HBox([self.control_slider, self.total_text],
                     layout=Layout(height="40px")),
                self.control_buttons,
            ],
            layout=Layout(width="100%"),
        )
        controls.on_displayed(lambda widget: self.initialize())
        return controls

    def initialize(self):
        """
        Setup the displayer ids to map results to the areas.
        """
        results = self.function(self.control_slider.value)
        if not isinstance(results, (list, tuple)):
            results = [results]
        self.displayers = [display(x, display_id=True) for x in results]

    def update_slider_control(self, change):
        """
        If the slider changes the value, call the function
        and update display areas.
        """
        if change["name"] == "value":
            self.position_text.value = self.control_slider.value
            self.output.clear_output(wait=True)
            results = self.function(self.control_slider.value)
            if not isinstance(results, (list, tuple)):
                results = [results]
            for i in range(len(self.displayers)):
                self.displayers[i].update(results[i])
示例#21
0
文件: widgets.py 项目: Calysto/conx
class Dashboard(VBox):
    """
    Build the dashboard for Jupyter widgets. Requires running
    in a notebook/jupyterlab.
    """
    def __init__(self, net, width="95%", height="550px", play_rate=0.5):
        self._ignore_layer_updates = False
        self.player = _Player(self, play_rate)
        self.player.start()
        self.net = net
        r = random.randint(1, 1000000)
        self.class_id = "picture-dashboard-%s-%s" % (self.net.name, r)
        self._width = width
        self._height = height
        ## Global widgets:
        style = {"description_width": "initial"}
        self.feature_columns = IntText(description="Detail columns:",
                                       value=self.net.config["dashboard.features.columns"],
                                       min=0,
                                       max=1024,
                                       style=style)
        self.feature_scale = FloatText(description="Detail scale:",
                                       value=self.net.config["dashboard.features.scale"],
                                       min=0.1,
                                       max=10,
                                       style=style)
        self.feature_columns.observe(self.regenerate, names='value')
        self.feature_scale.observe(self.regenerate, names='value')
        ## Hack to center SVG as justify-content is broken:
        self.net_svg = HTML(value="""<p style="text-align:center">%s</p>""" % ("",), layout=Layout(
            width=self._width, overflow_x='auto', overflow_y="auto",
            justify_content="center"))
        # Make controls first:
        self.output = Output()
        controls = self.make_controls()
        config = self.make_config()
        super().__init__([config, controls, self.net_svg, self.output])

    def propagate(self, inputs):
        """
        Propagate inputs through the dashboard view of the network.
        """
        if dynamic_pictures_check():
            return self.net.propagate(inputs, class_id=self.class_id, update_pictures=True)
        else:
            self.regenerate(inputs=input)

    def goto(self, position):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            return
        if self.control_select.value == "Train":
            length = len(self.net.dataset.train_inputs)
        elif self.control_select.value == "Test":
            length = len(self.net.dataset.test_inputs)
        #### Position it:
        if position == "begin":
            self.control_slider.value = 0
        elif position == "end":
            self.control_slider.value = length - 1
        elif position == "prev":
            if self.control_slider.value - 1 < 0:
                self.control_slider.value = length - 1 # wrap around
            else:
                self.control_slider.value = max(self.control_slider.value - 1, 0)
        elif position == "next":
            if self.control_slider.value + 1 > length - 1:
                self.control_slider.value = 0 # wrap around
            else:
                self.control_slider.value = min(self.control_slider.value + 1, length - 1)
        self.position_text.value = self.control_slider.value


    def change_select(self, change=None):
        """
        """
        self.update_control_slider(change)
        self.regenerate()

    def update_control_slider(self, change=None):
        self.net.config["dashboard.dataset"] = self.control_select.value
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            self.control_slider.value = 0
            self.position_text.value = 0
            self.control_slider.disabled = True
            self.position_text.disabled = True
            for child in self.control_buttons.children:
                if not hasattr(child, "icon") or child.icon != "refresh":
                    child.disabled = True
            return
        if self.control_select.value == "Test":
            self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
            minmax = (0, max(len(self.net.dataset.test_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.test_inputs) == 0:
                disabled = True
            else:
                disabled = False
        elif self.control_select.value == "Train":
            self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
            minmax = (0, max(len(self.net.dataset.train_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.train_inputs) == 0:
                disabled = True
            else:
                disabled = False
        self.control_slider.disabled = disabled
        self.position_text.disbaled = disabled
        self.position_text.value = self.control_slider.value
        for child in self.control_buttons.children:
            if not hasattr(child, "icon") or child.icon != "refresh":
                child.disabled = disabled

    def update_zoom_slider(self, change):
        if change["name"] == "value":
            self.net.config["svg_scale"] = self.zoom_slider.value
            self.regenerate()

    def update_position_text(self, change):
        # {'name': 'value', 'old': 2, 'new': 3, 'owner': IntText(value=3, layout=Layout(width='100%')), 'type': 'change'}
        self.control_slider.value = change["new"]

    def get_current_input(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_inputs[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_inputs[self.control_slider.value]

    def get_current_targets(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_targets[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_targets[self.control_slider.value]

    def update_slider_control(self, change):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            return
        if change["name"] == "value":
            self.position_text.value = self.control_slider.value
            if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.train_inputs[self.control_slider.value],
                                    targets=self.net.dataset.train_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.train_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.train_inputs[self.control_slider.value],
                                                   cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]:
                    if len(self.net.output_bank_order) == 1: ## FIXME: use minmax of output bank
                        self.net.display_component([self.net.dataset.train_targets[self.control_slider.value]],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        self.net.display_component(self.net.dataset.train_targets[self.control_slider.value],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.train_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.train_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors",  class_id=self.class_id, minmax=(-1, 1))
            elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.test_inputs[self.control_slider.value],
                                    targets=self.net.dataset.test_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.test_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.test_inputs[self.control_slider.value],
                                               cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]: ## FIXME: use minmax of output bank
                    self.net.display_component([self.net.dataset.test_targets[self.control_slider.value]],
                                               "targets",
                                               class_id=self.class_id,
                                               minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.test_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.test_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1))

    def toggle_play(self, button):
        ## toggle
        if self.button_play.description == "Play":
            self.button_play.description = "Stop"
            self.button_play.icon = "pause"
            self.player.resume()
        else:
            self.button_play.description = "Play"
            self.button_play.icon = "play"
            self.player.pause()

    def prop_one(self, button=None):
        self.update_slider_control({"name": "value"})

    def regenerate(self, button=None, inputs=None, targets=None):
        ## Protection when deleting object on shutdown:
        if isinstance(button, dict) and 'new' in button and button['new'] is None:
            return
        ## Update the config:
        self.net.config["dashboard.features.bank"] = self.feature_bank.value
        self.net.config["dashboard.features.columns"] = self.feature_columns.value
        self.net.config["dashboard.features.scale"] = self.feature_scale.value
        inputs = inputs if inputs is not None else self.get_current_input()
        targets = targets if targets is not None else self.get_current_targets()
        features = None
        if self.feature_bank.value in self.net.layer_dict.keys() and inputs is not None:
            if self.net.model is not None:
                features = self.net.propagate_to_features(self.feature_bank.value, inputs,
                                                          cols=self.feature_columns.value,
                                                          scale=self.feature_scale.value, display=False)
        svg = """<p style="text-align:center">%s</p>""" % (self.net.to_svg(
            inputs=inputs,
            targets=targets,
            class_id=self.class_id,
            highlights={self.feature_bank.value: {
                "border_color": "orange",
                "border_width": 30,
            }}))
        if inputs is not None and features is not None:
            html_horizontal = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top" style="width: 50%%;">%s</td>
  <td valign="top" align="center" style="width: 50%%;"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            html_vertical = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top">%s</td>
</tr>
<tr>
  <td valign="top" align="center"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            self.net_svg.value = (html_vertical if self.net.config["svg_rotate"] else html_horizontal) % (
                svg, "%s details" % self.feature_bank.value, features)
        else:
            self.net_svg.value = svg

    def make_colormap_image(self, colormap_name):
        from .layers import Layer
        if not colormap_name:
            colormap_name = get_colormap()
        layer = Layer("Colormap", 100)
        minmax = layer.get_act_minmax()
        image = layer.make_image(np.arange(minmax[0], minmax[1], .01),
                                 colormap_name,
                                 {"pixels_per_unit": 1,
                                  "svg_rotate": self.net.config["svg_rotate"]}).resize((300, 25))
        return image

    def set_attr(self, obj, attr, value):
        if value not in [{}, None]: ## value is None when shutting down
            if isinstance(value, dict):
                value = value["value"]
            if isinstance(obj, dict):
                obj[attr] = value
            else:
                setattr(obj, attr, value)
            ## was crashing on Widgets.__del__, if get_ipython() no longer existed
            self.regenerate()

    def make_controls(self):
        layout = Layout(width='100%', height="100%")
        button_begin = Button(icon="fast-backward", layout=layout)
        button_prev = Button(icon="backward", layout=layout)
        button_next = Button(icon="forward", layout=layout)
        button_end = Button(icon="fast-forward", layout=layout)
        #button_prop = Button(description="Propagate", layout=Layout(width='100%'))
        #button_train = Button(description="Train", layout=Layout(width='100%'))
        self.button_play = Button(icon="play", description="Play", layout=layout)
        step_down = Button(icon="sort-down", layout=Layout(width="95%", height="100%"))
        step_up = Button(icon="sort-up", layout=Layout(width="95%", height="100%"))
        up_down = HBox([step_down, step_up], layout=Layout(width="100%", height="100%"))
        refresh_button = Button(icon="refresh", layout=Layout(width="25%", height="100%"))

        self.position_text = IntText(value=0, layout=layout)

        self.control_buttons = HBox([
            button_begin,
            button_prev,
            #button_train,
            self.position_text,
            button_next,
            button_end,
            self.button_play,
            up_down,
            refresh_button
        ], layout=Layout(width='100%', height="100%"))
        length = (len(self.net.dataset.train_inputs) - 1) if len(self.net.dataset.train_inputs) > 0 else 0
        self.control_slider = IntSlider(description="Dataset index",
                                   continuous_update=False,
                                   min=0,
                                   max=max(length, 0),
                                   value=0,
                                   layout=Layout(width='100%'))
        if self.net.config["dashboard.dataset"] == "Train":
            length = len(self.net.dataset.train_inputs)
        else:
            length = len(self.net.dataset.test_inputs)
        self.total_text = Label(value="of %s" % length, layout=Layout(width="100px"))
        self.zoom_slider = FloatSlider(description="Zoom",
                                       continuous_update=False,
                                       min=0, max=1.0,
                                       style={"description_width": 'initial'},
                                       layout=Layout(width="65%"),
                                       value=self.net.config["svg_scale"] if self.net.config["svg_scale"] is not None else 0.5)

        ## Hook them up:
        button_begin.on_click(lambda button: self.goto("begin"))
        button_end.on_click(lambda button: self.goto("end"))
        button_next.on_click(lambda button: self.goto("next"))
        button_prev.on_click(lambda button: self.goto("prev"))
        self.button_play.on_click(self.toggle_play)
        self.control_slider.observe(self.update_slider_control, names='value')
        refresh_button.on_click(lambda widget: (self.update_control_slider(),
                                                self.output.clear_output(),
                                                self.regenerate()))
        step_down.on_click(lambda widget: self.move_step("down"))
        step_up.on_click(lambda widget: self.move_step("up"))
        self.zoom_slider.observe(self.update_zoom_slider, names='value')
        self.position_text.observe(self.update_position_text, names='value')
        # Put them together:
        controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")),
                         self.control_buttons], layout=Layout(width='100%'))

        #net_page = VBox([control, self.net_svg], layout=Layout(width='95%'))
        controls.on_displayed(lambda widget: self.regenerate())
        return controls

    def move_step(self, direction):
        """
        Move the layer stepper up/down through network
        """
        options = [""] + [layer.name for layer in self.net.layers]
        index = options.index(self.feature_bank.value)
        if direction == "up":
            new_index = (index + 1) % len(options)
        else: ## down
            new_index = (index - 1) % len(options)
        self.feature_bank.value = options[new_index]
        self.regenerate()

    def make_config(self):
        layout = Layout()
        style = {"description_width": "initial"}
        checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"],
                             layout=layout, style=style)
        checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value')
        checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"],
                             layout=layout, style=style)
        checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value')

        hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:",
                         style=style, layout=layout)
        hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value')
        vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:",
                         style=style, layout=layout)
        vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value')
        self.feature_bank = Select(description="Details:", value=self.net.config["dashboard.features.bank"],
                              options=[""] + [layer.name for layer in self.net.layers],
                              rows=1)
        self.feature_bank.observe(self.regenerate, names='value')
        self.control_select = Select(
            options=['Test', 'Train'],
            value=self.net.config["dashboard.dataset"],
            description='Dataset:',
            rows=1
        )
        self.control_select.observe(self.change_select, names='value')
        column1 = [self.control_select,
                   self.zoom_slider,
                   hspace,
                   vspace,
                   HBox([checkbox1, checkbox2]),
                   self.feature_bank,
                   self.feature_columns,
                   self.feature_scale
        ]
        ## Make layer selectable, and update-able:
        column2 = []
        layer = self.net.layers[-1]
        self.layer_select = Select(description="Layer:", value=layer.name,
                                   options=[layer.name for layer in
                                            self.net.layers],
                                   rows=1)
        self.layer_select.observe(self.update_layer_selection, names='value')
        column2.append(self.layer_select)
        self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout)
        self.layer_visible_checkbox.observe(self.update_layer, names='value')
        column2.append(self.layer_visible_checkbox)
        self.layer_colormap = Select(description="Colormap:",
                                     options=[""] + AVAILABLE_COLORMAPS,
                                     value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1)
        self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)))
        self.layer_colormap.observe(self.update_layer, names='value')
        column2.append(self.layer_colormap)
        column2.append(self.layer_colormap_image)
        ## get dynamic minmax; if you change it it will set it in layer as override:
        minmax = layer.get_act_minmax()
        self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style)
        self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style)
        self.layer_mindim.observe(self.update_layer, names='value')
        self.layer_maxdim.observe(self.update_layer, names='value')
        column2.append(self.layer_mindim)
        column2.append(self.layer_maxdim)
        output_shape = layer.get_output_shape()
        self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style)
        self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout)
        self.layer_feature.observe(self.update_layer, names='value')
        column2.append(self.layer_feature)
        self.svg_rotate = Checkbox(description="Rotate network",
                                   value=self.net.config["svg_rotate"],
                                   style={"description_width": 'initial'},
                                   layout=Layout(width="52%"))
        self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value')
        self.save_config_button = Button(icon="save", layout=Layout(width="10%"))
        self.save_config_button.on_click(self.save_config)
        column2.append(HBox([self.svg_rotate, self.save_config_button]))
        config_children = HBox([VBox(column1, layout=Layout(width="100%")),
                                VBox(column2, layout=Layout(width="100%"))])
        accordion = Accordion(children=[config_children])
        accordion.set_title(0, self.net.name)
        accordion.selected_index = None
        return accordion

    def save_config(self, widget=None):
        self.net.save_config()

    def update_layer(self, change):
        """
        Update the layer object, and redisplay.
        """
        if self._ignore_layer_updates:
            return
        ## The rest indicates a change to a display variable.
        ## We need to save the value in the layer, and regenerate
        ## the display.
        # Get the layer:
        layer = self.net[self.layer_select.value]
        # Save the changed value in the layer:
        layer.feature = self.layer_feature.value
        layer.visible = self.layer_visible_checkbox.value
        ## These three, dealing with colors of activations,
        ## can be done with a prop_one():
        if "color" in change["owner"].description.lower():
            ## Matches: Colormap, lefmost color, rightmost color
            ## overriding dynamic minmax!
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.colormap = self.layer_colormap.value if self.layer_colormap.value else None
            self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
            self.prop_one()
        else:
            self.regenerate()

    def update_layer_selection(self, change):
        """
        Just update the widgets; don't redraw anything.
        """
        ## No need to redisplay anything
        self._ignore_layer_updates = True
        ## First, get the new layer selected:
        layer = self.net[self.layer_select.value]
        ## Now, let's update all of the values without updating:
        self.layer_visible_checkbox.value = layer.visible
        self.layer_colormap.value = layer.colormap if layer.colormap != "" else ""
        self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
        minmax = layer.get_act_minmax()
        self.layer_mindim.value = minmax[0]
        self.layer_maxdim.value = minmax[1]
        self.layer_feature.value = layer.feature
        self._ignore_layer_updates = False
示例#22
0
    class __Dbfs(object):
        """Database browser implementation
            
        Args:
            dbutils (DBUtils): DBUtils object (for fs only)
        """
        def __init__(self, dbutils):
            self.dbutils = dbutils

        def create(self):
            """Create the sidecar view"""
            self.sc = Sidecar(title="DBFS-%s" %
                              os.environ["DBJL_CLUSTER"].split("-")[-1])
            self.path = "/"
            self.flist = Select(options=[], rows=40, disabled=False)
            self.flist.observe(self.on_click, names="value")

            self.refresh = Button(description="refresh")
            self.refresh.on_click(self.on_refresh)
            self.output = Output()

            self.up = Button(description="up")
            self.up.on_click(self.on_up)

            with self.sc:
                display(
                    VBox([
                        HBox([self.up, self.refresh]), self.flist, self.output
                    ]))

            self.update()

        def convertBytes(self, fsize):
            """Convert bytes to largest unit
            
            Args:
                fsize (int): Size in bytes
            
            Returns:
                tuple: size of largest unit, largest unit
            """
            size = fsize
            unit = "B"
            if size > 1024 * 1024 * 1024 * 10:
                size = int(size / 1024.0 / 1024.0 / 1024.0)
                unit = "GB"
            elif size > 1024 * 1024 * 10:
                size = int(size / 1024.0 / 1024.0)
                unit = "MB"
            elif size > 1024 * 10:
                size = int(size / 1024.0)
                unit = "KB"
            return (size, unit)

        def update(self):
            """Update the view when an element was selected"""
            with self.output:
                print("updating ...")
            fobjs = self.dbutils.fs.ls(self.path)
            self.show_path(self.path)

            dirs = sorted([fobj.name for fobj in fobjs if fobj.isDir()],
                          key=lambda x: x.lower())
            files = sorted(
                [
                    "%s (%d %s)" %
                    ((fobj.name, ) + self.convertBytes(fobj.size))
                    for fobj in fobjs if not fobj.isDir()
                ],
                key=lambda x: x[0].lower(),
            )
            self.flist.options = [""] + dirs + files

        def show_path(self, path):
            """Show path in output widget
            
            Args:
                path (str): Currently selected path
            """
            self.output.clear_output()
            with self.output:
                print("dbfs:" + re.sub(r"\s\(.*?\)$", "", path))

        def on_refresh(self, b):
            """Refresh handler
            
            Args:
                b (ipywidgets.Button): clicked button
            """
            self.update()

        def on_up(self, b):
            """Up handler
            
            Args:
                b (ipywidgets.Button): clicked button
            """
            new_path = os.path.dirname(self.path.rstrip("/"))
            if new_path != self.path:
                self.path = new_path
            self.update()

        def on_click(self, change):
            """Click handler providing db and parent as context
            
            Args:
                db (str): database name
                parent (object): parent object
            """
            new_path = os.path.join(self.path, change["new"])
            if change["old"] is not None:
                if change["new"][-1] == "/":
                    self.path = new_path
                    self.update()
                else:
                    self.show_path(new_path)

        def close(self):
            """Close view"""
            self.sc.close()
示例#23
0
def cytoscapegraph(graph, onto=None, infobox=None, style=None):
    """Returns and instance of icytoscape-figure for an
    instance Graph of OntoGraph, the accomanying ontology
    is required for mouse actions"""

    from ipywidgets import Output, VBox, GridspecLayout
    from IPython.display import display, Image
    from pathlib import Path
    import networkx as nx
    import pydotplus
    import ipycytoscape
    from networkx.readwrite.json_graph import cytoscape_data
    # Define the styles, this has to be aligned with the graphviz values
    dotplus = pydotplus.graph_from_dot_data(graph.dot.source)
    # if graph doesn't have multiedges, use dotplus.set_strict(true)
    G = nx.nx_pydot.from_pydot(dotplus)

    colours, styles, fill = cytoscape_style()

    data = cytoscape_data(G)['elements']
    for d in data['edges']:
        d['data']['label'] = d['data']['label'].rsplit(' ', 1)[0].lstrip('"')
        lab = d['data']['label'].replace('Inverse(', '').rstrip(')')
        try:
            d['data']['colour'] = colours[lab]
        except KeyError:
            d['data']['colour'] = 'black'
        try:
            d['data']['style'] = styles[lab]
        except KeyError:
            d['data']['style'] = 'solid'
        if d['data']['label'].startswith('Inverse('):
            d['data']['targetarrow'] = 'diamond'
            d['data']['sourcearrow'] = 'none'
        else:
            d['data']['targetarrow'] = 'triangle'
            d['data']['sourcearrow'] = 'none'
        try:
            d['data']['fill'] = fill[lab]
        except KeyError:
            d['data']['fill'] = 'filled'

    cytofig = ipycytoscape.CytoscapeWidget()
    cytofig.graph.add_graph_from_json(data, directed=True)

    cytofig.set_style([
        {
            'selector': 'node',
            'css': {
                'content': 'data(label)',
                # 'text-valign': 'center',
                # 'color': 'white',
                # 'text-outline-width': 2,
                # 'text-outline-color': 'red',
                'background-color': 'blue'
            },
        },
        {
            'selector': 'node:parent',
            'css': {
                'background-opacity': 0.333
            }
        },
        {
            'selector': 'edge',
            'style': {
                'width': 2,
                'line-color': 'data(colour)',
                # 'content': 'data(label)',
                'line-style': 'data(style)'
            }
        },
        {
            'selector': 'edge.directed',
            'style': {
                'curve-style': 'bezier',
                'target-arrow-shape': 'data(targetarrow)',
                'target-arrow-color': 'data(colour)',
                'target-arrow-fill': 'data(fill)',
                'mid-source-arrow-shape': 'data(sourcearrow)',
                'mid-source-arrow-color': 'data(colour)'
            },
        },
        {
            'selector': 'edge.multiple_edges',
            'style': {
                'curve-style': 'bezier'
            }
        },
        {
            'selector': ':selected',
            'css': {
                'background-color': 'black',
                'line-color': 'black',
                'target-arrow-color': 'black',
                'source-arrow-color': 'black',
                'text-outline-color': 'black'
            },
        },
    ])

    if onto is not None:
        out = Output(layout={'border': '1px solid black'})

        def log_clicks(node):
            with out:
                print((onto.get_by_label(node["data"]["label"])))
                p = onto.get_by_label(node["data"]["label"]).get_parents()
                print(f'parents: {p}')
                try:
                    elucidation = onto.get_by_label(
                        node["data"]["label"]).elucidation
                    print(f'elucidation: {elucidation[0]}')
                except (AttributeError, IndexError):
                    pass

                try:
                    annotations = onto.get_by_label(
                        node["data"]["label"]).annotations
                    for e in annotations:
                        print(f'annotation: {e}')
                except AttributeError:
                    pass

                # Try does not work...
                try:
                    iri = onto.get_by_label(node["data"]["label"]).iri
                    print(f'iri: {iri}')
                except Exception:
                    pass
                try:
                    fig = node["data"]["label"]
                    if os.path.exists(Path(fig + '.png')):
                        display(Image(fig + '.png', width=100))
                    elif os.path.exists(Path(fig + '.jpg')):
                        display(Image(fig + '.jpg', width=100))
                except Exception:  # FIXME: make this more specific
                    pass
                out.clear_output(wait=True)

        def log_mouseovers(node):
            with out:
                print(onto.get_by_label(node["data"]["label"]))
                # print(f'mouseover: {pformat(node)}')
            out.clear_output(wait=True)

        cytofig.on('node', 'click', log_clicks)
        cytofig.on('node', 'mouseover', log_mouseovers)  # , remove=True)
        cytofig.on('node', 'mouseout', out.clear_output(wait=True))
        grid = GridspecLayout(1, 3, height='400px')
        if infobox == 'left':
            grid[0, 0] = out
            grid[0, 1:] = cytofig
        elif infobox == 'right':
            grid[0, 0:-1] = cytofig
            grid[0, 2] = out
        else:
            return VBox([cytofig, out])
        return grid

    return cytofig
示例#24
0
out = Output()
display(out)

save_rate = 10000
for i in log_progress(range(1, len(T)), every=save_rate):
    sln = sp.integrate.solve_ivp(der, (0, dt), uk, method='RK45')
    # extract solution
    u = sln.y[:, -1]
    # dealiasing
    u *= da
    # save solution at each n-th time-point
    if not i % save_rate:
        U.append(np.fft.ifft(u).real)
        T_saved.append(T[i])
        # plot solution
        out.clear_output(wait=True)
        with out:
            plt.plot(X, -U[-1])
            plt.show()
            print("time:     %.2f \u03BCs" % T_saved[-1])
            dist = v_water * T_saved[-1]
            print("distance: %.2f mm" % dist)
    # define next initial condition for integrator
    uk = u

# In[191]:

direct = "KdV_forced_delta_func_PS/"
np.save(direct + "U.npy", U)
np.save(direct + "T.npy", T_saved)
np.save(direct + "X.npy", X)
示例#25
0
class Progress(JupyterMixin, RenderHook):
    """Renders an auto-updating progress bar(s).

    Args:
        console (Console, optional): Optional Console instance. Default will an internal Console instance writing to stdout.
        auto_refresh (bool, optional): Enable auto refresh. If disabled, you will need to call `refresh()`.
        refresh_per_second (Optional[int], optional): Number of times per second to refresh the progress information or None to use default (10). Defaults to None.
        speed_estimate_period: (float, optional): Period (in seconds) used to calculate the speed estimate. Defaults to 30.
        transient: (bool, optional): Clear the progress on exit. Defaults to False.
        redirect_stout: (bool, optional): Enable redirection of stdout, so ``print`` may be used. Defaults to True.
        redirect_stout: (bool, optional): Enable redirection of stderr. Defaults to True.
        get_time: (Callable, optional): A callable that gets the current time, or None to use time.monotonic. Defaults to None.
    """
    def __init__(
        self,
        *columns: Union[str, ProgressColumn],
        console: Console = None,
        auto_refresh: bool = True,
        refresh_per_second: int = None,
        speed_estimate_period: float = 30.0,
        transient: bool = False,
        redirect_stdout: bool = True,
        redirect_stderr: bool = True,
        get_time: GetTimeCallable = None,
    ) -> None:
        assert (refresh_per_second is None
                or refresh_per_second > 0), "refresh_per_second must be > 0"
        self._lock = RLock()
        self.columns = columns or (
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
            TimeRemainingColumn(),
        )
        self.console = console or get_console()
        self.auto_refresh = auto_refresh and not self.console.is_jupyter
        self.refresh_per_second = refresh_per_second or 10
        self.speed_estimate_period = speed_estimate_period
        self.transient = transient
        self._redirect_stdout = redirect_stdout
        self._redirect_stderr = redirect_stderr
        self.get_time = get_time or monotonic
        self._tasks: Dict[TaskID, Task] = {}
        self._live_render = LiveRender(self.get_renderable())
        self._task_index: TaskID = TaskID(0)
        self._refresh_thread: Optional[_RefreshThread] = None
        self._started = False
        self.print = self.console.print
        self.log = self.console.log
        self._restore_stdout: Optional[IO[str]] = None
        self._restore_stderr: Optional[IO[str]] = None
        self.ipy_widget: Optional[Any] = None

    @property
    def tasks(self) -> List[Task]:
        """Get a list of Task instances."""
        with self._lock:
            return list(self._tasks.values())

    @property
    def task_ids(self) -> List[TaskID]:
        """A list of task IDs."""
        with self._lock:
            return list(self._tasks.keys())

    @property
    def finished(self) -> bool:
        """Check if all tasks have been completed."""
        with self._lock:
            if not self._tasks:
                return True
            return all(task.finished for task in self._tasks.values())

    def _enable_redirect_io(self):
        """Enable redirecting of stdout / stderr."""
        if self.console.is_terminal:
            if self._redirect_stdout:
                self._restore_stdout = sys.stdout
                sys.stdout = _FileProxy(self.console, sys.stdout)
            if self._redirect_stderr:
                self._restore_stderr = sys.stderr
                sys.stdout = _FileProxy(self.console, sys.stdout)

    def _disable_redirect_io(self):
        """Disable redirecting of stdout / stderr."""
        if self._restore_stdout:
            sys.stdout = self._restore_stdout
            self._restore_stdout = None
        if self._restore_stderr:
            sys.stderr = self._restore_stderr
            self._restore_stderr = None

    def start(self) -> None:
        """Start the progress display."""
        with self._lock:
            if self._started:
                return
            self._started = True
            self.console.show_cursor(False)
            self._enable_redirect_io()
            self.console.push_render_hook(self)
            self.refresh()
            if self.auto_refresh:
                self._refresh_thread = _RefreshThread(self,
                                                      self.refresh_per_second)
                self._refresh_thread.start()

    def stop(self) -> None:
        """Stop the progress display."""
        with self._lock:
            if not self._started:
                return
            self._started = False
            try:
                if self.auto_refresh and self._refresh_thread is not None:
                    self._refresh_thread.stop()
                self.refresh()
                if self.console.is_terminal:
                    self.console.line()
            finally:
                self.console.show_cursor(True)
                self._disable_redirect_io()
                self.console.pop_render_hook()
        if self._refresh_thread is not None:
            self._refresh_thread.join()
            self._refresh_thread = None
        if self.transient:
            self.console.control(self._live_render.restore_cursor())
        if self.ipy_widget is not None and self.transient:  # pragma: no cover
            self.ipy_widget.clear_output()
            self.ipy_widget.close()

    def __enter__(self) -> "Progress":
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.stop()

    def track(
        self,
        sequence: Union[Iterable[ProgressType], Sequence[ProgressType]],
        total: int = None,
        task_id: Optional[TaskID] = None,
        description="Working...",
        update_period: float = 0.1,
    ) -> Iterable[ProgressType]:
        """Track progress by iterating over a sequence.

        Args:
            sequence (Sequence[ProgressType]): A sequence of values you want to iterate over and track progress.
            total: (int, optional): Total number of steps. Default is len(sequence).
            task_id: (TaskID): Task to track. Default is new task.
            description: (str, optional): Description of task, if new task is created.
            update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.1.

        Returns:
            Iterable[ProgressType]: An iterable of values taken from the provided sequence.
        """
        if total is None:
            if isinstance(sequence, Sized):
                task_total = len(sequence)
            else:
                raise ValueError(
                    f"unable to get size of {sequence!r}, please specify 'total'"
                )
        else:
            task_total = total

        if task_id is None:
            task_id = self.add_task(description, total=task_total)
        else:
            self.update(task_id, total=task_total)
        with self:
            if self.auto_refresh:
                with _TrackThread(self, task_id,
                                  update_period) as track_thread:
                    for value in sequence:
                        yield value
                        track_thread.completed += 1
            else:
                advance = self.advance
                refresh = self.refresh
                for value in sequence:
                    yield value
                    advance(task_id, 1)
                    refresh()

    def start_task(self, task_id: TaskID) -> None:
        """Start a task.

        Starts a task (used when calculating elapsed time). You may need to call this manually,
        if you called ``add_task`` with ``start=False``.

        Args:
            task_id (TaskID): ID of task.
        """
        with self._lock:
            task = self._tasks[task_id]
            if task.start_time is None:
                task.start_time = self.get_time()

    def stop_task(self, task_id: TaskID) -> None:
        """Stop a task.

        This will freeze the elapsed time on the task.

        Args:
            task_id (TaskID): ID of task.
        """
        with self._lock:
            task = self._tasks[task_id]
            current_time = self.get_time()
            if task.start_time is None:
                task.start_time = current_time
            task.stop_time = current_time

    def update(
        self,
        task_id: TaskID,
        *,
        total: float = None,
        completed: float = None,
        advance: float = None,
        description: str = None,
        visible: bool = None,
        refresh: bool = False,
        **fields: Any,
    ) -> None:
        """Update information associated with a task.

        Args:
            task_id (TaskID): Task id (returned by add_task).
            total (float, optional): Updates task.total if not None.
            completed (float, optional): Updates task.completed if not None.
            advance (float, optional): Add a value to task.completed if not None.
            description (str, optional): Change task description if not None.
            visible (bool, optional): Set visible flag if not None.
            refresh (bool): Force a refresh of progress information. Default is False.
            **fields (Any): Additional data fields required for rendering.
        """
        with self._lock:
            task = self._tasks[task_id]
            completed_start = task.completed

            if total is not None:
                task.total = total
            if advance is not None:
                task.completed += advance
            if completed is not None:
                task.completed = completed
            if description is not None:
                task.description = description
            if visible is not None:
                task.visible = visible
            task.fields.update(fields)
            update_completed = task.completed - completed_start

            if refresh:
                self.refresh()

            current_time = self.get_time()
            old_sample_time = current_time - self.speed_estimate_period
            _progress = task._progress

            popleft = _progress.popleft
            while _progress and _progress[0].timestamp < old_sample_time:
                popleft()
            while len(_progress) > 1000:
                popleft()
            _progress.append(ProgressSample(current_time, update_completed))

    def advance(self, task_id: TaskID, advance: float = 1) -> None:
        """Advance task by a number of steps.

        Args:
            task_id (TaskID): ID of task.
            advance (float): Number of steps to advance. Default is 1.
        """
        current_time = self.get_time()
        with self._lock:
            task = self._tasks[task_id]
            completed_start = task.completed
            task.completed += advance
            update_completed = task.completed - completed_start
            old_sample_time = current_time - self.speed_estimate_period
            _progress = task._progress

            popleft = _progress.popleft
            while _progress and _progress[0].timestamp < old_sample_time:
                popleft()
            while len(_progress) > 1000:
                popleft()
            _progress.append(ProgressSample(current_time, update_completed))

    def refresh(self) -> None:
        """Refresh (render) the progress information."""
        if self.console.is_jupyter:  # pragma: no cover
            try:
                from ipywidgets import Output
                from IPython.display import display
            except ImportError:
                import warnings

                warnings.warn('install "ipywidgets" for Jupyter support')
            else:
                with self._lock:
                    if self.ipy_widget is None:
                        self.ipy_widget = Output()
                        display(self.ipy_widget)

                    with self.ipy_widget:
                        self.ipy_widget.clear_output(wait=True)
                        self.console.print(self.get_renderable())

        elif self.console.is_terminal and not self.console.is_dumb_terminal:
            with self._lock:
                self._live_render.set_renderable(self.get_renderable())
                with self.console:
                    self.console.print(Control(""))

    def get_renderable(self) -> RenderableType:
        """Get a renderable for the progress display."""
        renderable = RenderGroup(*self.get_renderables())
        return renderable

    def get_renderables(self) -> Iterable[RenderableType]:
        """Get a number of renderables for the progress display."""
        table = self.make_tasks_table(self.tasks)
        yield table

    def make_tasks_table(self, tasks: Iterable[Task]) -> Table:
        """Get a table to render the Progress display.

        Args:
            tasks (Iterable[Task]): An iterable of Task instances, one per row of the table.

        Returns:
            Table: A table instance.
        """

        table = Table.grid(padding=(0, 1))
        for _ in self.columns:
            table.add_column()
        for task in tasks:
            if task.visible:
                row: List[RenderableType] = []
                append = row.append
                for index, column in enumerate(self.columns):
                    if isinstance(column, str):
                        append(column.format(task=task))
                        table.columns[index].no_wrap = True
                    else:
                        widget = column(task)
                        append(widget)
                        if isinstance(widget, (str, Text)):
                            table.columns[index].no_wrap = True
                table.add_row(*row)
        return table

    def add_task(
        self,
        description: str,
        start: bool = True,
        total: int = 100,
        completed: int = 0,
        visible: bool = True,
        **fields: Any,
    ) -> TaskID:
        """Add a new 'task' to the Progress display.

        Args:
            description (str): A description of the task.
            start (bool, optional): Start the task immediately (to calculate elapsed time). If set to False,
                you will need to call `start` manually. Defaults to True.
            total (int, optional): Number of total steps in the progress if know. Defaults to 100.
            completed (int, optional): Number of steps completed so far.. Defaults to 0.
            visible (bool, optional): Enable display of the task. Defaults to True.
            **fields (str): Additional data fields required for rendering.

        Returns:
            TaskID: An ID you can use when calling `update`.
        """
        with self._lock:
            task = Task(
                self._task_index,
                description,
                total,
                completed,
                visible=visible,
                fields=fields,
                _get_time=self.get_time,
            )
            self._tasks[self._task_index] = task
            if start:
                self.start_task(self._task_index)
            self.refresh()
            try:
                return self._task_index
            finally:
                self._task_index = TaskID(int(self._task_index) + 1)

    def remove_task(self, task_id: TaskID) -> None:
        """Delete a task if it exists.

        Args:
            task_id (TaskID): A task ID.

        """
        with self._lock:
            del self._tasks[task_id]

    def process_renderables(
            self,
            renderables: List[ConsoleRenderable]) -> List[ConsoleRenderable]:
        """Process renderables to restore cursor and display progress."""
        if self.console.is_terminal:
            renderables = [
                self._live_render.position_cursor(),
                *renderables,
                self._live_render,
            ]
        return renderables
示例#26
0
文件: widgets.py 项目: uday1889/conx
class Dashboard(VBox):
    """
    Build the dashboard for Jupyter widgets. Requires running
    in a notebook/jupyterlab.
    """
    def __init__(self, net, width="95%", height="550px", play_rate=0.5):
        self._ignore_layer_updates = False
        self.player = _Player(self, play_rate)
        self.player.start()
        self.net = net
        r = random.randint(1, 1000000)
        self.class_id = "picture-dashboard-%s-%s" % (self.net.name, r)
        self._width = width
        self._height = height
        ## Global widgets:
        style = {"description_width": "initial"}
        self.feature_columns = IntText(description="Feature columns:",
                                       value=self.net.config["dashboard.features.columns"],
                                       min=0,
                                       max=1024,
                                       style=style)
        self.feature_scale = FloatText(description="Feature scale:",
                                       value=self.net.config["dashboard.features.scale"],
                                       min=0.1,
                                       max=10,
                                       style=style)
        self.feature_columns.observe(self.regenerate, names='value')
        self.feature_scale.observe(self.regenerate, names='value')
        ## Hack to center SVG as justify-content is broken:
        self.net_svg = HTML(value="""<p style="text-align:center">%s</p>""" % ("",), layout=Layout(
            width=self._width, overflow_x='auto', overflow_y="auto",
            justify_content="center"))
        # Make controls first:
        self.output = Output()
        controls = self.make_controls()
        config = self.make_config()
        super().__init__([config, controls, self.net_svg, self.output])

    def propagate(self, inputs):
        """
        Propagate inputs through the dashboard view of the network.
        """
        if dynamic_pictures_check():
            return self.net.propagate(inputs, class_id=self.class_id, update_pictures=True)
        else:
            self.regenerate(inputs=input)

    def goto(self, position):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            return
        if self.control_select.value == "Train":
            length = len(self.net.dataset.train_inputs)
        elif self.control_select.value == "Test":
            length = len(self.net.dataset.test_inputs)
        #### Position it:
        if position == "begin":
            self.control_slider.value = 0
        elif position == "end":
            self.control_slider.value = length - 1
        elif position == "prev":
            if self.control_slider.value - 1 < 0:
                self.control_slider.value = length - 1 # wrap around
            else:
                self.control_slider.value = max(self.control_slider.value - 1, 0)
        elif position == "next":
            if self.control_slider.value + 1 > length - 1:
                self.control_slider.value = 0 # wrap around
            else:
                self.control_slider.value = min(self.control_slider.value + 1, length - 1)
        self.position_text.value = self.control_slider.value


    def change_select(self, change=None):
        """
        """
        self.update_control_slider(change)
        self.regenerate()

    def update_control_slider(self, change=None):
        self.net.config["dashboard.dataset"] = self.control_select.value
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            self.control_slider.value = 0
            self.position_text.value = 0
            self.control_slider.disabled = True
            self.position_text.disabled = True
            for child in self.control_buttons.children:
                if not hasattr(child, "icon") or child.icon != "refresh":
                    child.disabled = True
            return
        if self.control_select.value == "Test":
            self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
            minmax = (0, max(len(self.net.dataset.test_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.test_inputs) == 0:
                disabled = True
            else:
                disabled = False
        elif self.control_select.value == "Train":
            self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
            minmax = (0, max(len(self.net.dataset.train_inputs) - 1, 0))
            if minmax[0] <= self.control_slider.value <= minmax[1]:
                pass # ok
            else:
                self.control_slider.value = 0
            self.control_slider.min = minmax[0]
            self.control_slider.max = minmax[1]
            if len(self.net.dataset.train_inputs) == 0:
                disabled = True
            else:
                disabled = False
        self.control_slider.disabled = disabled
        self.position_text.disbaled = disabled
        self.position_text.value = self.control_slider.value
        for child in self.control_buttons.children:
            if not hasattr(child, "icon") or child.icon != "refresh":
                child.disabled = disabled

    def update_zoom_slider(self, change):
        if change["name"] == "value":
            self.net.config["svg_scale"] = self.zoom_slider.value
            self.regenerate()

    def update_position_text(self, change):
        # {'name': 'value', 'old': 2, 'new': 3, 'owner': IntText(value=3, layout=Layout(width='100%')), 'type': 'change'}
        self.control_slider.value = change["new"]

    def get_current_input(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_inputs[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_inputs[self.control_slider.value]

    def get_current_targets(self):
        if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
            return self.net.dataset.train_targets[self.control_slider.value]
        elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
            return self.net.dataset.test_targets[self.control_slider.value]

    def update_slider_control(self, change):
        if len(self.net.dataset.inputs) == 0 or len(self.net.dataset.targets) == 0:
            self.total_text.value = "of 0"
            return
        if change["name"] == "value":
            self.position_text.value = self.control_slider.value
            if self.control_select.value == "Train" and len(self.net.dataset.train_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.train_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.train_inputs[self.control_slider.value],
                                    targets=self.net.dataset.train_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.train_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.train_inputs[self.control_slider.value],
                                               cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]:
                    if len(self.net.output_bank_order) == 1: ## FIXME: use minmax of output bank
                        self.net.display_component([self.net.dataset.train_targets[self.control_slider.value]],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        self.net.display_component(self.net.dataset.train_targets[self.control_slider.value],
                                                   "targets",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.train_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.train_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors",  class_id=self.class_id, minmax=(-1, 1))
            elif self.control_select.value == "Test" and len(self.net.dataset.test_targets) > 0:
                self.total_text.value = "of %s" % len(self.net.dataset.test_inputs)
                if self.net.model is None:
                    return
                if not dynamic_pictures_check():
                    self.regenerate(inputs=self.net.dataset.test_inputs[self.control_slider.value],
                                    targets=self.net.dataset.test_targets[self.control_slider.value])
                    return
                output = self.net.propagate(self.net.dataset.test_inputs[self.control_slider.value],
                                            class_id=self.class_id, update_pictures=True)
                if self.feature_bank.value in self.net.layer_dict.keys():
                    self.net.propagate_to_features(self.feature_bank.value, self.net.dataset.test_inputs[self.control_slider.value],
                                               cols=self.feature_columns.value, scale=self.feature_scale.value, html=False)
                if self.net.config["show_targets"]: ## FIXME: use minmax of output bank
                    self.net.display_component([self.net.dataset.test_targets[self.control_slider.value]],
                                               "targets",
                                               class_id=self.class_id,
                                               minmax=(-1, 1))
                if self.net.config["show_errors"]: ## minmax is error
                    if len(self.net.output_bank_order) == 1:
                        errors = np.array(output) - np.array(self.net.dataset.test_targets[self.control_slider.value])
                        self.net.display_component([errors.tolist()],
                                                   "errors",
                                                   class_id=self.class_id,
                                                   minmax=(-1, 1))
                    else:
                        errors = []
                        for bank in range(len(self.net.output_bank_order)):
                            errors.append( np.array(output[bank]) - np.array(self.net.dataset.test_targets[self.control_slider.value][bank]))
                        self.net.display_component(errors, "errors", class_id=self.class_id, minmax=(-1, 1))

    def toggle_play(self, button):
        ## toggle
        if self.button_play.description == "Play":
            self.button_play.description = "Stop"
            self.button_play.icon = "pause"
            self.player.resume()
        else:
            self.button_play.description = "Play"
            self.button_play.icon = "play"
            self.player.pause()

    def prop_one(self, button=None):
        self.update_slider_control({"name": "value"})

    def regenerate(self, button=None, inputs=None, targets=None):
        ## Protection when deleting object on shutdown:
        if isinstance(button, dict) and 'new' in button and button['new'] is None:
            return
        ## Update the config:
        self.net.config["dashboard.features.bank"] = self.feature_bank.value
        self.net.config["dashboard.features.columns"] = self.feature_columns.value
        self.net.config["dashboard.features.scale"] = self.feature_scale.value
        inputs = inputs if inputs is not None else self.get_current_input()
        targets = targets if targets is not None else self.get_current_targets()
        features = None
        if self.feature_bank.value in self.net.layer_dict.keys() and inputs is not None:
            if self.net.model is not None:
                features = self.net.propagate_to_features(self.feature_bank.value, inputs,
                                                          cols=self.feature_columns.value,
                                                          scale=self.feature_scale.value, display=False)
        svg = """<p style="text-align:center">%s</p>""" % (self.net.to_svg(inputs=inputs, targets=targets,
                                                                           class_id=self.class_id),)
        if inputs is not None and features is not None:
            html_horizontal = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top" style="width: 50%%;">%s</td>
  <td valign="top" align="center" style="width: 50%%;"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            html_vertical = """
<table align="center" style="width: 100%%;">
 <tr>
  <td valign="top">%s</td>
</tr>
<tr>
  <td valign="top" align="center"><p style="text-align:center"><b>%s</b></p>%s</td>
</tr>
</table>"""
            self.net_svg.value = (html_vertical if self.net.config["svg_rotate"] else html_horizontal) % (
                svg, "%s features" % self.feature_bank.value, features)
        else:
            self.net_svg.value = svg

    def make_colormap_image(self, colormap_name):
        from .layers import Layer
        if not colormap_name:
            colormap_name = get_colormap()
        layer = Layer("Colormap", 100)
        minmax = layer.get_act_minmax()
        image = layer.make_image(np.arange(minmax[0], minmax[1], .01),
                                 colormap_name,
                                 {"pixels_per_unit": 1,
                                  "svg_rotate": self.net.config["svg_rotate"]}).resize((300, 25))
        return image

    def set_attr(self, obj, attr, value):
        if value not in [{}, None]: ## value is None when shutting down
            if isinstance(value, dict):
                value = value["value"]
            if isinstance(obj, dict):
                obj[attr] = value
            else:
                setattr(obj, attr, value)
            ## was crashing on Widgets.__del__, if get_ipython() no longer existed
            self.regenerate()

    def make_controls(self):
        button_begin = Button(icon="fast-backward", layout=Layout(width='100%'))
        button_prev = Button(icon="backward", layout=Layout(width='100%'))
        button_next = Button(icon="forward", layout=Layout(width='100%'))
        button_end = Button(icon="fast-forward", layout=Layout(width='100%'))
        #button_prop = Button(description="Propagate", layout=Layout(width='100%'))
        #button_train = Button(description="Train", layout=Layout(width='100%'))
        self.button_play = Button(icon="play", description="Play", layout=Layout(width="100%"))
        refresh_button = Button(icon="refresh", layout=Layout(width="25%"))

        self.position_text = IntText(value=0, layout=Layout(width="100%"))

        self.control_buttons = HBox([
            button_begin,
            button_prev,
            #button_train,
            self.position_text,
            button_next,
            button_end,
            self.button_play,
            refresh_button
        ], layout=Layout(width='100%', height="50px"))
        length = (len(self.net.dataset.train_inputs) - 1) if len(self.net.dataset.train_inputs) > 0 else 0
        self.control_slider = IntSlider(description="Dataset index",
                                   continuous_update=False,
                                   min=0,
                                   max=max(length, 0),
                                   value=0,
                                   layout=Layout(width='100%'))
        if self.net.config["dashboard.dataset"] == "Train":
            length = len(self.net.dataset.train_inputs)
        else:
            length = len(self.net.dataset.test_inputs)
        self.total_text = Label(value="of %s" % length, layout=Layout(width="100px"))
        self.zoom_slider = FloatSlider(description="Zoom",
                                       continuous_update=False,
                                       min=0, max=1.0,
                                       style={"description_width": 'initial'},
                                       layout=Layout(width="65%"),
                                       value=self.net.config["svg_scale"] if self.net.config["svg_scale"] is not None else 0.5)

        ## Hook them up:
        button_begin.on_click(lambda button: self.goto("begin"))
        button_end.on_click(lambda button: self.goto("end"))
        button_next.on_click(lambda button: self.goto("next"))
        button_prev.on_click(lambda button: self.goto("prev"))
        self.button_play.on_click(self.toggle_play)
        self.control_slider.observe(self.update_slider_control, names='value')
        refresh_button.on_click(lambda widget: (self.update_control_slider(),
                                                self.output.clear_output(),
                                                self.regenerate()))
        self.zoom_slider.observe(self.update_zoom_slider, names='value')
        self.position_text.observe(self.update_position_text, names='value')
        # Put them together:
        controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")),
                         self.control_buttons], layout=Layout(width='100%'))

        #net_page = VBox([control, self.net_svg], layout=Layout(width='95%'))
        controls.on_displayed(lambda widget: self.regenerate())
        return controls

    def make_config(self):
        layout = Layout()
        style = {"description_width": "initial"}
        checkbox1 = Checkbox(description="Show Targets", value=self.net.config["show_targets"],
                             layout=layout, style=style)
        checkbox1.observe(lambda change: self.set_attr(self.net.config, "show_targets", change["new"]), names='value')
        checkbox2 = Checkbox(description="Errors", value=self.net.config["show_errors"],
                             layout=layout, style=style)
        checkbox2.observe(lambda change: self.set_attr(self.net.config, "show_errors", change["new"]), names='value')

        hspace = IntText(value=self.net.config["hspace"], description="Horizontal space between banks:",
                         style=style, layout=layout)
        hspace.observe(lambda change: self.set_attr(self.net.config, "hspace", change["new"]), names='value')
        vspace = IntText(value=self.net.config["vspace"], description="Vertical space between layers:",
                         style=style, layout=layout)
        vspace.observe(lambda change: self.set_attr(self.net.config, "vspace", change["new"]), names='value')
        self.feature_bank = Select(description="Features:", value=self.net.config["dashboard.features.bank"],
                              options=[""] + [layer.name for layer in self.net.layers if self.net._layer_has_features(layer.name)],
                              rows=1)
        self.feature_bank.observe(self.regenerate, names='value')
        self.control_select = Select(
            options=['Test', 'Train'],
            value=self.net.config["dashboard.dataset"],
            description='Dataset:',
            rows=1
        )
        self.control_select.observe(self.change_select, names='value')
        column1 = [self.control_select,
                   self.zoom_slider,
                   hspace,
                   vspace,
                   HBox([checkbox1, checkbox2]),
                   self.feature_bank,
                   self.feature_columns,
                   self.feature_scale
        ]
        ## Make layer selectable, and update-able:
        column2 = []
        layer = self.net.layers[-1]
        self.layer_select = Select(description="Layer:", value=layer.name,
                                   options=[layer.name for layer in
                                            self.net.layers],
                                   rows=1)
        self.layer_select.observe(self.update_layer_selection, names='value')
        column2.append(self.layer_select)
        self.layer_visible_checkbox = Checkbox(description="Visible", value=layer.visible, layout=layout)
        self.layer_visible_checkbox.observe(self.update_layer, names='value')
        column2.append(self.layer_visible_checkbox)
        self.layer_colormap = Select(description="Colormap:",
                                     options=[""] + AVAILABLE_COLORMAPS,
                                     value=layer.colormap if layer.colormap is not None else "", layout=layout, rows=1)
        self.layer_colormap_image = HTML(value="""<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap)))
        self.layer_colormap.observe(self.update_layer, names='value')
        column2.append(self.layer_colormap)
        column2.append(self.layer_colormap_image)
        ## get dynamic minmax; if you change it it will set it in layer as override:
        minmax = layer.get_act_minmax()
        self.layer_mindim = FloatText(description="Leftmost color maps to:", value=minmax[0], style=style)
        self.layer_maxdim = FloatText(description="Rightmost color maps to:", value=minmax[1], style=style)
        self.layer_mindim.observe(self.update_layer, names='value')
        self.layer_maxdim.observe(self.update_layer, names='value')
        column2.append(self.layer_mindim)
        column2.append(self.layer_maxdim)
        output_shape = layer.get_output_shape()
        self.layer_feature = IntText(value=layer.feature, description="Feature to show:", style=style)
        self.svg_rotate = Checkbox(description="Rotate", value=layer.visible, layout=layout)
        self.layer_feature.observe(self.update_layer, names='value')
        column2.append(self.layer_feature)
        self.svg_rotate = Checkbox(description="Rotate network",
                                   value=self.net.config["svg_rotate"],
                                   style={"description_width": 'initial'},
                                   layout=Layout(width="52%"))
        self.svg_rotate.observe(lambda change: self.set_attr(self.net.config, "svg_rotate", change["new"]), names='value')
        self.save_config_button = Button(icon="save", layout=Layout(width="10%"))
        self.save_config_button.on_click(self.save_config)
        column2.append(HBox([self.svg_rotate, self.save_config_button]))
        config_children = HBox([VBox(column1, layout=Layout(width="100%")),
                                VBox(column2, layout=Layout(width="100%"))])
        accordion = Accordion(children=[config_children])
        accordion.set_title(0, self.net.name)
        accordion.selected_index = None
        return accordion

    def save_config(self, widget=None):
        self.net.save_config()

    def update_layer(self, change):
        """
        Update the layer object, and redisplay.
        """
        if self._ignore_layer_updates:
            return
        ## The rest indicates a change to a display variable.
        ## We need to save the value in the layer, and regenerate
        ## the display.
        # Get the layer:
        layer = self.net[self.layer_select.value]
        # Save the changed value in the layer:
        layer.feature = self.layer_feature.value
        layer.visible = self.layer_visible_checkbox.value
        ## These three, dealing with colors of activations,
        ## can be done with a prop_one():
        if "color" in change["owner"].description.lower():
            ## Matches: Colormap, lefmost color, rightmost color
            ## overriding dynamic minmax!
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.minmax = (self.layer_mindim.value, self.layer_maxdim.value)
            layer.colormap = self.layer_colormap.value if self.layer_colormap.value else None
            self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
            self.prop_one()
        else:
            self.regenerate()

    def update_layer_selection(self, change):
        """
        Just update the widgets; don't redraw anything.
        """
        ## No need to redisplay anything
        self._ignore_layer_updates = True
        ## First, get the new layer selected:
        layer = self.net[self.layer_select.value]
        ## Now, let's update all of the values without updating:
        self.layer_visible_checkbox.value = layer.visible
        self.layer_colormap.value = layer.colormap if layer.colormap != "" else ""
        self.layer_colormap_image.value = """<img src="%s"/>""" % self.net._image_to_uri(self.make_colormap_image(layer.colormap))
        minmax = layer.get_act_minmax()
        self.layer_mindim.value = minmax[0]
        self.layer_maxdim.value = minmax[1]
        self.layer_feature.value = layer.feature
        self._ignore_layer_updates = False
示例#27
0
class Dbfs(object):
    """Database browser implementation

    Args:
        dbutils (DBUtils): DBUtils object (for fs only)
    """
    def __init__(self, dbutils):
        self.dbutils = dbutils
        self.running = False
        self.path = None
        self.flist = None
        self.refresh = None
        self.path_view = None
        self.preview = None
        self.up = None

    def create(self, path="/", height="400px"):
        if self.running:
            print("dbfs browser already running. Use close() first")
            return
        self.path = path
        self.flist = Select(options=[],
                            disabled=False,
                            layout={"height": height})
        self.flist.observe(self.on_click, names="value")

        self.refresh = Button(icon="refresh", layout={"width": "40px"})
        self.refresh.on_click(self.on_refresh)
        self.path_view = Output()
        self.preview = Output(
            layout={
                "width": "800px",
                "height": height,
                "overflow": "scroll",
                "border": "1px solid gray",
            })

        self.up = Button(icon="arrow-up", layout={"width": "40px"})
        self.up.on_click(self.on_up)

        display(
            VBox([
                HBox([self.refresh, self.up, self.path_view]),
                HBox([self.flist, self.preview]),
            ]))

        self.update()
        self.running = True

    def convertBytes(self, fsize):
        """Convert bytes to largest unit

        Args:
            fsize (int): Size in bytes

        Returns:
            tuple: size of largest unit, largest unit
        """
        size = fsize
        unit = "B"
        if size > 1024 * 1024 * 1024 * 10:
            size = int(size / 1024.0 / 1024.0 / 1024.0)
            unit = "GB"
        elif size > 1024 * 1024 * 10:
            size = int(size / 1024.0 / 1024.0)
            unit = "MB"
        elif size > 1024 * 10:
            size = int(size / 1024.0)
            unit = "KB"
        return (size, unit)

    def update(self):
        """Update the view when an element was selected"""
        self.path_view.clear_output()
        self.preview.clear_output()
        with self.path_view:
            print("updating ...")
        try:
            fobjs = self.dbutils.fs.ls(self.path)
        except:  # pylint: disable=bare-except
            with self.path_view:
                print("Error: Cannot access folder")
            return False

        self.show_path(self.path)

        dirs = sorted([fobj.name for fobj in fobjs if fobj.isDir()],
                      key=lambda x: x.lower())
        files = sorted(
            [
                "%s (%d %s)" % ((fobj.name, ) + self.convertBytes(fobj.size))
                for fobj in fobjs if not fobj.isDir()
            ],
            key=lambda x: x[0].lower(),
        )
        self.flist.options = [""] + dirs + files
        return True

    def show_path(self, path):
        """Show path in output widget

        Args:
            path (str): Currently selected path
        """
        self.path_view.clear_output()
        with self.path_view:
            print("dbfs:" + re.sub(r"\s\(.*?\)$", "", path))

    def show_preview(self, path):
        """Show preview of csv, md or txt in output widget

        Args:
            path (str): Currently selected path
        """
        real_path = re.sub(r"\s\(.*?\)$", "", path)
        parts = real_path.split(".")
        if len(parts) > 0 and parts[-1].lower() in [
                "md",
                "html",
                "csv",
                "txt",
                "sh",
                "sql",
                "py",
                "scala",
                "json",
                "jpg",
                "jpeg",
                "png",
                "gif",
        ]:
            ext = parts[-1].lower()
            filename = "/dbfs" + real_path
            # text = self.dbutils.fs.head(real_path)
            self.preview.clear_output()
            with self.preview:
                if ext == "html":
                    display(HTML(filename=filename))
                elif ext in ["py", "sh", "sql", "scala"]:
                    display(Code(filename=filename))
                elif ext in ["jpg", "jpeg", "png", "gif"]:
                    display(Image(filename=filename))
                elif ext == "md":
                    display(Markdown(filename=filename))
                elif ext == "csv":
                    df = pd.read_csv(filename)
                    display(df)
                else:
                    with open(filename, "r") as fd:
                        print(fd.read())

    def on_refresh(self, _):
        """Refresh handler

        Args:
            b (ipywidgets.Button): clicked button
        """
        self.update()

    def on_up(self, _):
        """Up handler

        Args:
            b (ipywidgets.Button): clicked button
        """
        new_path = os.path.dirname(self.path.rstrip("/"))
        if new_path != self.path:
            self.path = new_path
        self.update()

    def on_click(self, change):
        """Click handler providing db and parent as context

        Args:
            db (str): database name
            parent (object): parent object
        """
        new_path = os.path.join(self.path, change["new"])
        if change["old"] is not None:
            if len(change["new"]) > 0 and change["new"][-1] == "/":
                old_path = self.path
                self.path = new_path
                if not self.update():
                    self.path = old_path
            else:
                self.show_path(new_path)
                self.show_preview(new_path)

    def close(self):
        """Close view"""
        self.running = False
示例#28
0
文件: widgets.py 项目: uday1889/conx
class SequenceViewer(VBox):
    """
    SequenceViewer

    Arguments:
        title (str) - Title of sequence
        function (callable) - takes an index 0 to length - 1. Function should
            a displayable or list of displayables
        length (int) - total number of frames in sequence
        play_rate (float) - seconds to wait between frames when auto-playing.
            Optional. Default is 0.5 seconds.

    >>> def function(index):
    ...     return [None]
    >>> sv = SequenceViewer("Title", function, 10)
    >>> ## Do this manually for testing:
    >>> sv.initialize()
    None
    >>> ## Testing:
    >>> class Dummy:
    ...     def update(self, result):
    ...         return result
    >>> sv.displayers = [Dummy()]
    >>> print("Testing"); sv.goto("begin") # doctest: +ELLIPSIS
    Testing...
    >>> print("Testing"); sv.goto("end") # doctest: +ELLIPSIS
    Testing...
    >>> print("Testing"); sv.goto("prev") # doctest: +ELLIPSIS
    Testing...
    >>> print("Testing"); sv.goto("next") # doctest: +ELLIPSIS
    Testing...

    """
    def __init__(self, title, function, length, play_rate=0.5):
        self.player = _Player(self, play_rate)
        self.player.start()
        self.title = title
        self.function = function
        self.length = length
        self.output = Output()
        self.position_text = IntText(value=0, layout=Layout(width="100%"))
        self.total_text = Label(value="of %s" % self.length, layout=Layout(width="100px"))
        controls = self.make_controls()
        super().__init__([controls, self.output])

    def goto(self, position):
        #### Position it:
        if position == "begin":
            self.control_slider.value = 0
        elif position == "end":
            self.control_slider.value = self.length - 1
        elif position == "prev":
            if self.control_slider.value - 1 < 0:
                self.control_slider.value = self.length - 1 # wrap around
            else:
                self.control_slider.value = max(self.control_slider.value - 1, 0)
        elif position == "next":
            if self.control_slider.value + 1 > self.length - 1:
                self.control_slider.value = 0 # wrap around
            else:
                self.control_slider.value = min(self.control_slider.value + 1, self.length - 1)
        elif isinstance(position, int):
            self.control_slider.value = position
        self.position_text.value = self.control_slider.value

    def toggle_play(self, button):
        ## toggle
        if self.button_play.description == "Play":
            self.button_play.description = "Stop"
            self.button_play.icon = "pause"
            self.player.resume()
        else:
            self.button_play.description = "Play"
            self.button_play.icon = "play"
            self.player.pause()

    def make_controls(self):
        button_begin = Button(icon="fast-backward", layout=Layout(width='100%'))
        button_prev = Button(icon="backward", layout=Layout(width='100%'))
        button_next = Button(icon="forward", layout=Layout(width='100%'))
        button_end = Button(icon="fast-forward", layout=Layout(width='100%'))
        self.button_play = Button(icon="play", description="Play", layout=Layout(width="100%"))
        self.control_buttons = HBox([
            button_begin,
            button_prev,
            self.position_text,
            button_next,
            button_end,
            self.button_play,
        ], layout=Layout(width='100%', height="50px"))
        self.control_slider = IntSlider(description=self.title,
                                        continuous_update=False,
                                        min=0,
                                        max=max(self.length - 1, 0),
                                        value=0,
                                        style={"description_width": 'initial'},
                                        layout=Layout(width='100%'))
        ## Hook them up:
        button_begin.on_click(lambda button: self.goto("begin"))
        button_end.on_click(lambda button: self.goto("end"))
        button_next.on_click(lambda button: self.goto("next"))
        button_prev.on_click(lambda button: self.goto("prev"))
        self.button_play.on_click(self.toggle_play)
        self.control_slider.observe(self.update_slider_control, names='value')
        controls = VBox([HBox([self.control_slider, self.total_text], layout=Layout(height="40px")),
                         self.control_buttons], layout=Layout(width='100%'))
        controls.on_displayed(lambda widget: self.initialize())
        return controls

    def initialize(self):
        results = self.function(self.control_slider.value)
        try:
            results = list(results)
        except:
            results = [results]
        self.displayers = [display(x, display_id=True) for x in results]

    def update_slider_control(self, change):
        if change["name"] == "value":
            self.position_text.value = self.control_slider.value
            self.output.clear_output(wait=True)
            results = self.function(self.control_slider.value)
            try:
                results = list(results)
            except:
                results = [results]
            for i in range(len(self.displayers)):
                self.displayers[i].update(results[i])
示例#29
0
文件: live.py 项目: weizoudm/rich
class Live(JupyterMixin, RenderHook):
    """Renders an auto-updating live display of any given renderable.

    Args:
        renderable (RenderableType, optional): The renderable to live display. Defaults to displaying nothing.
        console (Console, optional): Optional Console instance. Default will an internal Console instance writing to stdout.
        screen (bool, optional): Enable alternate screen mode. Defaults to False.
        auto_refresh (bool, optional): Enable auto refresh. If disabled, you will need to call `refresh()` or `update()` with refresh flag. Defaults to True
        refresh_per_second (float, optional): Number of times per second to refresh the live display. Defaults to 1.
        transient (bool, optional): Clear the renderable on exit. Defaults to False.
        redirect_stdout (bool, optional): Enable redirection of stdout, so ``print`` may be used. Defaults to True.
        redirect_stderr (bool, optional): Enable redirection of stderr. Defaults to True.
        vertical_overflow (VerticalOverflowMethod, optional): How to handle renderable when it is too tall for the console. Defaults to "ellipsis".
        get_renderable (Callable[[], RenderableType], optional): Optional callable to get renderable. Defaults to None.
    """

    def __init__(
        self,
        renderable: RenderableType = None,
        *,
        console: Console = None,
        screen: bool = False,
        auto_refresh: bool = True,
        refresh_per_second: float = 4,
        transient: bool = False,
        redirect_stdout: bool = True,
        redirect_stderr: bool = True,
        vertical_overflow: VerticalOverflowMethod = "ellipsis",
        get_renderable: Callable[[], RenderableType] = None,
    ) -> None:
        assert refresh_per_second > 0, "refresh_per_second must be > 0"
        self._renderable = renderable
        self.console = console if console is not None else get_console()
        self._screen = screen
        self._alt_screen = False

        self._redirect_stdout = redirect_stdout
        self._redirect_stderr = redirect_stderr
        self._restore_stdout: Optional[IO[str]] = None
        self._restore_stderr: Optional[IO[str]] = None

        self._lock = RLock()
        self.ipy_widget: Optional[Any] = None
        self.auto_refresh = auto_refresh
        self._started: bool = False
        self.transient = transient

        self._refresh_thread: Optional[_RefreshThread] = None
        self.refresh_per_second = refresh_per_second

        self.vertical_overflow = vertical_overflow
        self._get_renderable = get_renderable
        self._live_render = LiveRender(
            self.get_renderable(), vertical_overflow=vertical_overflow
        )
        # cant store just clear_control as the live_render shape is lazily computed on render

    def get_renderable(self) -> RenderableType:
        renderable = (
            self._get_renderable()
            if self._get_renderable is not None
            else self._renderable
        )
        return renderable or ""

    def start(self, refresh=False) -> None:
        """Start live rendering display.

        Args:
            refresh (bool, optional): Also refresh. Defaults to False.
        """
        with self._lock:
            if self._started:
                return
            self.console.set_live(self)
            self._started = True
            if self._screen:
                self._alt_screen = self.console.set_alt_screen(True)
            self.console.show_cursor(False)
            self._enable_redirect_io()
            self.console.push_render_hook(self)
            if refresh:
                self.refresh()
            if self.auto_refresh:
                self._refresh_thread = _RefreshThread(self, self.refresh_per_second)
                self._refresh_thread.start()

    def stop(self) -> None:
        """Stop live rendering display."""
        with self._lock:
            if not self._started:
                return
            self.console.clear_live()
            self._started = False
            try:
                if self.auto_refresh and self._refresh_thread is not None:
                    self._refresh_thread.stop()
                # allow it to fully render on the last even if overflow
                self.vertical_overflow = "visible"
                if not self._alt_screen:
                    if not self.console.is_jupyter:
                        self.refresh()
                    if self.console.is_terminal:
                        self.console.line()
            finally:
                self._disable_redirect_io()
                self.console.pop_render_hook()
                self.console.show_cursor(True)
                if self._alt_screen:
                    self.console.set_alt_screen(False)

        if self._refresh_thread is not None:
            self._refresh_thread.join()
            self._refresh_thread = None
        if self.transient and not self._screen:
            self.console.control(self._live_render.restore_cursor())
        if self.ipy_widget is not None:  # pragma: no cover
            if self.transient:
                self.ipy_widget.close()
            else:
                # jupyter last refresh must occur after console pop render hook
                # i am not sure why this is needed
                self.refresh()

    def __enter__(self) -> "Live":
        self.start(refresh=self._renderable is not None)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.stop()

    def _enable_redirect_io(self):
        """Enable redirecting of stdout / stderr."""
        if self.console.is_terminal:
            if self._redirect_stdout and not isinstance(sys.stdout, FileProxy):  # type: ignore
                self._restore_stdout = sys.stdout
                sys.stdout = FileProxy(self.console, sys.stdout)
            if self._redirect_stderr and not isinstance(sys.stderr, FileProxy):  # type: ignore
                self._restore_stderr = sys.stderr
                sys.stderr = FileProxy(self.console, sys.stderr)

    def _disable_redirect_io(self):
        """Disable redirecting of stdout / stderr."""
        if self._restore_stdout:
            sys.stdout = self._restore_stdout
            self._restore_stdout = None
        if self._restore_stderr:
            sys.stderr = self._restore_stderr
            self._restore_stderr = None

    @property
    def renderable(self) -> RenderableType:
        """Get the renderable that is being displayed

        Returns:
            RenderableType: Displayed renderable.
        """
        renderable = self.get_renderable()
        return Screen(renderable) if self._alt_screen else renderable

    def update(self, renderable: RenderableType, *, refresh: bool = False) -> None:
        """Update the renderable that is being displayed

        Args:
            renderable (RenderableType): New renderable to use.
            refresh (bool, optional): Refresh the display. Defaults to False.
        """
        with self._lock:
            self._renderable = renderable
            if refresh:
                self.refresh()

    def refresh(self) -> None:
        """Update the display of the Live Render."""
        self._live_render.set_renderable(self.renderable)
        if self.console.is_jupyter:  # pragma: no cover
            try:
                from IPython.display import display
                from ipywidgets import Output
            except ImportError:
                import warnings

                warnings.warn('install "ipywidgets" for Jupyter support')
            else:
                with self._lock:
                    if self.ipy_widget is None:
                        self.ipy_widget = Output()
                        display(self.ipy_widget)

                    with self.ipy_widget:
                        self.ipy_widget.clear_output(wait=True)
                        self.console.print(self._live_render.renderable)
        elif self.console.is_terminal and not self.console.is_dumb_terminal:
            with self._lock, self.console:
                self.console.print(Control(""))
        elif (
            not self._started and not self.transient
        ):  # if it is finished allow files or dumb-terminals to see final result
            with self.console:
                self.console.print(Control(""))

    def process_renderables(
        self, renderables: List[ConsoleRenderable]
    ) -> List[ConsoleRenderable]:
        """Process renderables to restore cursor and display progress."""
        self._live_render.vertical_overflow = self.vertical_overflow
        if self.console.is_interactive:
            # lock needs acquiring as user can modify live_render renderable at any time unlike in Progress.
            with self._lock:
                # determine the control command needed to clear previous rendering
                reset = (
                    Control.home()
                    if self._alt_screen
                    else self._live_render.position_cursor()
                )
                renderables = [
                    reset,
                    *renderables,
                    self._live_render,
                ]
        elif (
            not self._started and not self.transient
        ):  # if it is finished render the final output for files or dumb_terminals
            renderables = [*renderables, self._live_render]

        return renderables
示例#30
0
class App():
    """ """

    settings = {"enabled_grid": "B"}

    def __init__(self, session=None):

        self.session = session
        self.use_grid = self.settings["enabled_grid"]

        # generate map grid polygon layers
        self.grid_layers = LayerGroup()
        self.grid_dict = {}

        for feat in above_grid["features"]:
            level = feat["properties"]["grid_level"]
            if level == self.use_grid:
                Cell_object = Cell(feat)
                #Cell_object.layer.on_click()

                grid_id = Cell_object.id
                self.grid_dict[grid_id] = Cell_object
                self.grid_layers.add_layer(self.grid_dict[grid_id].layer)

        # make an attribute that will hold selected layer
        self.selected_layer = LayerGroup()

        self.map = Map(layers=(
            esri,
            self.grid_layers,
            self.selected_layer,
        ),
                       center=(65, -100),
                       zoom=3,
                       width="auto",
                       height="auto",
                       scroll_wheel_zoom=True)

        # map draw controls
        self.draw_control = DrawControl()
        self.draw_control.polyline = {}
        self.draw_control.circle = {}
        self.draw_control.circlemarker = {}
        self.draw_control.remove = False
        self.draw_control.edit = False
        self.draw_control.polygon = {**draw_style}
        self.draw_control.rectangle = {**draw_style}
        self.draw_control.on_draw(self.update_selected_cells)
        self.map.add_control(self.draw_control)

        # output display
        self.output = Output(layout=Layout(width="auto", height="auto"))

        # make the widget layout
        self.ui = VBox(
            [
                #header,
                #HBox([instruct, geojson_text]),
                self.map,
                self.output
            ],
            layout=Layout(width="auto"))

        # display ui
        display(self.ui)

    def update_selected_cells(self, *args, **kwargs):
        """ """
        # clear all draw and selection layers
        self.draw_control.clear()

        # --------------------------------------------------------------------
        # update active cells and make a big merged polgyon for selection

        # make shapely geom from geojson
        drawn_json = kwargs["geo_json"]
        shapely_geom = shape(drawn_json["geometry"])
        cells = self.grid_dict

        # iterate over cells and collect intersecting cells
        on = []
        for id, cell in cells.items():
            if shapely_geom.intersects(cell.shape):
                on.append(cell.shape)

        # this is blatant abuse of try/except; fix it
        try:
            # get the union of all of the cells that are toggled on
            union = cascaded_union(on)
            centroid = union.centroid

            # make layer that represents selected cells and add to selected_layer
            self.selected_layer.clear_layers()
            x, y = union.exterior.coords.xy
            self.selected_layer.add_layer(Polygon(locations=list(zip(y, x))))
            self.map.center = (centroid.y, centroid.x)

            # --------------------------------------------------------------
            # find all CMR collections that intersect with merged cells geom

            selected = []
            for index, collection in above_results_df.iterrows():
                box = collection.boxes
                shapely_box = CMR_box_to_Shapely_box(box[0])

                # intersect: use shapely_geom if strictly using drawn poly
                intersect_bool = shapely_box.intersects(union)
                if intersect_bool:
                    selected.append(index)

            self.coll = above_results_df.iloc[selected]

            self.tab = qgrid.show_grid(
                self.coll[["dataset_id", "time_start", "time_end", "boxes"]],
                grid_options={
                    'forceFitColumns': False,
                    'minColumnWidth': "0",
                    'maxColumnWidth': "400"
                },
                show_toolbar=False)

            self.output.clear_output()
            with self.output:
                display(self.tab)
                #display(self.coll[[
                #    "dataset_id", "time_start", "time_end", "boxes"]])

        except:
            pass
示例#31
0
class Annotator:
    """
    Interactive Ipython widget for annotating Set cards images.
    Annotations are iteratively saved as json files.

    Attributes
    ----------
    directory : str
        path to a directory containing images to annotate (jpg or png)
    output_directory : str
        directory to save annotations, if None (default) save to subdirectory DEFAULT_OUTPUT_SUBDIR

    """
    def __init__(self, directory, output_directory=None):
        self.input_directory = directory
        self.output_dir = self.set_output_dir(output_directory)
        self.examples = self.list_examples_to_annotate()
        self.cursor = 0
        self.annotations = []
        self.progress_message = HTML()
        self.output_message = Output()
        self.output_image = Output()
        self.label_buttons, self.submit_button = self.make_all_buttons()

    def annotate(self):
        """Run the annotation widgets."""
        self.set_progression_message()
        display(self.progress_message)
        display(self.output_image)
        self.initialize_label_buttons()
        self.display_all_buttons()
        display(self.output_message)
        self.show_next_example()
        self.submit_button.on_click(
            lambda button: self.on_button_clicked(button))

    def set_progression_message(self):
        nb_annotations = len(self.annotations)
        nb_remaining = len(self.examples) - self.cursor
        self.progress_message.value = f'{nb_annotations} example(s) annotated, {nb_remaining} example(s) remaining'

    def set_output_dir(self, output_dir: str) -> str:
        """Set output directory. If None set to default sub-directory."""
        if output_dir is None:
            output_dir = os.path.join(self.input_directory,
                                      DEFAULT_OUTPUT_SUBDIR)
        output_dir = os.path.abspath(output_dir)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            logger.info(f'created output directory: {output_dir}')
        return output_dir

    def list_examples_to_annotate(self) -> List[str]:
        """
        Return list of paths to images to annotate.

        Exclude already examples already annotated based on files in output_directory

        Returns
        -------
        List[str] of path to images to annotate

        """
        all_examples = list_images_in_directory(self.input_directory)
        already_annotated_image_names = Annotator.get_basenames_in_directory(
            directory=self.output_dir)
        return [
            example for example in all_examples
            if get_basename(example) not in already_annotated_image_names
        ]

    @staticmethod
    def get_basenames_in_directory(directory: str):
        """Return the list of the base names (filename without extension) of the files in given directory."""
        basenames = []
        for path in os.listdir(directory):
            basename = get_basename(path)
            basenames.append(basename)
        return basenames

    @staticmethod
    def make_all_buttons() -> (Dict[str, ToggleButtons], Button):
        """Build all necessary buttons.

        Return
        ------
        buttons : (dict of {str: ToggleButtons}, Button)
            label buttons (for number, color, shape and shading) and submit button

        """
        number_button = ToggleButtons(
            options=[('1', Number.ONE), ('2', Number.TWO),
                     ('3', Number.THREE)],
            description=ButtonName.NUMBER.value,
        )
        color_button = ToggleButtons(
            options=[('red', Color.RED), ('green', Color.GREEN),
                     ('purple', Color.PURPLE)],
            description=ButtonName.COLOR.value,
        )
        shape_button = ToggleButtons(
            options=[('oval', Shape.OVAL), ('diamond', Shape.DIAMOND),
                     ('squiggle', Shape.SQUIGGLE)],
            description=ButtonName.SHAPE.value,
        )
        shading_button = ToggleButtons(
            options=[('open', Shading.OPEN), ('striped', Shading.STRIPED),
                     ('solid', Shading.SOLID)],
            description=ButtonName.SHADING.value,
        )
        submit_button = Button(description=ButtonName.SUBMIT.value)

        label_buttons = {
            number_button.description: number_button,
            color_button.description: color_button,
            shape_button.description: shape_button,
            shading_button.description: shading_button,
        }
        return label_buttons, submit_button

    def display_all_buttons(self):
        for button in self.label_buttons.values():
            display(button)
        display(self.submit_button)

    def initialize_label_buttons(self):
        for button in self.label_buttons.values():
            button.value = None

    def disable_all_buttons(self):
        for button in self.label_buttons.values():
            button.disabled = True
        self.submit_button.disabled = True

    def show_next_example(self):
        self.set_progression_message()
        if self.cursor >= len(self.examples):
            self.disable_all_buttons()
            with self.output_message:
                print('Annotation completed.')
            return
        with self.output_image:
            self.output_image.clear_output()
            display(Image(self.examples[self.cursor], width=200))

    def on_button_clicked(self, but):
        responses = self.get_label_buttons_responses()
        missing_attributes = Annotator.get_missing_attributes(responses)
        self.output_message.clear_output()
        if missing_attributes:
            with self.output_message:
                print(f"Missing value for {missing_attributes}. Retry.")
        else:
            with self.output_message:
                annotation = Annotator.get_annotation_as_json_string(responses)
                current_example = self.examples[self.cursor]
                self.annotations.append((current_example, annotation))
                self.save_annotation(annotation, current_example)
                print(f"Annotation submitted: {annotation}")
            self.initialize_label_buttons()
            self.cursor += 1
            self.show_next_example()

    def get_label_buttons_responses(
            self) -> Dict[str, Union[Number, Color, Shape, Shading]]:
        """Return label buttons values."""
        return {
            att: button.value
            for att, button in self.label_buttons.items()
        }

    @staticmethod
    def get_annotation_as_json_string(
            response: Dict[str, Union[Number, Color, Shape, Shading]]) -> str:
        """Get label buttons responses as json string."""
        annotation = {
            att: button_response.value
            for att, button_response in response.items()
        }
        return json.dumps(annotation)

    def save_annotation(self, annotation: str, example: str) -> None:
        """Save annotation as json file"""
        basename = os.path.splitext(os.path.basename(example))[0]
        destination_path = os.path.join(self.output_dir, basename + '.json')
        with open(destination_path, 'w') as destination:
            destination.write(annotation + '\n')
            logger.info(f'annotation saved to {destination_path}')

    @staticmethod
    def get_missing_attributes(
        responses: Dict[str, Union[Number, Color, Shape,
                                   Shading]]) -> List[str]:
        """Return the list of labels button's name without response."""
        missing = []
        for att, response in responses.items():
            if response is None:
                missing.append(att)
        return missing
示例#32
0
class Replay(object):
    def __init__(self, debug, cad_width, height):
        self.debug_output = Output()
        self.debug = debug
        self.cad_width = cad_width
        self.height = height
        self.view = None

    def format_steps(self, raw_steps):
        def to_code(step, results):
            def to_name(obj):
                if isinstance(obj, cq.Workplane):
                    name = results.get(obj, None)
                else:
                    name = str(obj)
                return obj if name is None else name

            if step.func != "":
                if step.func == "newObject":
                    args = ("...", )
                else:
                    args = tuple([to_name(arg) for arg in step.args])
                code = "%s%s%s" % ("| " * step.level, step.func, args)
                code = code[:-2] if len(step.args) == 1 else code[:-1]
                if len(step.args) > 0 and len(step.kwargs) > 0:
                    code += ","
                if step.kwargs != {}:
                    code += ", ".join(
                        ["%s=%s" % (k, v) for k, v in step.kwargs.items()])
                code += ")"
                if step.result_name != "":
                    code += (" => %s" % step.result_name)
            elif step.var != "":
                code = "%s%s" % ("| " * step.level, step.var)
            else:
                code = ("ERROR")
            return code

        steps = []
        entries = []
        obj_index = 1

        results = {step.result_obj: None for step in raw_steps}

        for i in range(len(raw_steps)):
            step = raw_steps[i]
            next_level = step.level if i == (len(raw_steps) -
                                             1) else raw_steps[i + 1].level

            # level change, so add/use the variable name
            if step.level > 0 and step.level != next_level and step.result_name == "":
                obj_name = "_v%d" % obj_index
                obj_index += 1
                step.result_name = obj_name
            steps.append(step)

        for step in steps:
            if results[step.result_obj] is None:
                # first occurence, take note and keep
                results[step.result_obj] = step.result_name
            else:
                # next occurences remove function and add variable name
                step.var = results[step.result_obj]
                step.clear_func()

        last_level = 1000000
        for step in reversed(steps):
            if step.level < last_level:
                last_level = 1000000
                entries.insert(0, (to_code(step, results), step.result_obj))
                if step.var != "":
                    last_level = step.level

        return entries

    def to_array(self, workplane, level=0, result_name=""):
        def walk(caller, level=0, result_name=""):
            stack = [
                Step(level,
                     func=caller["func"],
                     args=caller["args"],
                     kwargs=caller["kwargs"],
                     result_name=result_name,
                     result_obj=caller["obj"])
            ]
            for child in reversed(caller["children"]):
                stack = walk(child, level + 1) + stack
                for arg in child["args"]:
                    if isinstance(arg, cq.Workplane):
                        result_name = getattr(arg, "name", None)
                        stack = self.to_array(arg,
                                              level=level + 2,
                                              result_name=result_name) + stack
            return stack

        stack = []

        obj = workplane
        while obj is not None:
            caller = getattr(obj, "_caller", None)
            result_name = getattr(obj, "name", "")
            if caller is not None:
                stack = walk(caller, level, result_name) + stack
                for arg in caller["args"]:
                    if isinstance(arg, cq.Workplane):
                        result_name = getattr(arg, "name", "")
                        stack = self.to_array(arg,
                                              level=level + 1,
                                              result_name=result_name) + stack
            obj = obj.parent

        return stack

    def select(self, indexes):
        with self.debug_output:
            self.indexes = indexes
            cad_objs = [self.stack[i][1] for i in self.indexes]

            # Save state
            axes = True if self.view is None else self.view.cq_view.axes.get_visibility(
            )
            grid = True if self.view is None else self.view.cq_view.grid.get_visibility(
            )
            axes0 = True if self.view is None else self.view.cq_view.axes.is_center(
            )
            ortho = True if self.view is None else self.view.cq_view.is_ortho()
            transparent = False if self.view is None else self.view.cq_view.is_transparent(
            )
            rotation = None if self.view is None else self.view.cq_view.camera.rotation
            zoom = None if self.view is None else self.view.cq_view.camera.zoom
            position = None if self.view is None else self.view.cq_view.camera.position
            # substract center out of position to be prepared for _scale function
            if position is not None:
                position = self.view.cq_view._sub(position,
                                                  self.view.cq_view.bb.center)

            # Show new view
            self.view = self.show(cad_objs, position, rotation, zoom, axes,
                                  grid, axes0, ortho, transparent)

    def select_handler(self, change):
        with self.debug_output:
            if change["name"] == "index":
                self.select(change["new"])

    def show(self,
             cad_objs,
             position,
             rotation,
             zoom,
             axes=True,
             grid=True,
             axes0=True,
             ortho=True,
             transparent=True):

        self.debug_output.clear_output()

        # Add hidden result to start with final size and allow for comparison
        if not isinstance(self.stack[-1][1].val(), cq.Vector):
            result = Part(self.stack[-1][1],
                          "Result",
                          show_faces=False,
                          show_edges=False)
            objs = [result] + cad_objs
        else:
            objs = cad_objs
        with self.debug_output:
            return show(*objs,
                        transparent=transparent,
                        axes=axes,
                        grid=grid,
                        axes0=axes0,
                        ortho=ortho,
                        cad_width=self.cad_width,
                        height=self.height,
                        show_parents=(len(cad_objs) == 1),
                        position=position,
                        rotation=rotation,
                        zoom=zoom)