Esempio n. 1
0
 def __init__(self, md_string='', display=True, append=False):
     self.md_string = md_string
     self.append = append
     self.handle = DisplayHandle()
     self.last_string = []
     if display:
         self.display()
Esempio n. 2
0
 def __init__(self, epochs=None, train_steps=None, val_steps=None):
     self.epoch_progress = tqdm(total=epochs,
                                desc="Epoch",
                                unit=" epochs",
                                file=train.orig_stdout,
                                dynamic_ncols=True)
     self.train_progress = tqdm(total=train_steps,
                                desc="Train",
                                unit=" batch",
                                file=train.orig_stdout,
                                dynamic_ncols=True)
     self.val_progress = tqdm(total=val_steps,
                              desc="Validation",
                              unit=" batch",
                              file=train.orig_stdout,
                              dynamic_ncols=True)
     self.logs = None
     self.log_display = DisplayHandle("logs")
Esempio n. 3
0
    def __init__(self, **kwargs):

        # Force terminal to xterm-256color because it should have broad support
        kwargs['term'] = 'xterm-256color'

        super(NotebookManager, self).__init__(**kwargs)

        # Force 24-bit color
        self.term.number_of_colors = 1 << 24

        self._converter = HTMLConverter(self.term)
        self._output = []
        self._display = DisplayHandle()
        self._html = HTML('')
        self._primed = False

        # Default width to 100 unless specified
        self.width = self._width or 100
Esempio n. 4
0
 def __init__(self, html_string='', display=True):
     self.html_string = html_string
     self.handle = DisplayHandle()
Esempio n. 5
0
 def display(self):
     self.DisplayHandle = DisplayHandle()
     self.VL = UpdatableVegaLite(self.H.vegalite())
     self.DisplayHandle.display(self.VL)
Esempio n. 6
0
def _adversarial_train(
    X, y, z,
    C, A,
    cls_loss,
    cls_schedule,
    adv_loss,
    adv_schedule,
    parity,
    bootstrap_epochs=0,
    epochs=32,
    batch_size=32,
    display_progress=False
):

    N = len(X)

    alpha = K.placeholder(ndim=0, dtype='float32')
    lr = K.placeholder(ndim=0, dtype='float32')

    x_batch, y_batch, z_batch = (
        Dataset.from_tensor_slices((X, y, z))
            .shuffle(N)
            .repeat(2*epochs+bootstrap_epochs)
            .batch(batch_size)
            .make_one_shot_iterator()
            .get_next()
    )

    adversary_gradients = K.gradients(
        adv_loss(x_batch, y_batch, z_batch),
        A.trainable_weights
    )

    adversary_updates = [
        K.update_sub(w, lr * dw)
        for w, dw in zip(A.trainable_weights, adversary_gradients)
    ]

    update_adversary = K.function(
          [lr],
          [tf.constant(0)],
          updates=adversary_updates
    )

    classifier_gradients = K.gradients(
        cls_loss(x_batch, y_batch),
        C.trainable_weights
    )

    a_c_gradients = K.gradients(
        adv_loss(x_batch, y_batch, z_batch),
        C.trainable_weights
    )

    classifier_updates = [
        K.update_sub(w, lr * (dCdw - alpha * dAdw - _proj(dCdw, dAdw)))
        for w, dCdw, dAdw in zip(
            C.trainable_weights,
            classifier_gradients,
            a_c_gradients
        )
    ]

    update_classifier = K.function(
          [alpha, lr],
          [tf.constant(0)],
          updates=classifier_updates
    )

    if not display_progress:
        tqdm = _tqdm

    else:
        from matplotlib import pyplot as plt
        import matplotlib.ticker as ticker
        from IPython.display import DisplayHandle
        from IPython.display import clear_output

        from tqdm.notebook import tqdm as _tqdm_notebook
        tqdm = _tqdm_notebook

        dh = DisplayHandle()
        dh.display("Graphs loading ...")

        _X = Input((X.shape[1],))
        _Y = Input((y.shape[1],))
        _Z = Input((z.shape[1],))

        _classifier_loss = cls_loss(_X, _Y)
        _dcdcs = K.gradients(_classifier_loss, C.trainable_weights)

        _classifier_gradients = K.sum(
            [K.sum(K.abs(dcdc))
             for dcdc in _dcdcs]
        )

        _adversary_loss = adv_loss(_X, _Y, _Z)
        _dadas = K.gradients(_adversary_loss, A.trainable_weights)
        _adversary_gradients = K.sum(
            [K.sum(K.abs(dada))
             for dada in _dadas]
        )

        _dadcs = K.gradients(_adversary_loss, C.trainable_weights)

        _a_c_gradients = K.sum(
            [K.sum(K.abs(
                alpha * dadc -
                _proj(dcdc, dadc))) for dcdc, dadc in zip(_dcdcs, _dadcs)])

        _total_gradients = _classifier_gradients + _a_c_gradients

        cls_loss_f = K.function([_X, _Y], [_classifier_loss])
        cls_grad_f = K.function([_X, _Y], [_classifier_gradients])

        adv_loss_f = K.function([_X, _Y, _Z], [_adversary_loss])
        adv_grad_f = K.function([_X, _Y, _Z], [_adversary_gradients])

        a_c_grad_f = K.function([alpha, _X, _Y, _Z], [_a_c_gradients])
        total_grad_f = K.function([alpha, _X, _Y, _Z], [_total_gradients])

        adv_loss = []
        cls_loss = []

        adv_grad = []
        cls_grad = []
        a_c_grad = []
        total_grad = []

        adv_acc = []
        cls_acc = []
        dm_abs = []
        dm_rel = []
        dm_abs_ideal = []
        dm_rel_ideal = []
        dm_g0 = []
        dm_g1 = []
        base_class = []
        base_adv = []

        pred_range = []
        adv_range = []

        cls_lrs = []
        cls_alphas = []
        adv_lrs = []

        cls_xs = []
        adv_xs = []

        baseline_accuracy = max(y.mean(), 1-y.mean())
        baseline_adversary_accuracy = max(z.mean(), 1.0-z.mean())

    progress = tqdm(
        desc="training",
        unit="epoch",
        total=epochs,
        leave=False
    )

    def update_display(
        t,
        cls_lr=None, cls_alpha=None,
        adv_lr=None
    ):
        if not display_progress:
            return

        y_pred = C.predict(X)

        if cls_lr is not None:
            cls_xs.append(t)
            cls_lrs.append(cls_lr)
            cls_alphas.append(cls_alpha)

            cls_loss.append(cls_loss_f([X, y])[0])
            cls_grad.append(cls_grad_f([X, y])[0])
            a_c_grad.append(a_c_grad_f([cls_alpha, X, y, z])[0])
            total_grad.append(total_grad_f([cls_alpha, X, y, z])[0])

            y_acc = ((y_pred > 0.5) == y).mean()

            cls_acc.append(y_acc)
            base_class.append(baseline_accuracy)

            _dm = parity(y_pred > 0.5, y, z)
            dm_abs_ideal.append(0.0)
            dm_rel_ideal.append(1.0)
            dm_abs.append(abs(_dm[0]-_dm[1]))
            dm_rel.append(min(_dm[0],_dm[1])/(max(0.0001, _dm[0], _dm[1])))
            dm_g0.append(_dm[0])
            dm_g1.append(_dm[1])

            pred_range.append(y_pred.max() - y_pred.min())

        if adv_lr is not None:
            adv_xs.append(t)

            adv_lrs.append(adv_lr)

            adv_loss.append(adv_loss_f([X, y, z])[0])
            adv_grad.append(adv_grad_f([X, y, z])[0])

            z_pred = A.predict(x=[y_pred, z])
            z_acc = ((z_pred > 0.5) * 1 == z).mean()
            adv_acc.append(z_acc)

            base_adv.append(baseline_adversary_accuracy)
            adv_range.append(z_pred.max() - z_pred.min())

        fig, axs = plt.subplots(5, 1, figsize=(15, 15))

        axs1t = axs[1].twinx()
        axs2t = axs[2].twinx()
        axs3t = axs[3].twinx()

        axs[0].plot(cls_xs, cls_acc, label="classifier", color='green')
        axs[0].plot(cls_xs, base_class, label="baseline classifier", ls=':', color='green')
        axs[0].plot(adv_xs, adv_acc, label="adversary", color='red')
        axs[0].plot(adv_xs, base_adv, label="baseline adversary", ls=':', color='red')

        axs[1].plot(cls_xs, dm_abs, label="absolute disparity", color='red')
        axs[1].plot(cls_xs, dm_abs_ideal, label="ideal absolute disparity", color='red', ls=':')
        axs1t.plot(cls_xs, dm_rel, label="relative disparity", color='green')
        axs1t.plot(cls_xs, dm_rel_ideal, label="ideal relative disparity", color='green', ls=':')
        axs[1].plot(cls_xs, dm_g0, label="male positive", color="black")
        axs[1].plot(cls_xs, dm_g1, label="female positive", color="black")

        axs[2].plot(cls_xs, cls_loss, label="classifier cls loss", color='green')
        axs[2].plot(adv_xs, adv_loss, label="adversary loss", color='red')
        axs2t.plot(cls_xs, pred_range, label="classifier range", color='green', ls=':')
        axs2t.plot(adv_xs, adv_range, label="adversary range", color='red', ls=':')

        axs[4].plot(cls_xs, cls_lrs, label="classifier lr", color='green', ls=":")
        axs[4].plot(cls_xs, cls_alphas, label="classifier alpha", color='green', ls="-")
        axs[4].plot(adv_xs, adv_lrs, label="adversary lr", color='red', ls=":")

        axs[3].plot(cls_xs, total_grad, label="classifier (total) gradients", color='green')
        axs[3].plot(cls_xs, cls_grad, label="classifier (cls) gradients", color='green', ls=':')
        axs[3].plot(cls_xs, a_c_grad, label="classifier (adv) gradients", color='green', ls='-.')

        axs3t.plot(adv_xs, adv_grad, label="adversary gradients", color='red')

        axs[0].set_title("prediction performance")
        axs[1].set_title("fairness characteristics")
        axs[2].set_title("loss characteristics")
        axs[3].set_title("learning characteristics")
        axs[4].set_title("learning parameters")

        axs[2].set_yscale("log", basey=2.0)
        axs[3].set_yscale("symlog", basey=2.0)
        axs[4].set_yscale("symlog", basey=2.0)

        axs[0].set_ylabel("accuracy")

        axs[1].set_ylabel("outcome Pr")
        axs1t.set_ylabel("outcome (min group Pr) / (max group Pr)")
        axs[2].set_ylabel("loss")
        axs2t.set_ylabel("Pr or Pr diff")
        axs[3].set_ylabel("grad")
        axs3t.set_ylabel("grad")
        axs[4].set_ylabel("parameter val")

        axs1t.legend(loc=4)
        axs2t.legend(loc=4)
        axs3t.legend(loc=4)

        for axi, ax in enumerate(axs):
            ax.set_xlabel("t")
            ax.legend(loc=3)
            ax.yaxis.grid(True, which='major')

        dh.update(fig)

        plt.close(fig)

    adv_lr = adv_schedule(1)
    cls_alpha, cls_lr = cls_schedule(1)

    t = -bootstrap_epochs+1

    for _ in tqdm(
        range(bootstrap_epochs),
        desc="bootstrapping classifier",
        unit="epoch",
        leave=False
    ):
        for _ in tqdm(
            range(N // batch_size),
            desc="classifier",
            unit="batch",
            leave=False
        ):
            update_classifier([
                0.0,
                adv_lr
            ])

        update_display(t, cls_lr=cls_lr, cls_alpha=0.0)
        t += 1

    while True:
        try:
            adv_lr = adv_schedule(t)
            cls_alpha, cls_lr = cls_schedule(t)

            for _ in tqdm(
                range(N // batch_size),
                desc="adversary",
                unit="batch",
                leave=False
            ):
                update_adversary([adv_lr])

            update_display(t, adv_lr=adv_lr)

            for _ in tqdm(
                range(N // batch_size),
                desc="classifier",
                unit="batch",
                leave=False
            ):
                update_classifier([
                    cls_alpha,
                    cls_lr
                ])

            update_display(t, cls_lr=cls_lr, cls_alpha=cls_alpha)

            t += 1
            progress.update(1)

        except OutOfRangeError:
            break