Beispiel #1
0
def log_progress(sequence,
                 every=None,
                 size=None,
                 user_label=None,
                 refresh=False):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)  # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{index} / ?'.format(index=index)
                else:
                    progress.value = index
                    label.value = u'{index} / {size} - {user_label}'.format(
                        index=index, size=size, user_label=user_label)
            yield record
        if refresh: box.close()
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = str(index or '?')
class NBMasterBar(MasterBar):
    names = ['train', 'valid']
    def __init__(self, gen, total=None, hide_graph=False, order=None):
        super().__init__(gen, NBProgressBar, total)
        self.report = []
        self.text = HTML()
        self.vbox = VBox([self.first_bar.box, self.text])
        if order is None: order = ['pb1', 'text', 'pb2', 'graph']
        self.inner_dict = {'pb1':self.first_bar.box, 'text':self.text}
        self.hide_graph,self.order = hide_graph,order

    def on_iter_begin(self):
        self.start_t = self.last_t = time()
        display(self.vbox)

    def on_iter_end(self):
        #if hasattr(self, 'fig'): self.fig.clear()
        total_time = format_time(time() - self.start_t)
        end_report = f'Total time: {total_time}\n'
        max_len = 0
        for item in self.report:
            if len(item[0]) > max_len: max_len = len(item[0])
        for item in self.report:
            ending = f'  ({item[1]})\n' if item[1] != '' else '\n'
            end_report += item[0] + (' ' * (max_len-len(item[0]))) + ending
        self.vbox.close()
        print(end_report)

    def add_child(self, child):
        self.child = child
        self.inner_dict['pb2'] = self.child.box
        if hasattr(self,'out'): self.show(['pb1', 'pb2', 'text', 'graph'])
        else:                   self.show(['pb1', 'pb2', 'text'])

    def show(self, child_names):
        to_show = [name for name in self.order if name in child_names]
        self.vbox.children = [self.inner_dict[n] for n in to_show]

    def write(self, line):
        if hasattr(self, 'last_t'):
            cur_time = time()
            elapsed_time = format_time(cur_time - self.last_t)
            self.last_t = cur_time
        else: elapsed_time = ''
        self.report.append([line, elapsed_time])
        self.text.value += line + '<p>'

    def update_graph(self, graphs, x_bounds=None, y_bounds=None):
        if self.hide_graph: return
        self.out = widgets.Output()
        if not hasattr(self, 'fig'):
            self.fig, self.ax = plt.subplots(1, figsize=(6,4))
        self.out = widgets.Output()
        self.inner_dict['graph'] = self.out
        self.ax.clear()
        if len(self.names) < len(graphs): self.names += [''] * (len(graphs) - len(self.names))
        for g,n in zip(graphs,self.names): self.ax.plot(*g, label=n)
        self.ax.legend(loc='upper right')
        if x_bounds is not None: self.ax.set_xlim(*x_bounds)
        if y_bounds is not None: self.ax.set_ylim(*y_bounds)
        with self.out:
            clear_output(wait=True)
            display(self.ax.figure)
        if hasattr(self,'child') and self.child.is_active: self.show(['pb1', 'pb2', 'text', 'graph'])
        else: self.show(['pb1', 'text', 'graph'])