Пример #1
0
class IPythonDisplay(Lockable):
    def __init__(self, h):
        Lockable.__init__(self)
        self.H = h
        self.init()

    @synchronized
    def init(self):
        self.DisplayHandle = None
        self.VL = None

    @synchronized
    def display(self):
        self.DisplayHandle = DisplayHandle()
        self.VL = UpdatableVegaLite(self.H.vegalite())
        self.DisplayHandle.display(self.VL)

    @synchronized
    def update(self):
        if self.DisplayHandle is None or self.VL is None:
            self.display()
        else:
            specs = self.H.vegalite()
            self.VL.update_specs(specs)
            #print "updaded specs"
            self.DisplayHandle.update(self.VL)
Пример #2
0
 def __init__(self, md_string='', display=True, append=False):
     self.md_string = md_string
     self.append = append
     self.handle = DisplayHandle()
     self.last_string = []
     if display:
         self.display()
Пример #3
0
def update_plot(
    y: Metric,
    fig: plt.Figure,
    ax: plt.Axes,
    line: list[plt.Line2D],
    display_id: DisplayHandle,
    window: int = 15,
    logging_steps: int = 1,
):
    if not in_notebook():
        return

    y = np.array(y)
    if len(y.shape) == 2:
        y = y.mean(-1)

    yavg = moving_average(y.squeeze(), window=window)
    line[0].set_ydata(yavg)
    line[0].set_xdata(logging_steps * np.arange(yavg.shape[0]))
    #  line[0].set_ydata(y)
    #  line[0].set_xdata(plot_freq * np.arange(y.shape[0]))
    #  line[0].set_xdata(np.arange(len(yavg)))
    ax.relim()
    ax.autoscale_view()
    fig.canvas.draw()
    display_id.update(fig)
Пример #4
0
class DynHTML():
    """A output element for displaying HTML that can be updated."""
    def __init__(self, html_string='', display=True):
        self.html_string = html_string
        self.handle = DisplayHandle()

    def display(self):
        self.handle.display(StyledHTML(self.html_string))

    def update(self, html_string):
        self.html_string = html_string
        self.handle.update(StyledHTML(self.html_string))
Пример #5
0
class DynMarkdown():
    """A output element for displaying Markdown that can be updated."""
    def __init__(self, md_string='', display=True, append=False):
        self.md_string = md_string
        self.append = append
        self.handle = DisplayHandle()
        self.last_string = []
        if display:
            self.display()

    def display(self):
        self.handle.display(Markdown(self.md_string))

    def update(self, md_string):
        self.last_string.append(self.md_string)
        if len(self.last_string) > 10:
            self.last_string = self.last_string[-10:]
        if self.append:
            self.md_string += str(md_string) + '  \n'
        else:
            self.md_string = str(md_string)
        self.handle.update(Markdown(self.md_string))

    def clear(self):
        self.md_string = ''
        self.handle.update(Markdown(self.md_string))

    def revert_one(self):
        try:
            self.md_string = self.last_string.pop()
        except:
            self.md_string = ''
        self.handle.update(Markdown(self.md_string))
Пример #6
0
def update_joint_plots(
    plot_data1: LivePlotData,
    plot_data2: LivePlotData,
    display_id: DisplayHandle,
    window: int = 15,
    logging_steps: int = 1,
    fig: plt.Figure = None,
):
    if not in_notebook():
        return

    if fig is None:
        fig = plt.gcf()

    plot_obj1 = plot_data1.plot_obj
    plot_obj2 = plot_data2.plot_obj

    x1 = np.array(plot_data1.data).squeeze()  # type: np.ndarray
    x2 = np.array(plot_data2.data).squeeze()  # type: np.ndarray

    #  x1avg = x1.mean(-1) if len(x1.shape) == 2 else x1  # type: np.ndarray
    #  x2avg = x2.mean(-1) if len(x2.shape) == 2 else x2  # type: np.ndarray
    if len(x1.shape) == 2:
        x1avg = np.mean(x1, -1)
    else:
        x1avg = x1
    if len(x2.shape) == 2:
        x2avg = np.mean(x2, -1)
    else:
        x2avg = x2

    y1 = moving_average(x1avg, window=window)
    y2 = moving_average(x2avg, window=window)

    plot_obj1.line[0].set_ydata(np.array(y1))
    plot_obj1.line[0].set_xdata(logging_steps * np.arange(y1.shape[0]))

    plot_obj2.line[0].set_ydata(y2)
    plot_obj2.line[0].set_xdata(logging_steps * np.arange(y2.shape[0]))

    plot_obj1.ax.relim()
    plot_obj2.ax.relim()

    plot_obj1.ax.autoscale_view()
    plot_obj2.ax.autoscale_view()

    fig.canvas.draw()
    display_id.update(fig)  # need to force colab to update plot
Пример #7
0
 def __init__(self, epochs=None, train_steps=None, val_steps=None):
     self.epoch_progress = tqdm(total=epochs,
                                desc="Epoch",
                                unit=" epochs",
                                file=train.orig_stdout,
                                dynamic_ncols=True)
     self.train_progress = tqdm(total=train_steps,
                                desc="Train",
                                unit=" batch",
                                file=train.orig_stdout,
                                dynamic_ncols=True)
     self.val_progress = tqdm(total=val_steps,
                              desc="Validation",
                              unit=" batch",
                              file=train.orig_stdout,
                              dynamic_ncols=True)
     self.logs = None
     self.log_display = DisplayHandle("logs")
Пример #8
0
    def __init__(self, **kwargs):

        # Force terminal to xterm-256color because it should have broad support
        kwargs['term'] = 'xterm-256color'

        super(NotebookManager, self).__init__(**kwargs)

        # Force 24-bit color
        self.term.number_of_colors = 1 << 24

        self._converter = HTMLConverter(self.term)
        self._output = []
        self._display = DisplayHandle()
        self._html = HTML('')
        self._primed = False

        # Default width to 100 unless specified
        self.width = self._width or 100
Пример #9
0
class ProgressCallback(Callback):
    def __init__(self, epochs=None, train_steps=None, val_steps=None):
        self.epoch_progress = tqdm(total=epochs,
                                   desc="Epoch",
                                   unit=" epochs",
                                   file=train.orig_stdout,
                                   dynamic_ncols=True)
        self.train_progress = tqdm(total=train_steps,
                                   desc="Train",
                                   unit=" batch",
                                   file=train.orig_stdout,
                                   dynamic_ncols=True)
        self.val_progress = tqdm(total=val_steps,
                                 desc="Validation",
                                 unit=" batch",
                                 file=train.orig_stdout,
                                 dynamic_ncols=True)
        self.logs = None
        self.log_display = DisplayHandle("logs")

    def refresh(self):
        self.epoch_progress.refresh()
        self.train_progress.refresh()
        self.val_progress.refresh()

    def ncols(self):
        return self.train_progress.dynamic_ncols(self.train_progress.fp)

    def on_train_begin(self, logs=None):
        self.epoch_progress.n = 0
        self.epoch_progress.total = self.params["epochs"]
        self.train_progress.total = self.params["steps"]
        self.val_progress.total = len(self.learner.val_loader)
        self.refresh()

    def on_train_end(self, logs=None):
        self.epoch_progress.close()
        self.train_progress.close()
        self.val_progress.close()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_progress.update()
        self.epoch_progress.set_description("Epoch {}".format(epoch + 1))

    def on_epoch_end(self, epoch, logs=None):
        # logs = self.learner.metrics.summary()
        if self.logs is None:
            self.logs = pd.DataFrame(logs, index=[0])
        else:
            self.logs = self.logs.append(logs, ignore_index=True)

        self.log_display.update(self.logs.tail(1))
        # self.epoch_progress.set_postfix(logs)
        self.train_progress.n = 0
        self.val_progress.n = 0
        self.refresh()

    def on_train_batch_end(self, step, logs=None):
        logs = self.learner.metrics.summary()
        self.train_progress.update()
        self.train_progress.set_postfix(logs)

    def on_test_begin(self, logs=None):
        val_steps = len(self.learner.test_loader)
        self.val_progress = tqdm(total=val_steps,
                                 desc="Test",
                                 unit=" batch",
                                 file=train.orig_stdout,
                                 dynamic_ncols=True)

    def on_test_end(self, logs=None):
        if logs is not None:
            logs = pd.DataFrame([logs])
            self.log_display.update(logs)
        self.val_progress.close()

    def on_test_batch_end(self, step, logs=None):
        logs = self.learner.metrics.summary()
        self.val_progress.update()
        self.val_progress.set_postfix(logs)
Пример #10
0
 def __init__(self, html_string='', display=True):
     self.html_string = html_string
     self.handle = DisplayHandle()
Пример #11
0
class NotebookManager(BaseManager):
    """
    Args:
        counter_class(:py:term:`class`): Progress bar class (Default: :py:class:`Counter`)
        status_bar_class(:py:term:`class`): Status bar class (Default: :py:class:`StatusBar`)
        enabled(bool): Status (Default: True)
        width(int): Static output width (Default: 100)
        kwargs(Dict[str, Any]): Any additional :py:term:`keyword arguments<keyword argument>`
            will be used as default values when :py:meth:`counter` is called.

    Manager class for outputting progress bars to Jupyter notebooks

    The following keyword arguments are set if provided, but ignored:

      * *stream*
      * *set_scroll*
      * *companion_stream*
      * *no_resize*
      * *threaded*

    """
    def __init__(self, **kwargs):

        # Force terminal to xterm-256color because it should have broad support
        kwargs['term'] = 'xterm-256color'

        super(NotebookManager, self).__init__(**kwargs)

        # Force 24-bit color
        self.term.number_of_colors = 1 << 24

        self._converter = HTMLConverter(self.term)
        self._output = []
        self._display = DisplayHandle()
        self._html = HTML('')
        self._primed = False

        # Default width to 100 unless specified
        self.width = self._width or 100

    def __repr__(self):
        return '%s()' % self.__class__.__name__

    def _flush_streams(self):
        """
        Display buffered output
        """

        if not self.enabled:
            return

        self._html.data = '%s<div class="enlighten">\n%s\n</div>\n' % (
            self._converter.style, '\n'.join(reversed(self._output)))

        if self._primed:
            self._display.update(self._html)
        else:
            self._primed = True
            self._display.display(self._html)

    def stop(self):
        # See parent class for docstring

        if not self.enabled:
            return

        positions = self.counters.values()

        for num in range(max(positions), 0, -1):
            if num not in positions:
                self._output[num - 1] = '  <br>'

        for counter in self.counters:
            counter.enabled = False

        self._flush_streams()

    def write(self, output='', flush=True, counter=None, **kwargs):
        # See parent class for docstring

        if not self.enabled:
            return

        position = self.counters[counter] if counter else 1

        # If output is callable, call it with supplied arguments
        if callable(output):
            output = output(**kwargs)

        # If there is space between this bar and the last, fill with blank lines
        for _ in range(position - len(self._output)):
            self._output.append('  <br>')

        # Set output
        self._output[position -
                     1] = ('  <div class="enlighten-bar">\n    %s\n  </div>' %
                           self._converter.to_html(output))

        if flush:
            self._flush_streams()
Пример #12
0
 def display(self):
     self.DisplayHandle = DisplayHandle()
     self.VL = UpdatableVegaLite(self.H.vegalite())
     self.DisplayHandle.display(self.VL)
Пример #13
0
def _adversarial_train(
    X, y, z,
    C, A,
    cls_loss,
    cls_schedule,
    adv_loss,
    adv_schedule,
    parity,
    bootstrap_epochs=0,
    epochs=32,
    batch_size=32,
    display_progress=False
):

    N = len(X)

    alpha = K.placeholder(ndim=0, dtype='float32')
    lr = K.placeholder(ndim=0, dtype='float32')

    x_batch, y_batch, z_batch = (
        Dataset.from_tensor_slices((X, y, z))
            .shuffle(N)
            .repeat(2*epochs+bootstrap_epochs)
            .batch(batch_size)
            .make_one_shot_iterator()
            .get_next()
    )

    adversary_gradients = K.gradients(
        adv_loss(x_batch, y_batch, z_batch),
        A.trainable_weights
    )

    adversary_updates = [
        K.update_sub(w, lr * dw)
        for w, dw in zip(A.trainable_weights, adversary_gradients)
    ]

    update_adversary = K.function(
          [lr],
          [tf.constant(0)],
          updates=adversary_updates
    )

    classifier_gradients = K.gradients(
        cls_loss(x_batch, y_batch),
        C.trainable_weights
    )

    a_c_gradients = K.gradients(
        adv_loss(x_batch, y_batch, z_batch),
        C.trainable_weights
    )

    classifier_updates = [
        K.update_sub(w, lr * (dCdw - alpha * dAdw - _proj(dCdw, dAdw)))
        for w, dCdw, dAdw in zip(
            C.trainable_weights,
            classifier_gradients,
            a_c_gradients
        )
    ]

    update_classifier = K.function(
          [alpha, lr],
          [tf.constant(0)],
          updates=classifier_updates
    )

    if not display_progress:
        tqdm = _tqdm

    else:
        from matplotlib import pyplot as plt
        import matplotlib.ticker as ticker
        from IPython.display import DisplayHandle
        from IPython.display import clear_output

        from tqdm.notebook import tqdm as _tqdm_notebook
        tqdm = _tqdm_notebook

        dh = DisplayHandle()
        dh.display("Graphs loading ...")

        _X = Input((X.shape[1],))
        _Y = Input((y.shape[1],))
        _Z = Input((z.shape[1],))

        _classifier_loss = cls_loss(_X, _Y)
        _dcdcs = K.gradients(_classifier_loss, C.trainable_weights)

        _classifier_gradients = K.sum(
            [K.sum(K.abs(dcdc))
             for dcdc in _dcdcs]
        )

        _adversary_loss = adv_loss(_X, _Y, _Z)
        _dadas = K.gradients(_adversary_loss, A.trainable_weights)
        _adversary_gradients = K.sum(
            [K.sum(K.abs(dada))
             for dada in _dadas]
        )

        _dadcs = K.gradients(_adversary_loss, C.trainable_weights)

        _a_c_gradients = K.sum(
            [K.sum(K.abs(
                alpha * dadc -
                _proj(dcdc, dadc))) for dcdc, dadc in zip(_dcdcs, _dadcs)])

        _total_gradients = _classifier_gradients + _a_c_gradients

        cls_loss_f = K.function([_X, _Y], [_classifier_loss])
        cls_grad_f = K.function([_X, _Y], [_classifier_gradients])

        adv_loss_f = K.function([_X, _Y, _Z], [_adversary_loss])
        adv_grad_f = K.function([_X, _Y, _Z], [_adversary_gradients])

        a_c_grad_f = K.function([alpha, _X, _Y, _Z], [_a_c_gradients])
        total_grad_f = K.function([alpha, _X, _Y, _Z], [_total_gradients])

        adv_loss = []
        cls_loss = []

        adv_grad = []
        cls_grad = []
        a_c_grad = []
        total_grad = []

        adv_acc = []
        cls_acc = []
        dm_abs = []
        dm_rel = []
        dm_abs_ideal = []
        dm_rel_ideal = []
        dm_g0 = []
        dm_g1 = []
        base_class = []
        base_adv = []

        pred_range = []
        adv_range = []

        cls_lrs = []
        cls_alphas = []
        adv_lrs = []

        cls_xs = []
        adv_xs = []

        baseline_accuracy = max(y.mean(), 1-y.mean())
        baseline_adversary_accuracy = max(z.mean(), 1.0-z.mean())

    progress = tqdm(
        desc="training",
        unit="epoch",
        total=epochs,
        leave=False
    )

    def update_display(
        t,
        cls_lr=None, cls_alpha=None,
        adv_lr=None
    ):
        if not display_progress:
            return

        y_pred = C.predict(X)

        if cls_lr is not None:
            cls_xs.append(t)
            cls_lrs.append(cls_lr)
            cls_alphas.append(cls_alpha)

            cls_loss.append(cls_loss_f([X, y])[0])
            cls_grad.append(cls_grad_f([X, y])[0])
            a_c_grad.append(a_c_grad_f([cls_alpha, X, y, z])[0])
            total_grad.append(total_grad_f([cls_alpha, X, y, z])[0])

            y_acc = ((y_pred > 0.5) == y).mean()

            cls_acc.append(y_acc)
            base_class.append(baseline_accuracy)

            _dm = parity(y_pred > 0.5, y, z)
            dm_abs_ideal.append(0.0)
            dm_rel_ideal.append(1.0)
            dm_abs.append(abs(_dm[0]-_dm[1]))
            dm_rel.append(min(_dm[0],_dm[1])/(max(0.0001, _dm[0], _dm[1])))
            dm_g0.append(_dm[0])
            dm_g1.append(_dm[1])

            pred_range.append(y_pred.max() - y_pred.min())

        if adv_lr is not None:
            adv_xs.append(t)

            adv_lrs.append(adv_lr)

            adv_loss.append(adv_loss_f([X, y, z])[0])
            adv_grad.append(adv_grad_f([X, y, z])[0])

            z_pred = A.predict(x=[y_pred, z])
            z_acc = ((z_pred > 0.5) * 1 == z).mean()
            adv_acc.append(z_acc)

            base_adv.append(baseline_adversary_accuracy)
            adv_range.append(z_pred.max() - z_pred.min())

        fig, axs = plt.subplots(5, 1, figsize=(15, 15))

        axs1t = axs[1].twinx()
        axs2t = axs[2].twinx()
        axs3t = axs[3].twinx()

        axs[0].plot(cls_xs, cls_acc, label="classifier", color='green')
        axs[0].plot(cls_xs, base_class, label="baseline classifier", ls=':', color='green')
        axs[0].plot(adv_xs, adv_acc, label="adversary", color='red')
        axs[0].plot(adv_xs, base_adv, label="baseline adversary", ls=':', color='red')

        axs[1].plot(cls_xs, dm_abs, label="absolute disparity", color='red')
        axs[1].plot(cls_xs, dm_abs_ideal, label="ideal absolute disparity", color='red', ls=':')
        axs1t.plot(cls_xs, dm_rel, label="relative disparity", color='green')
        axs1t.plot(cls_xs, dm_rel_ideal, label="ideal relative disparity", color='green', ls=':')
        axs[1].plot(cls_xs, dm_g0, label="male positive", color="black")
        axs[1].plot(cls_xs, dm_g1, label="female positive", color="black")

        axs[2].plot(cls_xs, cls_loss, label="classifier cls loss", color='green')
        axs[2].plot(adv_xs, adv_loss, label="adversary loss", color='red')
        axs2t.plot(cls_xs, pred_range, label="classifier range", color='green', ls=':')
        axs2t.plot(adv_xs, adv_range, label="adversary range", color='red', ls=':')

        axs[4].plot(cls_xs, cls_lrs, label="classifier lr", color='green', ls=":")
        axs[4].plot(cls_xs, cls_alphas, label="classifier alpha", color='green', ls="-")
        axs[4].plot(adv_xs, adv_lrs, label="adversary lr", color='red', ls=":")

        axs[3].plot(cls_xs, total_grad, label="classifier (total) gradients", color='green')
        axs[3].plot(cls_xs, cls_grad, label="classifier (cls) gradients", color='green', ls=':')
        axs[3].plot(cls_xs, a_c_grad, label="classifier (adv) gradients", color='green', ls='-.')

        axs3t.plot(adv_xs, adv_grad, label="adversary gradients", color='red')

        axs[0].set_title("prediction performance")
        axs[1].set_title("fairness characteristics")
        axs[2].set_title("loss characteristics")
        axs[3].set_title("learning characteristics")
        axs[4].set_title("learning parameters")

        axs[2].set_yscale("log", basey=2.0)
        axs[3].set_yscale("symlog", basey=2.0)
        axs[4].set_yscale("symlog", basey=2.0)

        axs[0].set_ylabel("accuracy")

        axs[1].set_ylabel("outcome Pr")
        axs1t.set_ylabel("outcome (min group Pr) / (max group Pr)")
        axs[2].set_ylabel("loss")
        axs2t.set_ylabel("Pr or Pr diff")
        axs[3].set_ylabel("grad")
        axs3t.set_ylabel("grad")
        axs[4].set_ylabel("parameter val")

        axs1t.legend(loc=4)
        axs2t.legend(loc=4)
        axs3t.legend(loc=4)

        for axi, ax in enumerate(axs):
            ax.set_xlabel("t")
            ax.legend(loc=3)
            ax.yaxis.grid(True, which='major')

        dh.update(fig)

        plt.close(fig)

    adv_lr = adv_schedule(1)
    cls_alpha, cls_lr = cls_schedule(1)

    t = -bootstrap_epochs+1

    for _ in tqdm(
        range(bootstrap_epochs),
        desc="bootstrapping classifier",
        unit="epoch",
        leave=False
    ):
        for _ in tqdm(
            range(N // batch_size),
            desc="classifier",
            unit="batch",
            leave=False
        ):
            update_classifier([
                0.0,
                adv_lr
            ])

        update_display(t, cls_lr=cls_lr, cls_alpha=0.0)
        t += 1

    while True:
        try:
            adv_lr = adv_schedule(t)
            cls_alpha, cls_lr = cls_schedule(t)

            for _ in tqdm(
                range(N // batch_size),
                desc="adversary",
                unit="batch",
                leave=False
            ):
                update_adversary([adv_lr])

            update_display(t, adv_lr=adv_lr)

            for _ in tqdm(
                range(N // batch_size),
                desc="classifier",
                unit="batch",
                leave=False
            ):
                update_classifier([
                    cls_alpha,
                    cls_lr
                ])

            update_display(t, cls_lr=cls_lr, cls_alpha=cls_alpha)

            t += 1
            progress.update(1)

        except OutOfRangeError:
            break