Пример #1
0
def run(autoencoder, classifier, optimizer, loader, split, epoch):
    predictions, targets, l_inf_diffs = list(), list(), list()
    tot_mix_loss, tot_ce_loss, tot_dl2_loss = Statistics.get_stats(3)
    tot_l2_loss, tot_stat_par, tot_eq_odds = Statistics.get_stats(3)

    progress_bar = tqdm(loader)

    for data_batch, targets_batch, protected_batch in progress_bar:
        batch_size = data_batch.shape[0]
        data_batch = data_batch.to(device)
        targets_batch = targets_batch.to(device)
        protected_batch = protected_batch.to(device)

        x_batches, y_batches = list(), list()
        assert batch_size % oracle.constraint.n_tvars == 0
        k = batch_size // oracle.constraint.n_tvars

        for i in range(oracle.constraint.n_tvars):
            x_batches.append(data_batch[i: i + k])
            y_batches.append(targets_batch[i: i + k])

        if split == 'train':
            autoencoder.train()
            classifier.train()

        latent_data = autoencoder.encode(data_batch)

        data_batch_dec = autoencoder.decode(latent_data)
        l2_loss = torch.norm(data_batch_dec - data_batch, dim=1)

        logits = classifier(latent_data)
        cross_entropy = binary_cross_entropy(logits, targets_batch)
        predictions_batch = classifier.predict(latent_data)

        stat_par = statistical_parity(predictions_batch, protected_batch)
        eq_odds = equalized_odds(
            targets_batch, predictions_batch, protected_batch
        )

        predictions.append(predictions_batch.detach().cpu())
        targets.append(targets_batch.detach().cpu())

        autoencoder.eval()
        classifier.eval()

        if oracle.constraint.n_gvars > 0:
            domains = oracle.constraint.get_domains(x_batches, y_batches)
            z_batches = oracle.general_attack(
                x_batches, y_batches, domains, num_restarts=1,
                num_iters=args.dl2_iters, args=args
            )
        else:
            z_batches = None

        latent_adv = autoencoder.encode(z_batches[0]).detach()
        l_inf_diffs.append(
            torch.abs(latent_data - latent_adv).max(1)[0].detach().cpu()
        )

        if split == 'train':
            autoencoder.train()
            classifier.train()

        _, dl2_loss, _ = oracle.evaluate(
            x_batches, y_batches, z_batches, args
        )
        mix_loss = torch.mean(
            cross_entropy + args.dl2_weight * dl2_loss +
            args.dec_weight * l2_loss
        )

        if split == 'train':
            optimizer.zero_grad()
            mix_loss.backward()
            optimizer.step()

        tot_ce_loss.add(cross_entropy.mean().item())
        tot_dl2_loss.add(dl2_loss.mean().item())
        tot_mix_loss.add(mix_loss.mean().item())
        tot_l2_loss.add(l2_loss.mean().item())
        tot_stat_par.add(stat_par.mean().item())
        tot_eq_odds.add(eq_odds.mean().item())

        progress_bar.set_description(
            f'[{split}] epoch={epoch:d}, ce_loss={tot_ce_loss.mean():.4f}, '
            f'dl2_loss={tot_dl2_loss.mean():.4f}, '
            f'mix_loss={tot_mix_loss.mean():.4f}'
        )

    predictions = torch.cat(predictions)
    targets = torch.cat(targets)
    l_inf_diffs = torch.cat(l_inf_diffs)

    accuracy = accuracy_score(targets, predictions)
    balanced_accuracy = balanced_accuracy_score(targets, predictions)
    tn, fp, fn, tp = confusion_matrix(targets, predictions).ravel()
    f1 = f1_score(targets, predictions)

    writer.add_scalar('Accuracy/%s' % split, accuracy, epoch)
    writer.add_scalar('Balanced Accuracy/%s' % split, balanced_accuracy, epoch)
    writer.add_scalar('Cross Entropy/%s' % split, tot_ce_loss.mean(), epoch)
    writer.add_scalar('Decoder Loss/%s' % split, tot_l2_loss.mean(), epoch)
    writer.add_scalar('DL2 Loss/%s' % split, tot_dl2_loss.mean(), epoch)
    writer.add_scalar('Loss/%s' % split, tot_mix_loss.mean(), epoch)
    writer.add_scalar('True Positives/%s' % split, tp, epoch)
    writer.add_scalar('False Negatives/%s' % split, fn, epoch)
    writer.add_scalar('True Negatives/%s' % split, tn, epoch)
    writer.add_scalar('False Positives/%s' % split, fp, epoch)
    writer.add_scalar('F1 Score/%s' % split, f1, epoch)
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)
    writer.add_scalar('Stat. Parity/%s' % split, tot_stat_par.mean(), epoch)
    writer.add_scalar('Equalized Odds/%s' % split, tot_eq_odds.mean(), epoch)
    writer.add_histogram('L-inf Differences/%s' % split, l_inf_diffs, epoch)

    return tot_mix_loss
Пример #2
0
def run(autoencoder, classifier, optimizer, loader, split):
    predictions, targets = list(), list()
    tot_ce_loss, tot_stat_par, tot_eq_odds = Statistics.get_stats(3)

    progress_bar = tqdm(loader)

    if args.adversarial:
        attack = PGD(classifier,
                     args.delta,
                     F.binary_cross_entropy_with_logits,
                     clip_min=float('-inf'),
                     clip_max=float('inf'))

    for data_batch, targets_batch, protected_batch in progress_bar:
        batch_size = data_batch.shape[0]
        data_batch = data_batch.to(device)
        targets_batch = targets_batch.to(device)
        protected_batch = protected_batch.to(device)

        x_batches, y_batches = list(), list()
        assert batch_size % oracle.constraint.n_tvars == 0
        k = batch_size // oracle.constraint.n_tvars

        for i in range(oracle.constraint.n_tvars):
            x_batches.append(data_batch[i:i + k])
            y_batches.append(targets_batch[i:i + k])

        if split == 'train':
            classifier.train()

        latent_data = autoencoder.encode(data_batch)

        if args.adversarial:
            latent_data = attack.attack(args.delta / 10,
                                        latent_data,
                                        20,
                                        targets_batch,
                                        targeted=False,
                                        num_restarts=1,
                                        random_start=True)

        logits = classifier(latent_data)
        ce_loss = binary_cross_entropy(logits, targets_batch)
        predictions_batch = classifier.predict(latent_data)

        stat_par = statistical_parity(predictions_batch, protected_batch)
        eq_odds = equalized_odds(targets_batch, predictions_batch,
                                 protected_batch)

        predictions.append(predictions_batch.detach().cpu())
        targets.append(targets_batch.detach().cpu())

        if split == 'train':
            optimizer.zero_grad()
            ce_loss.mean().backward()
            optimizer.step()

        tot_ce_loss.add(ce_loss.mean().item())
        tot_stat_par.add(stat_par.mean().item())
        tot_eq_odds.add(eq_odds.mean().item())

        progress_bar.set_description(
            f'[{split}] epoch={epoch:d}, ce_loss={tot_ce_loss.mean():.4f}')

    predictions = torch.cat(predictions)
    targets = torch.cat(targets)

    accuracy = accuracy_score(targets, predictions)
    balanced_accuracy = balanced_accuracy_score(targets, predictions)
    tn, fp, fn, tp = confusion_matrix(targets, predictions).ravel()
    f1 = f1_score(targets, predictions)

    writer.add_scalar('Accuracy/%s' % split, accuracy, epoch)
    writer.add_scalar('Balanced Accuracy/%s' % split, balanced_accuracy, epoch)
    writer.add_scalar('Cross Entropy/%s' % split, tot_ce_loss.mean(), epoch)
    writer.add_scalar('True Positives/%s' % split, tp, epoch)
    writer.add_scalar('False Negatives/%s' % split, fn, epoch)
    writer.add_scalar('True Negatives/%s' % split, tn, epoch)
    writer.add_scalar('False Positives/%s' % split, fp, epoch)
    writer.add_scalar('F1 Score/%s' % split, f1, epoch)
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)
    writer.add_scalar('Stat. Parity/%s' % split, tot_stat_par.mean(), epoch)
    writer.add_scalar('Equalized Odds/%s' % split, tot_eq_odds.mean(), epoch)

    return tot_ce_loss