def get_encoder_transfer_stats(args, source_model, target_model, loader, log):

    source_z_l2_norms = AverageMeter()
    target_z_l2_norms = AverageMeter()

    source_model.eval()
    target_model.eval()

    batch = tqdm(loader, total=len(loader) // loader.batch_size)
    for i, (X, y) in enumerate(batch):

        if args.gpu:
            X, y = X.cuda(), y.cuda()

        _, _, E = source_model(X)
        source_x_adv, source_e_adv, source_diff, source_max_diff = encoder_attack(
            X,
            source_model,
            args.num_steps,
            args.epsilon,
            args.alpha,
            random_restart=True)

        source_z_l2_norms.update(source_diff)
        _, _, E_adv = target_model(source_x_adv)
        _, _, E = target_model(X)
        l2 = torch.norm(E - E_adv, dim=-1, p=2).mean()
        target_z_l2_norms.update(l2)

        print("Src L2 {src_l2:3f}\t"
              "Tgt L2 {tgt_l2:3f}\t".format(src_l2=source_z_l2_norms.avg,
                                            tgt_l2=target_z_l2_norms.avg),
              file=log)
Exemple #2
0
def get_attack_stats(args, model, loader, log):

    batch_time = AverageMeter()
    clean_errors = AverageMeter()
    adv_errors = AverageMeter()
    l2_norms = AverageMeter(
    )  # l2 norm between representations for clean input and adversarial input, proxy for stability
    change_fraction = AverageMeter()

    model.eval()
    end = time.time()
    batch = tqdm(loader, total=len(loader) // loader.batch_size)
    for i, (X, y) in enumerate(batch):

        if args.gpu:
            X, y = X.cuda(), y.cuda()

        Z_clean = classifier.encoder(X, intermediate=True)

        # adv samples
        if args.attack == "pgd":
            X_adv, delta, out, out_adv = pgd(model=model,
                                             X=X,
                                             y=y,
                                             epsilon=args.epsilon,
                                             alpha=args.alpha,
                                             num_steps=args.num_steps,
                                             p='inf')

        elif args.attack == "fgsm":
            X_adv, delta, out, out_adv = fgsm(model=model,
                                              X=X,
                                              y=y,
                                              epsilon=args.epsilon)

        err_clean = (out.data != y).float().sum() / X.size(0)
        err_adv = (out_adv.data != y).float().sum() / X.size(0)

        clean_errors.update(err_clean)
        adv_errors.update(err_adv)

        # pass perturbed input through classifier's encoder, get perturbed representations
        Z_adv = classifier.encoder(X_adv, intermediate=True)

        Z_l2 = torch.norm(Z_clean - Z_adv, p=2, dim=-1, keepdim=True).mean()
        l2_norms.update(Z_l2)

        # compute fraction of l1_norm
        fraction = (torch.abs(Z_clean - Z_adv) / Z_clean)
        print(fraction.shape)
        print(fraction.max())
        change_fraction.update(fraction.max())

        batch.set_description("Clean Err {} Adv Err {} L2 {} Frac {}".format(
            clean_errors.avg, adv_errors.avg, l2_norms.avg,
            change_fraction.avg))

        # print to logfile
        print("clean_err: ",
              clean_errors.avg,
              " adv_err: ",
              adv_errors.avg,
              "l2 norm: ",
              l2_norms.avg,
              "l1 frac: ",
              change_fraction.avg,
              file=log)

    print(' * Clean Error {clean_error.avg:.3f}\t'
          ' Adv Error {adv_errors.avg:.3f}\t'
          ' L2 norm {l2_norms.avg:.3f}\t'
          ' L1 frac {change_frac.avg:.3f}\t'.format(
              clean_error=clean_errors,
              adv_errors=adv_errors,
              l2_norms=l2_norms,
              change_frac=change_fraction),
          file=log)
Exemple #3
0
def get_attack_stats(loader,
                     encoder,
                     decoder,
                     beta,
                     attack_type="l2_attack",
                     num_samples=10,
                     epsilon=0.05):

    clean_rate = AverageMeter()
    adv_rate = AverageMeter()

    clean_dist = AverageMeter()
    adv_dist = AverageMeter()

    batch = tqdm(loader, total=len(loader) // loader.batch_size)
    for i, (X, y) in enumerate(batch):
        if attack_type == "l2_attack":
            delta, mu, cov, loss = l2_attack(X,
                                             encoder,
                                             num_steps=100,
                                             alpha=0.1,
                                             epsilon=epsilon,
                                             p="inf")

        elif attack_type == "rate_attack":
            delta, mu, cov, loss = rate_attack(X,
                                               encoder,
                                               num_steps=100,
                                               alpha=0.1,
                                               epsilon=epsilon,
                                               p="inf")

        elif attack_type == "distortion_attack":
            delta, mu, cov, loss = distortion_attack(X,
                                                     encoder,
                                                     decoder=decoder,
                                                     num_steps=100,
                                                     alpha=0.1,
                                                     epsilon=epsilon,
                                                     p="inf")

        elif attack_type == "random_attack":
            delta, mu, cov, loss = random_attack(X,
                                                 encoder,
                                                 epsilon=epsilon,
                                                 p="inf")

        Z = torch.bmm(
            cov.repeat(num_samples, 1, 1),
            torch.randn(args.batch_size * num_samples, mu.shape[1],
                        1).to(args.device))

        Z = Z.squeeze() + mu.repeat(num_samples, 1)
        X_hat = decoder(Z)

        rate = 0.5 * (-log_diagonal_det(cov) - cov.shape[1] + trace(cov) +
                      torch.norm(mu, p=2, dim=1))
        distortion = torch.pow(X_hat - X.repeat(num_samples, 1, 1, 1),
                               2).sum(dim=(-1, -2, -3))
        adv_rate.update(rate, n=args.batch_size * num_samples)
        adv_dist.update(distortion, n=args.batch_size * num_samples)

        batch.set_description("Adv Rate {} Adv Distortion {}: ".format(
            adv_rate.avg, adv_dist.avg))

        X = X.repeat(num_samples, 1, 1, 1)
        mu, cov = encoder(X)
        Z = torch.bmm(
            cov,
            torch.randn(args.batch_size * num_samples, mu.shape[1],
                        1).to(args.device))
        Z = Z.squeeze() + mu
        X_hat = decoder(Z)

        rate = 0.5 * (-log_diagonal_det(cov) - cov.shape[1] + trace(cov) +
                      torch.norm(mu, p=2, dim=1))
        # note that distortion is calculated from one sample from the posterior
        distortion = torch.pow(X_hat - X, 2).sum(dim=(-1, -2, -3))

        clean_rate.update(rate, n=args.batch_size * num_samples)
        clean_dist.update(distortion, n=args.batch_size * num_samples)
        batch.set_description("Rate {} Distortion {}: ".format(
            clean_rate.avg, clean_dist.avg))

    # get histogram of rate and distortion for clean and adversarial
    # clean_rate.get_histogram(title="Rate", filename="rate_{}_{}.png".format(beta, attack_type))
    # adv_rate.get_histogram(title="Adv Rate", filename="adv_rate_{}_{}.png".format(beta, attack_type))
    # clean_dist.get_histogram(title="Distortion", filename="distortion_{}_{}.png".format(beta, attack_type))
    # adv_dist.get_histogram(title="Adv Distortion", filename="adv_distortion{}_{}.png".format(beta, attack_type))
    return clean_rate.avg, adv_rate.avg, clean_dist.avg, adv_dist.avg
Exemple #4
0
def get_attack_stats(args,
                     encoder,
                     classifier,
                     discriminator,
                     loader,
                     log,
                     type="class"):

    clean_errors = AverageMeter()
    adv_errors = AverageMeter()

    mi_meter = AverageMeter()
    mi_adv_adv_meter = AverageMeter()
    mi_adv_clean_meter = AverageMeter()
    mi_clean_adv_meter = AverageMeter()

    c_l2_norms = AverageMeter()
    c_l2_frac = AverageMeter()

    fc_l2_norms = AverageMeter()
    fc_l2_frac = AverageMeter()

    z_l2_norms = AverageMeter()
    z_l2_frac = AverageMeter()

    batch = tqdm(loader, total=len(loader) // loader.batch_size)
    for i, (X, y) in enumerate(batch):

        if args.gpu:
            X, y = X.cuda(), y.cuda()

        if type == "class" and classifier is not None:
            # adv samples using classifier
            if args.attack == "pgd":
                X_adv, delta, out, out_adv = pgd(model=classifier,
                                                 X=X,
                                                 y=y,
                                                 epsilon=args.epsilon,
                                                 alpha=args.alpha,
                                                 num_steps=args.num_steps,
                                                 p='inf')

            elif args.attack == "fgsm":
                X_adv, delta, out, out_adv = fgsm(model=classifier,
                                                  X=X,
                                                  y=y,
                                                  epsilon=args.epsilon)

        elif type == "encoder":
            X_adv, E_adv, diff, max_diff = encoder_attack(X,
                                                          encoder,
                                                          args.num_steps,
                                                          args.epsilon,
                                                          args.alpha,
                                                          random_restart=True)

            batch.set_description("Avg Diff {} Max Diff {}".format(
                diff, max_diff))

        elif type == "impostor":
            batch_size = X.shape[0]
            # using the given batch form X_s X_t pairs
            X_s = X[0:batch_size // 2]
            X_t = X[batch_size // 2:]
            # set y to the labels for X_s and X to X_s for later computation and logging
            y = y[0:batch_size // 2]
            X = X_s

            X_adv, E_adv, diff, min_diff = cw_infomax_encoder_attack(
                X_s,
                X_t,
                encoder=encoder,
                num_steps=2000,
                alpha=0.001,
                c=0.1,
                p=2)

            batch.set_description("Avg Diff {} Min Diff {}".format(
                diff, min_diff))

        elif type == "random":

            delta = torch.rand_like(X).sign() * args.epsilon
            X_adv = X + delta
            _, _, E = encoder(X)
            _, _, E_d = encoder(X_adv)
            norm = torch.norm(E - E_d, p=2, dim=-1)
            batch.set_description("Avg Diff {} Max Diff {} ".format(
                norm.mean(), norm.max()))

        if classifier is not None:
            # UPDATE CLEAN and ADV ERRORS
            logits_clean = classifier(X)
            logits_adv = classifier(X_adv)
            out = logits_clean.max(1)[1]
            out_adv = logits_adv.max(1)[1]

            err_clean = (out.data != y).float().sum() / X.size(0)
            err_adv = (out_adv.data != y).float().sum() / X.size(0)
            clean_errors.update(err_clean)
            adv_errors.update(err_adv)

        # UPDATE L2 NORM METERS
        C_clean, FC_clean, Z_clean = encoder(X)
        C_adv, FC_adv, Z_adv = encoder(X_adv)

        l2 = torch.norm(Z_clean - Z_adv, p=2, dim=-1, keepdim=True)
        fraction = (l2 / torch.norm(Z_clean, p=2, dim=-1, keepdim=True))
        z_l2_norms.update(l2.mean())
        z_l2_frac.update(fraction.mean())

        l2 = torch.norm(C_clean - C_adv, p=2, dim=(-1, -2, -3), keepdim=True)
        fraction = (l2 /
                    torch.norm(C_clean, p=2, dim=(-1, -2, -3), keepdim=True))
        c_l2_norms.update(l2.mean())
        c_l2_frac.update(fraction.mean())

        l2 = torch.norm(FC_clean - FC_adv, p=2, dim=-1, keepdim=True)
        fraction = (l2 / torch.norm(FC_clean, p=2, dim=-1, keepdim=True))
        fc_l2_norms.update(l2.mean())
        fc_l2_frac.update(fraction.mean())

        with torch.no_grad():
            # evaluate the critic scores for X and E
            mi, E = discriminator(X=X)

            # evaluate the critic scores for X_adv and E_adv
            mi_adv_adv, E_adv = discriminator(X=X_adv)

            # evaluate the critic scores for X_adv and E_clean
            mi_adv_clean, _ = discriminator(X_adv, E=E)

            # evaluate the critic scores for X, E_adv
            mi_clean_adv, _ = discriminator(X, E=E_adv)

        # UPDATE MI METERS
        mi_meter.update(mi)
        mi_adv_adv_meter.update(mi_adv_adv)
        mi_adv_clean_meter.update(mi_adv_clean)
        mi_clean_adv_meter.update(mi_clean_adv)

        batch.set_description(
            "MI(X, E) {} MI(X_adv, E_adv) {} MI(X_adv, E) {} MI(X, E_adv) {}".
            format(mi, mi_adv_adv, mi_adv_clean, mi_clean_adv))

        # print to logfile
        print("Error Clean {clean_errors.avg:.3f}\t Error Adv{adv_errors.avg:.3f}\t "
                "C L2 {c_l2_norms.avg:.3f}\t C L2 Frac{c_l2_frac.avg:.3f}\t"
                "FC L2 {fc_l2_norms.avg:.3f}\t FC L2 Frac{fc_l2_frac.avg:.3f}\t"
                "Z L2 {z_l2_norms.avg:.3f}\t Z L2 Frac{z_l2_frac.avg:.3f}\t"
                "MI(X, E) {mi.avg:.3f}\t MI(X_adv, E_adv) {mi_adv_adv.avg:.3f}\t "
                "MI(X_adv, E) {mi_adv_clean.avg:.3f}\t MI(X, E_adv) {mi_clean_adv.avg:.3f}\t".format(
              clean_errors=clean_errors, adv_errors=adv_errors,
              c_l2_norms=c_l2_norms, c_l2_frac=c_l2_frac,
              fc_l2_norms=fc_l2_norms, fc_l2_frac=fc_l2_frac,
              z_l2_norms=z_l2_norms, z_l2_frac=z_l2_frac,
              mi=mi_meter, mi_adv_adv=mi_adv_adv_meter, mi_adv_clean=mi_adv_clean_meter, mi_clean_adv=mi_clean_adv_meter), \
                      file=log)

        log.flush()
def get_classifier_transfer_stats(args, source_model, target_model, loader,
                                  log):

    source_clean_errors = AverageMeter()
    source_adv_errors = AverageMeter()
    target_clean_errors = AverageMeter()
    target_adv_errors = AverageMeter()

    # compute the attack success rate: fraction of changed predictions
    source_success_changed_meter = AverageMeter()
    target_success_changed_meter = AverageMeter()

    # compute the attack success rate: fraction of changed predictions from correct predictions
    source_success_correct_meter = AverageMeter()
    target_success_correct_meter = AverageMeter()

    # compute the fraction of successfully transferred attacks
    fraction_transfer_meter = AverageMeter()

    # compute the fraction of successfully transferred attacks with the same prediction
    fraction_transfer_same_meter = AverageMeter()

    source_model.eval()
    target_model.eval()

    end = time.time()
    batch = tqdm(loader, total=len(loader) // loader.batch_size)
    for i, (X, y) in enumerate(batch):

        if args.gpu:
            X, y = X.cuda(), y.cuda()

        # adv samples
        if args.attack == "pgd":
            X_adv, delta, source_out, source_out_adv = pgd(
                model=source_model,
                X=X,
                y=y,
                epsilon=args.epsilon,
                alpha=args.alpha,
                num_steps=args.num_steps,
                p='inf')

        elif args.attack == "fgsm":
            X_adv, delta, source_out, source_out_adv = fgsm(
                model=source_model, X=X, y=y, epsilon=args.epsilon)

        err_clean = (source_out.data != y).float().sum() / X.size(0)
        err_adv = (source_out_adv.data != y).float().sum() / X.size(0)

        # compute success rate overall
        source_success_all = (source_out_adv.data != source_out.data).float()

        # compute success rate for correctly classified inputs
        source_mask = (source_out.data == y).float()
        source_success_correct = (
            (source_out_adv.data != source_out.data).float() * source_mask)

        # update source model stats
        source_clean_errors.update(err_clean)
        source_adv_errors.update(err_adv)
        source_success_changed_meter.update(source_success_all.sum() /
                                            X.size(0))
        source_success_correct_meter.update(source_success_correct.sum() /
                                            source_mask.sum())

        # pass X_adv to target model and compute success rate
        target_out = target_model(X)
        target_out_adv = target_model(X_adv)
        err_clean = (target_out.data != y).float().sum() / X.size(0)
        err_adv = (target_out_adv.data != y).float().sum() / X.size(0)

        # compute success rate overall
        target_success_all = (target_out_adv.data != target_out.data).float()

        # compute success rate for correctly classified inputs
        target_mask = (target_out.data == y).float()
        target_success_correct = (
            (target_out_adv.data != target_out.data).float() * target_mask)

        # update target model stats
        target_clean_errors.update(err_clean)
        target_adv_errors.update(err_adv)
        target_success_changed_meter.update(target_success_all.sum() /
                                            X.size(0))
        target_success_correct_meter.update(target_success_correct.sum() /
                                            target_mask.sum())

        # compute fraction of successfully transferred attacks
        fraction_transfer_meter.update(
            (source_success_all * target_success_all).float().sum() /
            source_success_all.sum())

        # compute fraction of matched misclassifications from successfully transferred attacks
        fraction_transfer_same_meter.update((source_success_all * target_success_all).float() * \
        (target_out_adv == source_out_adv) / (source_success_all * target_success_all).float().sum())

        batch.set_description(
            "Adv Error{}, Transfer Rate {} Same Pred Rate {}".format(
                source_adv_errors, fraction_transfer_meter.avg,
                fraction_transfer_same_meter))

    print('Source Clean Error {source_clean_error.avg:.3f}\t'
          'Source Adv Error {source_adv_error.avg:.3f}\t'
          'Target Clean Error {target_clean_error.avg:.3f}\t'
          'Target Adv Error {target_adv_error.avg:.3f}\t'
          'Source Changed {source_success_changed.avg:.3f}\t'
          'Target Changed {target_success_changed.avg:.3f}\t'
          'Source Correct Changed {source_success_correct:.3f}\t'
          'Target Correct Changed {target_success_correct:.3f}\t'
          'Fraction Trasfer {fraction_transfer:.3f}\t'
          'Fraction Transfer Same {fraction_transfer_same:.3f}\t'.format(
              source_clean_error=source_clean_errors,
              source_adv_error=source_adv_errors,
              target_clean_error=target_clean_errors,
              target_adv_error=target_adv_errors,
              source_success_changed=source_success_changed_meter,
              target_success_changed=target_success_changed_meter,
              source_success_correct=source_success_correct_meter,
              target_success_correct=target_success_correct_meter,
              fraction_transfer=fraction_transfer_meter,
              fraction_transfer_same=fraction_transfer_same_meter),
          file=log)
def get_attack_stats(args, encoder, classifier, loader, log, type="class"):
    clean_errors = AverageMeter()
    adv_errors = AverageMeter()

    c_l2_norms = AverageMeter()
    c_l2_frac = AverageMeter()

    fc_l2_norms = AverageMeter()
    fc_l2_frac = AverageMeter()

    z_l2_norms = AverageMeter()
    z_l2_frac = AverageMeter()

    classifier.eval()

    batch = tqdm(loader, total=len(loader) // loader.batch_size)
    for i, (X, y) in enumerate(batch):

        if args.gpu:
            X, y = X.cuda(), y.cuda()

        if type == "class":
            # adv samples using classifier
            if args.attack == "pgd":
                X_adv, delta, out, out_adv = pgd(model=classifier, X=X, y=y, epsilon=args.epsilon,
                                                 alpha=args.alpha, num_steps=args.num_steps, p='inf')

            elif args.attack == "fgsm":
                X_adv, delta, out, out_adv = fgsm(model=classifier, X=X, y=y, epsilon=args.epsilon)

        elif type == "encoder":
            X_adv, E_adv, diff, max_diff = encoder_attack(X, encoder, args.num_steps, args.epsilon, args.alpha,
                                                          random_restart=True)

            batch.set_description("Avg Diff {} Max Diff {}".format(diff, max_diff))

            # run classifier on adversarial representations
            logits_clean = classifier(X)
            logits_adv = classifier(X_adv)
            out = logits_clean.max(1)[1]
            out_adv = logits_adv.max(1)[1]

        elif type == "impostor":
            batch_size = X.shape[0]
            # using the given batch form X_s X_t pairs
            X_s = X[0:batch_size // 2]
            X_t = X[batch_size // 2:]
            # set y to the labels for X_s and X to X_s for later computation and logging
            y = y[0:batch_size // 2]
            X = X_s

            X_adv, E_adv, diff, min_diff = source2target(X_s, X_t, encoder=encoder, epsilon=2.0,
                                                         max_steps=70000, step_size=0.001)

            # run classifier on adversarial representations
            logits_clean = classifier(X_s)
            logits_adv = classifier(X_adv)
            out = logits_clean.max(1)[1]
            out_adv = logits_adv.max(1)[1]

            batch.set_description("Avg Diff {} Min Diff {}".format(diff, min_diff))

        elif type == "random":

            delta = torch.rand_like(X).sign() * args.epsilon
            X_adv = X + delta
            _, _, E = encoder(X)
            _, _, E_d = encoder(X_adv)
            norm = torch.norm(E - E_d, p=2, dim=-1)

            # run classifier on adversarial representations
            logits_clean = classifier(X)
            logits_adv = classifier(X_adv)
            out = logits_clean.max(1)[1]
            out_adv = logits_adv.max(1)[1]

            batch.set_description("Avg Diff {} Max Diff {} ".format(norm.mean(), norm.max()))

        # UPDATE CLEAN and ADV ERRORS
        err_clean = (out.data != y).float().sum() / X.size(0)
        err_adv = (out_adv.data != y).float().sum() / X.size(0)
        clean_errors.update(err_clean)
        adv_errors.update(err_adv)

        # UPDATE L2 NORM METERS
        C_clean, FC_clean, Z_clean = encoder(X)
        C_adv, FC_adv, Z_adv = encoder(X_adv)

        l2 = torch.norm(Z_clean - Z_adv, p=2, dim=-1, keepdim=True)
        fraction = (l2 / torch.norm(Z_clean, p=2, dim=-1, keepdim=True))
        z_l2_norms.update(l2.mean())
        z_l2_frac.update(fraction.mean())

        l2 = torch.norm(C_clean - C_adv, p=2, dim=(-1, -2, -3), keepdim=True)
        fraction = (l2 / torch.norm(C_clean, p=2, dim=(-1, -2, -3), keepdim=True))
        c_l2_norms.update(l2.mean())
        c_l2_frac.update(fraction.mean())

        l2 = torch.norm(FC_clean - FC_adv, p=2, dim=-1, keepdim=True)
        fraction = (l2 / torch.norm(FC_clean, p=2, dim=-1, keepdim=True))
        fc_l2_norms.update(l2.mean())
        fc_l2_frac.update(fraction.mean())

        # print to logfile
        print("Error Clean {clean_errors.avg:.3f}\t Error Adv{adv_errors.avg:.3f}\t "
              "C L2 {c_l2_norms.avg:.3f}\t C L2 Frac{c_l2_frac.avg:.3f}\t"
              "FC L2 {fc_l2_norms.avg:.3f}\t FC L2 Frac{fc_l2_frac.avg:.3f}\t"
              "Z L2 {z_l2_norms.avg:.3f}\t Z L2 Frac{z_l2_frac.avg:.3f}\t".format(
            clean_errors=clean_errors, adv_errors=adv_errors,
            c_l2_norms=c_l2_norms, c_l2_frac=c_l2_frac,
            fc_l2_norms=fc_l2_norms, fc_l2_frac=fc_l2_frac,
            z_l2_norms=z_l2_norms, z_l2_frac=z_l2_frac),
            file=log)

        log.flush()
Exemple #7
0
def vae_transfer(src_vae, tgt_vae, loader, log, gpu):

    main_loss = AverageMeter()
    transfer_loss = AverageMeter()
    decoder_loss = AverageMeter()
    transfer_decoder_loss = AverageMeter()

    batch = tqdm(loader, total=len(loader) // loader.batch_size)
    for X, y in batch:
        if gpu:
            X = X.cuda()

        batch_size = X.shape[0]
        X_s = X[0:batch_size // 2]
        X_t = X[batch_size // 2:]
        delta, Z_b, loss = cw_vae_encoder_attack(X_s,
                                                 X_t,
                                                 encoder=src_encoder,
                                                 p=2,
                                                 num_steps=2500,
                                                 c=1.5,
                                                 alpha=0.001)

        # evaluate decoding wrt target image
        _, _, _, X_hat_tgt = src_vae(X_t)
        _, _, Z_b, X_hat_adv = src_vae(X_s + delta)
        decoder_loss.update(
            torch.norm(X_hat_adv - X_hat_tgt, p=2, dim=(-3, -2, -1)).mean())
        main_loss.update(loss)

        mu_tgt, cov_tgt = tgt_vae.encoder(X_t)
        Z_tgt = torch.bmm(
            cov_tgt,
            torch.randn(batch_size, mu_tgt.shape[1], 1).to(X.device))
        Z_tgt = Z_tgt.squeeze() + mu_tgt.squeeze()

        mu_adv, cov_adv = tgt_vae.encoder(X_s + delta)
        Z_adv = torch.bmm(
            cov_adv,
            torch.randn(batch_size, mu_adv.shape[1], 1).to(X.device))
        Z_adv = Z_adv.squeeze() + mu_adv.squeeze()

        X_hat_tgt = tgt_vae.decoder(Z_tgt)
        X_hat_adv = tgt_vae.decoder(Z_adv)
        loss = l2_wasserstein(mu_tgt, mu_adv, cov_tgt, cov_adv)
        transfer_decoder_loss.update(
            torch.norm(X_hat_adv - X_hat_tgt, p=2, dim=(-3, -2, -1)).mean())
        transfer_loss.update(loss)

    print("Encoder Loss: {}\t "
          "Transfer Encoder Loss: {}\t"
          "Decoder Matching Loss {}\t"
          "Transfer Decoder Matching Loss {}".format(
              main_loss.avg, transfer_loss.avg, decoder_loss.avg,
              transfer_decoder_loss.avg),
          file=log)

    log.flush()
Exemple #8
0
def infomax_transfer(src_encoder, tgt_encoder, src_decoder, tgt_decoder,
                     loader, log, gpu):

    main_loss = AverageMeter()
    transfer_loss = AverageMeter()
    decoder_loss = AverageMeter()
    transfer_decoder_loss = AverageMeter()

    batch = tqdm(loader, total=len(loader) // loader.batch_size)
    for X, y in batch:
        if gpu:
            X = X.cuda()

        batch_size = X.shape[0]
        X_s = X[0:batch_size // 2]
        X_t = X[batch_size // 2:]
        delta, Z_b, loss = cw_infomax_encoder_attack(X_s,
                                                     X_t,
                                                     encoder=src_encoder,
                                                     num_steps=2000,
                                                     alpha=0.001,
                                                     c=0.1,
                                                     p=2)

        # evaluate decoding wrt target image
        with torch.no_grad():
            _, _, Z_tgt = src_encoder(X_t)
            X_hat_tgt = src_decoder(Z_tgt)
            X_hat_adv = src_decoder(Z_b)

            recon_loss = torch.norm(X_hat_tgt - X_hat_adv,
                                    p=2,
                                    dim=(-3, -2, -1)).mean()
            decoder_loss.update(recon_loss)
            main_loss.update(loss)

            # compute transfer losses
            _, _, Z = tgt_encoder(X_s)
            _, _, Z_adv = tgt_encoder(X_s + delta)
            _, _, Z_tgt = tgt_encoder(X_t)

            X_hat_tgt = tgt_decoder(Z_tgt)
            loss = torch.norm(Z_tgt - Z_adv, dim=-1, p=2).mean()
            X_hat_adv = tgt_decoder(Z_adv)
            recon_loss = torch.norm(X_hat_tgt - X_hat_adv,
                                    p=2,
                                    dim=(-3, -2, -1)).mean()
            transfer_loss.update(loss)
            transfer_decoder_loss.update(recon_loss)

    print("Encoder Loss: {}\t "
          "Transfer Encoder Loss: {}\t"
          "Decoder Matching Loss {}\t"
          "Transfer Decoder Matching Loss {}".format(
              main_loss.avg, transfer_loss.avg, decoder_loss.avg,
              transfer_decoder_loss.avg),
          file=log)

    log.flush()