예제 #1
0
    def eval_model(self, sess, adv_level=0.05):
        test_gen = self.get_test_gen(sess)
        it = 0
        
        dropout_rate = self.dropout_rate
        x, y = self.x, self.y
        logits = self.logits
        accuracy_op = self.pred_op
        
        loss_1 = self.loss_op
        #loss_adv = tf.losses.softmax_cross_entropy(y, logits=adv_logits)
        #loss_z_l2 = tf.reduce_mean(tf.nn.l2_loss(myVAE.z - myVAE.z_in))

        #attack = attacks.LinfPGDAttack(loss_z_l2, myVAE.x_hat, 0.3, 30, 0.01, True)
        attack = attacks.LinfPGDAttack(loss_1, x, adv_level, 30, 0.01, True)
        #attack = attacks.LinfPGDAttack(loss_adv, myVAE.x_hat, 0.3, 30, 0.01, True)

        normal_avr, adv_avr_ref, adv_avr, reconstr_avr = 0, 0, 0, 0
        for x_test, y_test in test_gen:
            it = it + 1
            
            # Check Transfer
            #self.saver.restore(sess, './pretrained_models/cifar10/data/cifar10_classifier/model.ckpt-1000000')
            
            #adversarial_x = get_adv_dataset(sess, logits, x, y, x_test, y_test, adv_level)
            #adversarial_x = get_adv_dataset(sess, adv_logits, myVAE.x_hat, y, x_test, y_test)

            #_, z_in = myVAE.autoencode_dataset(sess, x_test)
            #adversarial_x = attack.perturb(x_test, y_test, myVAE.x_hat, y, sess, myVAE.z_in, z_in)
            
            adversarial_x = attack.perturb(x_test, y_test, x, y, sess)
            #adversarial_x = attack.perturb(x_test, y_test, myVAE.x_hat, y, sess)
            #cleaned_x, z_res = myVAE.autoencode_dataset(sess, adversarial_x)
            #print ('compare z vs z : ', (z_in[0] - z_res[0]), np.linalg.norm(z_in[0] - z_res[0]))

            #self.restore_session(sess)
            
            normal_avr += sess.run(accuracy_op, feed_dict={x: x_test, y: y_test, dropout_rate: 0.0})
            adv_avr_ref += sess.run(accuracy_op, feed_dict={x: adversarial_x, y: y_test, dropout_rate: 0.0})

            #adv_avr += sess.run(adv_acc_op, feed_dict={myVAE.x_hat: adversarial_x, y: y_test, dropout_rate: 0.0})
            #reconstr_avr += sess.run(accuracy_op, feed_dict={x: cleaned_x, y: y_test, dropout_rate: 0.0})

            if( it % 10 == 3 ):
                #test_pred = sess.run(logits, feed_dict={x: cleaned_x, y: y_test, dropout_rate: 0.0})

                #i1 = np.argmax(test_pred, 1)
                #i2 = np.argmax(y_test, 1)
                #index = np.where(np.not_equal(i1, i2))[0]

                #p_size = len(index)
                p_size = x_test.shape[0]
                wrong_x = x_test[:, :]
                wrong_adv = adversarial_x[:, :]
                #wrong_res = cleaned_x[:, :]
                
                self.test_generate(wrong_x, p_size, 'images/cl_original.png')
                self.test_generate(wrong_adv, p_size, 'images/cl_adversarial.png')
                
        print ("------------ Test ----------------")
        print("Normal Accuracy:", normal_avr / it)
        print("Normal Adversarial Accuracy:", adv_avr_ref / it)
예제 #2
0
    def test_attack(self, attack_type):
        # Load the trained generator.
        self.restore_model(self.test_iters)

        # Set data loader.
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader

        # Initialize Metrics
        l1_error, l2_error, min_dist, l0_error = 0.0, 0.0, 0.0, 0.0
        n_dist, n_samples = 0, 0

        for i, (x_real, c_org) in enumerate(data_loader):
            # Prepare input images and target domain labels.
            x_real = x_real.to(self.device)
            c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset,
                                            self.selected_attrs)

            if attack_type == 'PGD':
                curr_attack = attacks.LinfPGDAttack(model=self.G,
                                                    device=self.device,
                                                    feat=None)
            elif attack_type == 'FGSM':
                curr_attack = attacks.FGSM(model=self.G,
                                           device=self.device,
                                           feat=None)
            elif attack_type == 'iFGSM':
                curr_attack = attacks.iFGSM(model=self.G,
                                            device=self.device,
                                            feat=None)

            # Translated images.
            x_fake_list = [x_real]
            tide = 0

            for idx, c_trg in enumerate(c_trg_list):
                print('image', i, 'class', idx)
                with torch.no_grad():
                    x_real_mod = x_real
                    # x_real_mod = self.blur_tensor(x_real_mod) # use blur
                    gen_noattack, gen_noattack_feats = self.G(
                        x_real_mod, c_trg)

                # Attacks
                x_adv, perturb = curr_attack.perturb(x_real, gen_noattack,
                                                     c_trg)

                # Generate adversarial example
                x_adv = x_real + perturb
                # torch.save(perturb, 'tensor.pt')

                # Metrics
                with torch.no_grad():
                    gen, _ = self.G(x_adv, c_trg)

                    # Add to lists
                    if tide == 0:
                        x_fake_list.append(x_adv)
                    x_fake_list.append(gen)

                    l1_error += F.l1_loss(gen, gen_noattack)
                    l2_error += F.mse_loss(gen, gen_noattack)
                    l0_error += (gen - gen_noattack).norm(0)
                    min_dist += (gen - gen_noattack).norm(float('-inf'))
                    if F.mse_loss(gen, gen_noattack) > 0.05:
                        n_dist += 1
                    n_samples += 1

                tide += 1

            # Save the translated images.
            x_concat = torch.cat(x_fake_list, dim=3)
            result_path = os.path.join(self.result_dir,
                                       '{}-images.jpg'.format(i + 1))
            save_image(self.denorm(x_concat.data.cpu()),
                       result_path,
                       nrow=1,
                       padding=0)
            if i == 49:  # stop after this many images
                break

        # Print metrics
        print('{} images. L1 error: {}. L2 error: {}.'.format(
            n_samples, l1_error / n_samples, l2_error / n_samples))
예제 #3
0
def train_good(model, train_loader_in, train_loader_out, val_loader_in,
               val_loader_out_list, device, expfolder, args):
    train_out_name = train_loader_out.dataset.__repr__().split()[1]
    train_in_name = train_loader_in.dataset.__repr__().split()[1]
    print(model.layers)
    starttime = datetime.datetime.utcnow()

    schedule = schedules.schedule(args)

    print(f'schedule: {schedule}')
    model_folder = names.model_folder_name(expfolder, starttime, args,
                                           schedule)
    for subfolder in ['state_dicts', 'sample_images', 'logs', 'batch_images']:
        os.makedirs(f'{model_folder}/{subfolder}/', exist_ok=True)
    tb_subfolder = f'tb_logs/{args.tb_folder}/{model_folder}'
    os.makedirs(tb_subfolder, exist_ok=True)
    writer = SummaryWriter(tb_subfolder)
    print(f'model folder: {model_folder}')
    print(f'tb_subfolder: {tb_subfolder}')

    trainstart_message = f'Training {model.__name__} for {schedule["epochs"]} epochs of {2*min(len(train_loader_in.dataset), len(train_loader_out.dataset))} samples.'
    print(trainstart_message)

    if schedule['optimizer'] == 'SGDM':
        optimizer = optim.SGD(model.parameters(),
                              lr=schedule['start_lr'],
                              weight_decay=0.05,
                              momentum=0.9,
                              nesterov=True)
    elif schedule['optimizer'] == 'ADAM':
        optimizer = optim.Adam(model.parameters(),
                               lr=schedule['start_lr'],
                               weight_decay=0.005)
    else:
        raise ValueError(
            f'Optimizer {schedule["optimizer"]} not supported. Must be SGDM or ADAM.'
        )
    print(f'Optimizer settings: {optimizer}')
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        schedule['lr_decay_epochs'],
        gamma=schedule['lr_decay_factor'],
        last_epoch=-1)

    num_classes = model.num_classes

    for epoch in range(schedule['epochs']):
        #initialize epoch summary
        in_samples_this_epoch = 0
        ce_losses_in_epoch = []
        log_confs_in_epoch = []
        corrects_in_this_epoch = 0

        out_samples_this_epoch = 0
        above_quantile_losses_out_epoch, below_quantile_losses_out_epoch, full_good_losses_out_epoch, zero_good_losses_out_epoch, good_losses_out_epoch = [], [], [], [], []
        oe_losses_out_epoch, ceda_losses_out_epoch = [], []
        losses_out_epoch, kappa_losses_out_epoch = [], []
        log_confs_out_epoch, ub_log_confs_out_epoch = [], []
        acet_losses_out_epoch = []

        losses_epoch = []

        #hyperparameters for this epoch
        if schedule['kappa_epoch_ramp'] == 0:
            kappa_epoch = schedule['kappa'] * (epoch >=
                                               schedule['out_start_epoch'])
        else:
            kappa_epoch = schedule['kappa'] * min(
                max(epoch - schedule['out_start_epoch'], 0) /
                schedule['kappa_epoch_ramp'], 1)
        if schedule['eps_epoch_ramp'] == 0:
            eps_epoch = schedule['eps'] * (epoch >=
                                           schedule['eps_start_epoch'])
        else:
            eps_epoch = schedule['eps'] * min(
                (max(epoch - schedule['eps_start_epoch'], 0) /
                 schedule['eps_epoch_ramp']), 1)

        #if acet is turned on, it will be used in addition to the args.method
        if args.acet:
            if args.acet == 'ce':
                acet_lossfn = lossfunctions.CrossEntropyLossDistr
            elif args.acet == 'lc':
                acet_lossfn = lossfunctions.LogConf
            pgd = attacks.LinfPGDAttack(epsilon=eps_epoch,
                                        n=schedule['acet_n'],
                                        loss_fn=acet_lossfn,
                                        random_start=False,
                                        device=device)
        do_acet_epoch = args.acet and kappa_epoch > 0

        model.train()
        for batch_number, data in enumerate(
                zip(train_loader_in, train_loader_out), 0):
            img_batch_parts, lbl_batch_parts = [
                d[0].to(device) for d in data
            ], [d[1].to(device) for d in data]
            img_batch_in = img_batch_parts[0].to(device)
            img_batch_out = img_batch_parts[1].to(device)
            lbl_batch_in = lbl_batch_parts[0].to(device)
            lbl_batch_in_1hot = F.one_hot(lbl_batch_in, num_classes).float()
            lbl_batch_out = 1 / num_classes * torch.ones(
                lbl_batch_parts[1].size() + (num_classes, ),
                dtype=lbl_batch_parts[1].dtype).to(
                    device)  #set uniform label as it represents optimum

            batch_size_in = len(img_batch_in)
            batch_size_out = len(img_batch_out)

            in_samples_this_epoch += batch_size_in
            out_samples_this_epoch += batch_size_out

            #save example batch
            if epoch == 0 and batch_number == 0:
                vutils.save_image(img_batch_in,
                                  model_folder + '/batch_images/in_batch0.png')
                vutils.save_image(
                    img_batch_out,
                    model_folder + '/batch_images/out_batch0.png')

            optimizer.zero_grad()  #resets the calculated gradients

            logit_batch_in = model(img_batch_in)

            ce_loss_in = ce_loss(logit_batch_in, lbl_batch_in)
            ce_losses_in_epoch.append(
                ce_loss_in.detach().cpu().numpy())  #tracking
            p_in = logit_batch_in.softmax(dim=-1)  #tracking
            _, predicted_class_in = logit_batch_in.max(dim=-1)  #tracking
            corrects_in_this_epoch += predicted_class_in.eq(
                lbl_batch_in).sum().item()  #tracking

            do_acet_epoch = args.acet and kappa_epoch > 0
            if do_acet_epoch:
                if eps_epoch > 0:
                    adv_batch_out, _ = pgd.perturbt(img_batch_out,
                                                    lbl_batch_out, model)
                    model.train(
                    )  #to make sure it isn't set to eval after the attack
                else:
                    adv_batch_out = img_batch_out
                logit_adv_batch_out = model(adv_batch_out)
                acet_losses_indiv = acet_lossfn(logit_adv_batch_out,
                                                lbl_batch_out)
                log_conf_adv_out = logit_adv_batch_out.log_softmax(dim=-1).max(
                    dim=-1)[0]
                acet_loss_out = acet_losses_indiv.sum()
                acet_losses_out_epoch.append(
                    acet_loss_out.detach().cpu().numpy())  #tracking

            #calculate losses on the OOD inputs
            if args.method in {'OE', 'CEDA'}:
                logit_batch_out = model(img_batch_out)
                log_conf_out_batch = logit_batch_out.log_softmax(dim=-1).max(
                    dim=-1)[0]

                ceda_loss_out = log_conf_out_batch.sum()
                log_pred_out = logit_batch_out.log_softmax(dim=-1)
                oe_loss_out = -(log_pred_out / num_classes).sum()
                ceda_losses_out_epoch.append(
                    ceda_loss_out.detach().cpu().numpy())  #tracking
                oe_losses_out_epoch.append(
                    oe_loss_out.detach().cpu().numpy())  #tracking
                log_confs_out_epoch.append(
                    log_conf_out_batch.detach().cpu().numpy())

            if args.method == 'GOOD':
                l_logits_batch_out, u_logits_batch_out, ud_logit_out_batch = model.ibp_elision_forward(
                    img_batch_out - eps_epoch, img_batch_out + eps_epoch,
                    num_classes)
                ub_log_conf_out_batch = ud_logit_out_batch.max(dim=-1)[0].max(
                    dim=-1)[0]
                ub_conf_out_batch = ub_log_conf_out_batch.exp() / num_classes
                logit_batch_out = model(img_batch_out)
                logit_diff = logit_batch_out.max(
                    dim=-1
                )[0] - logit_batch_out.min(
                    dim=-1
                )[0]  #equals ud_logit_out_batch.max(dim=-1)[0].max(dim=-1)[0] for eps=0, but only needs 1 pass.
                l = math.floor(batch_size_out * args.good_quantile)
                h = batch_size_out - l
                above_quantile_indices = ub_log_conf_out_batch.topk(
                    h, largest=True)[
                        1]  #above or exactly at quantile, i.e. 'not below'.
                below_quantile_indices = ub_log_conf_out_batch.topk(
                    l, largest=False)[1]

                above_quantile_loss_out = (
                    (logit_diff[above_quantile_indices])**2 / 2).log1p().sum()
                below_quantile_loss_out = (
                    (ub_log_conf_out_batch[below_quantile_indices])**2 /
                    2).log1p().sum()
                good_loss_out = above_quantile_loss_out + below_quantile_loss_out

                #for tracking only
                zero_good_loss_out = (logit_diff**2 / 2).log1p().sum()
                full_good_loss_out = (ub_log_conf_out_batch**2 /
                                      2).log1p().sum()
                log_conf_out_batch = logit_batch_out.log_softmax(dim=-1).max(
                    dim=-1)[0]
                ceda_loss_out = log_conf_out_batch.sum()
                log_pred_out = logit_batch_out.log_softmax(dim=-1)
                oe_loss_out = -(log_pred_out / num_classes).sum()

                above_quantile_losses_out_epoch.append(
                    above_quantile_loss_out.detach().cpu().numpy())
                below_quantile_losses_out_epoch.append(
                    below_quantile_loss_out.detach().cpu().numpy())
                good_losses_out_epoch.append(
                    good_loss_out.detach().cpu().numpy())

                zero_good_losses_out_epoch.append(
                    zero_good_loss_out.detach().cpu().numpy())
                full_good_losses_out_epoch.append(
                    full_good_loss_out.detach().cpu().numpy())
                ceda_losses_out_epoch.append(
                    ceda_loss_out.detach().cpu().numpy())
                oe_losses_out_epoch.append(oe_loss_out.detach().cpu().numpy())
                log_confs_out_epoch.append(
                    log_conf_out_batch.detach().cpu().numpy())
                ub_log_confs_out_epoch.append(
                    ub_log_conf_out_batch.detach().cpu().numpy())

                #save example out batch splits
                if epoch % 10 == 0 and batch_number == 0:
                    if len(above_quantile_indices) > 0:
                        vutils.save_image(
                            img_batch_out[above_quantile_indices],
                            model_folder +
                            f'/batch_images/{epoch:3d}batch0_above_quantile.png'
                        )
                    if len(below_quantile_indices) > 0:
                        vutils.save_image(
                            img_batch_out[below_quantile_indices],
                            model_folder +
                            f'/batch_images/{epoch:3d}batch0_below_quantile.png'
                        )

            if args.method == 'plain' or epoch < schedule['out_start_epoch']:
                loss_batch = ce_loss_in.clone(
                )  #clone so adding acet to it cannot change ce_loss_in
                loss_name = 'in_ce'
                losses_out_epoch.append(0)
                kappa_losses_out_epoch.append(0)
            elif args.method == 'OE':
                loss_batch = ce_loss_in + kappa_epoch * oe_loss_out
                loss_name = f'in_ce+{kappa_epoch}*oe_loss_out'
                losses_out_epoch.append(oe_loss_out.detach().cpu().numpy())
                kappa_losses_out_epoch.append(
                    kappa_epoch * oe_loss_out.detach().cpu().numpy())
            elif args.method == 'CEDA':
                loss_batch = ce_loss_in + kappa_epoch * ceda_loss_out
                loss_name = f'in_ce+{kappa_epoch}*ceda_loss_out'
                losses_out_epoch.append(ceda_loss_out.detach().cpu().numpy())
                kappa_losses_out_epoch.append(
                    kappa_epoch * ceda_loss_out.detach().cpu().numpy())
            elif args.method == 'GOOD':
                loss_batch = ce_loss_in + kappa_epoch * good_loss_out
                loss_name = f'in_ce + {kappa_epoch}*(above_quantile_loss_out + eps{eps_epoch}below_quantile_loss_out)'
                losses_out_epoch.append(good_loss_out.detach().cpu().numpy())
                kappa_losses_out_epoch.append(
                    kappa_epoch * good_loss_out.detach().cpu().numpy())

            #acet is added on top
            if do_acet_epoch:
                loss_batch += kappa_epoch * acet_loss_out
                loss_name += f'+{kappa_epoch}*out_{eps_epoch}_adv_conf'

            losses_epoch.append(loss_batch.detach().cpu().numpy())  #tracking

            loss_batch.backward(
            )  # backpropagation of the loss. between here and optimizer.step() there should be no computations; only for saving the gradients it makes sense to have code between the two commands.
            optimizer.step()  # updates the parameters of the model

        ce_loss_in_epoch = np.sum(ce_losses_in_epoch) / in_samples_this_epoch
        accuracy_epoch = corrects_in_this_epoch / in_samples_this_epoch
        log_conf_in_epoch = np.sum(log_confs_in_epoch) / in_samples_this_epoch
        loss_epoch = np.sum(
            losses_epoch) / in_samples_this_epoch  #per in sample!

        loss_out_epoch = np.sum(losses_out_epoch) / out_samples_this_epoch
        kappa_loss_out_epoch = np.sum(
            kappa_losses_out_epoch) / out_samples_this_epoch

        if args.acet and kappa_epoch > 0:
            acet_loss_out_epoch = np.sum(
                acet_losses_out_epoch) / out_samples_this_epoch

        if args.method in {'OE', 'CEDA'}:
            oe_loss_out_epoch = np.sum(
                oe_losses_out_epoch) / out_samples_this_epoch
            ceda_loss_out_epoch = np.sum(
                ceda_losses_out_epoch) / out_samples_this_epoch
            log_conf_out_epoch = np.sum(
                np.concatenate(log_confs_out_epoch)) / out_samples_this_epoch
            median_log_conf_out_epoch = np.median(
                np.concatenate(log_confs_out_epoch)) / out_samples_this_epoch

        if args.method == 'GOOD':
            above_quantile_loss_out_epoch = np.sum(
                above_quantile_losses_out_epoch) / out_samples_this_epoch
            below_quantile_loss_out_epoch = np.sum(
                below_quantile_losses_out_epoch) / out_samples_this_epoch
            full_good_loss_out_epoch = np.sum(
                full_good_losses_out_epoch) / out_samples_this_epoch
            zero_good_loss_out_epoch = np.sum(
                zero_good_losses_out_epoch) / out_samples_this_epoch
            oe_loss_out_epoch = np.sum(
                oe_losses_out_epoch) / out_samples_this_epoch
            ceda_loss_out_epoch = np.sum(
                ceda_losses_out_epoch) / out_samples_this_epoch
            log_conf_out_epoch = np.sum(
                np.concatenate(log_confs_out_epoch)) / out_samples_this_epoch
            ub_log_conf_out_epoch = np.sum(
                np.concatenate(
                    ub_log_confs_out_epoch)) / out_samples_this_epoch
            median_log_conf_out_epoch = np.median(
                np.concatenate(log_confs_out_epoch)) / out_samples_this_epoch
            median_ub_log_conf_out_epoch = np.median(
                np.concatenate(
                    ub_log_confs_out_epoch)) / out_samples_this_epoch
        s_0 = f'Epoch {epoch} (lr={get_lr(optimizer)}) complete with mean training loss {loss_epoch} (ce_loss_in: {ce_loss_in_epoch}, loss_out:{loss_out_epoch}, used loss:{loss_name}).'
        s_1 = 'Time since start of training: {0}.\n'.format(
            datetime.datetime.utcnow() - starttime)
        print(s_0)
        print(s_1)

        writer.add_scalar('TrainIn/loss_total_per_in', loss_epoch, epoch)
        writer.add_scalar('TrainIn/ce_loss_in', ce_loss_in_epoch, epoch)
        writer.add_scalar('TrainIn/accuracy', accuracy_epoch, epoch)

        writer.add_scalar('TrainOut/loss_out', loss_out_epoch, epoch)
        writer.add_scalar('TrainOut/kappa_loss_out', kappa_loss_out_epoch,
                          epoch)
        if args.acet and kappa_epoch > 0:
            writer.add_scalar('TrainOut/acet_loss_out', acet_loss_out_epoch,
                              epoch)
        if args.method in {'OE', 'CEDA'}:
            writer.add_scalar('TrainOut/oe_loss_out', oe_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/ceda_loss_out', ceda_loss_out_epoch,
                              epoch)
            writer.add_scalar('TrainOut/log_conf_out', log_conf_out_epoch,
                              epoch)
            writer.add_scalar('TrainOut/median_log_conf_out',
                              median_log_conf_out_epoch, epoch)
            writer.add_histogram('Train_log_conf_out',
                                 np.concatenate(log_confs_out_epoch), epoch)
        if args.method == 'GOOD':
            writer.add_scalar('TrainOut/above_quantile_loss_out',
                              above_quantile_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/below_quantile_loss_out',
                              below_quantile_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/full_good_loss_out',
                              full_good_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/zero_good_loss_out',
                              zero_good_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/oe_loss_out', oe_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/ceda_loss_out', ceda_loss_out_epoch,
                              epoch)
            writer.add_scalar('TrainOut/log_conf_out', log_conf_out_epoch,
                              epoch)
            writer.add_scalar('TrainOut/ub_log_conf_out',
                              ub_log_conf_out_epoch, epoch)
            writer.add_scalar('TrainOut/median_log_conf_out',
                              median_log_conf_out_epoch, epoch)
            writer.add_scalar('TrainOut/median_ub_log_conf_out',
                              median_ub_log_conf_out_epoch, epoch)
            writer.add_histogram('Train_log_conf_out',
                                 np.concatenate(log_confs_out_epoch), epoch)
            writer.add_histogram('Train_ub_log_conf_out',
                                 np.concatenate(ub_log_confs_out_epoch), epoch)
        writer.add_scalar('TrainHyPa/eps', eps_epoch, epoch)
        writer.add_scalar('TrainHyPa/kappa', kappa_epoch, epoch)
        writer.add_scalar('TrainHyPa/learning_rate', get_lr(optimizer), epoch)

        do_valuation = True  #the whole evaluation only takes a few seconds.
        if do_valuation:
            val_th = 0.3  #for evaluating how many samples get conf values > 30%.
            if train_in_name == 'MNIST':
                val_eps = 0.3  #smaller values can be useful if no guarantees are given for this
            if train_in_name == 'SVHN' or train_in_name == 'CIFAR10':
                val_eps = 0.01

            eval_result_dict = evaluation.evaluate_ibp_lc(model,
                                                          val_loader_in,
                                                          val_loader_out_list,
                                                          eps=val_eps,
                                                          conf_th=val_th,
                                                          device=device,
                                                          n_pgd=0,
                                                          n_samples=1000)

            in_accuracy, pred_in_confidences, pred_in_mean_confidence, pred_in_above_th, number_of_in_datapoints = eval_result_dict[
                val_loader_out_list[0].dataset.__repr__().split()[1]][:5]
            writer.add_scalar('Val/in_accuracy', in_accuracy, epoch)
            writer.add_scalar('Val/mean_confidence', pred_in_mean_confidence,
                              epoch)
            writer.add_scalar('Val/confidences_above_{0:.2f}'.format(val_th),
                              pred_in_above_th / number_of_in_datapoints,
                              epoch)
            writer.add_scalar('Val/eps', val_eps, epoch)

            writer.add_histogram('Val/pred_in_confidences',
                                 pred_in_confidences, epoch)

            for val_loader_out in val_loader_out_list:
                out_name = val_loader_out.dataset.__repr__().split()[1]
                in_accuracy, pred_in_confidences, pred_in_mean_confidence, pred_in_above_th, number_of_in_datapoints, pred_out_confidences, pred_out_mean_confidence, pred_out_above_th, number_of_out_datapoints, ub_el_out_confidences, ub_elision_mean_out_confidence, ub_elision_median_out_confidence, ub_elision_out_below_th, auroc_from_predictions, auroc_out_guaranteed_softmax_elision, auroc_from_predictions_conservative, auroc_out_guaranteed_softmax_elision_conservative, pred_adv_out_confidences, adversarial_pred_out_mean_confidence, adversarial_pred_out_median_confidence, adversarial_pred_out_above_th = eval_result_dict[
                    out_name]

                writer.add_scalar('Val{0}/mean_confidence'.format(out_name),
                                  pred_out_mean_confidence, epoch)
                writer.add_scalar('Val{0}/mean_ub_confidence'.format(out_name),
                                  ub_elision_mean_out_confidence, epoch)
                writer.add_scalar(
                    'Val{0}/median_ub_confidence'.format(out_name),
                    ub_elision_median_out_confidence, epoch)
                writer.add_scalar(
                    'Val{0}/confidences_above_{1:.2f}'.format(
                        out_name, val_th),
                    pred_out_above_th / number_of_out_datapoints, epoch)
                writer.add_scalar(
                    'Val{0}/ub_confidences_below_{1:.2f}'.format(
                        out_name, val_th),
                    ub_elision_out_below_th / number_of_out_datapoints, epoch)
                writer.add_scalar('Val{0}/AUC'.format(out_name),
                                  auroc_from_predictions, epoch)
                writer.add_scalar('Val{0}/GAUC'.format(out_name),
                                  auroc_out_guaranteed_softmax_elision, epoch)
                writer.add_scalar('Val{0}/cAUC'.format(out_name),
                                  auroc_from_predictions_conservative, epoch)
                writer.add_scalar(
                    'Val{0}/cGAUC'.format(out_name),
                    auroc_out_guaranteed_softmax_elision_conservative, epoch)

                writer.add_histogram('Val{0}confidences'.format(out_name),
                                     pred_out_confidences, epoch)
                writer.add_histogram('Val{0}/ub_confidences'.format(out_name),
                                     ub_el_out_confidences, epoch)

        lr_scheduler.step()
        if epoch % 50 == 0 or epoch == 103 or epoch == 105:
            save_filename = model_folder + '/state_dicts/{0:03d}.pt'.format(
                epoch)
            torch.save(model.state_dict(), save_filename)
        torch.cuda.empty_cache()
        del data, img_batch_parts, lbl_batch_parts
        if 'reopen_data_file' in dir(
                train_loader_out
        ):  #loading 80M Tiny Images and thus training becomes much slower from epoch 2 if we do not do this.
            train_loader_out.reopen_data_file()
    stoptime = datetime.datetime.utcnow()
    dt = stoptime - starttime
    save_filename = model_folder + '/state_dicts/' + str(epoch) + 'fin.pt'
    torch.save(model.state_dict(), save_filename)
    print('Training finished after {0} seconds'.format(dt))
    writer.close()
    return model_folder
예제 #4
0
    def animation(self, mode='animate_image'):
        from PIL import Image
        from torchvision import transforms as T

        regular_image_transform = []
        regular_image_transform.append(T.ToTensor())
        regular_image_transform.append(
            T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
        regular_image_transform = T.Compose(regular_image_transform)

        G_path = sorted(glob.glob(
            os.path.join(self.animation_models_dir, '*G.ckpt')),
                        key=self.numericalSort)[0]
        self.G.load_state_dict(
            torch.load(G_path, map_location=f'cuda:{self.gpu_id}'))
        self.G = self.G.cuda(0)

        reference_expression_images = []

        with torch.no_grad():
            with open(self.animation_attributes_path, 'r') as txt_file:
                csv_lines = txt_file.readlines()

                targets = torch.zeros(len(csv_lines), self.c_dim)
                input_images = torch.zeros(len(csv_lines), 3, 128, 128)

                for idx, line in enumerate(csv_lines):
                    splitted_lines = line.split(' ')
                    image_path = os.path.join(
                        self.animation_attribute_images_dir, splitted_lines[0])
                    input_images[idx, :] = regular_image_transform(
                        Image.open(image_path)).cuda()
                    reference_expression_images.append(splitted_lines[0])
                    targets[idx, :] = torch.Tensor(
                        np.array(
                            list(
                                map(lambda x: float(x) / 5.,
                                    splitted_lines[1::]))))

        if mode == 'animate_random_batch':
            animation_batch_size = 7

            self.data_iter = iter(self.data_loader)
            images_to_animate, _ = next(self.data_iter)
            images_to_animate = images_to_animate[0:animation_batch_size].cuda(
            )

            for target_idx in range(targets.size(0)):
                targets_au = targets[target_idx, :].unsqueeze(0).repeat(
                    animation_batch_size, 1).cuda()
                resulting_images_att, resulting_images_reg = self.G(
                    images_to_animate, targets_au)

                resulting_images = self.imFromAttReg(resulting_images_att,
                                                     resulting_images_reg,
                                                     images_to_animate).cuda()

                save_images = - \
                    torch.ones((animation_batch_size + 1)
                               * 2, 3, 128, 128).cuda()

                save_images[1:animation_batch_size + 1] = images_to_animate
                save_images[animation_batch_size +
                            1] = input_images[target_idx]
                save_images[animation_batch_size +
                            2:(animation_batch_size + 1) *
                            2] = resulting_images

                save_image(
                    (save_images + 1) / 2,
                    os.path.join(self.animation_results_dir,
                                 reference_expression_images[target_idx]))

        if mode == 'animate_image':
            # Initialize Metrics
            l1_error, l2_error, min_dist, l0_error = 0.0, 0.0, 0.0, 0.0
            n_dist, n_samples = 0, 0

            pgd_attack = attacks.LinfPGDAttack(model=self.G,
                                               device=self.device)

            images_to_animate_path = sorted(
                glob.glob(self.animation_images_dir + '/*'))

            for idx, image_path in enumerate(images_to_animate_path):
                image_to_animate = regular_image_transform(
                    Image.open(image_path)).unsqueeze(0).cuda()

                for target_idx in range(targets.size(0) - 1):
                    print('image', idx, 'AU', target_idx)

                    targets_au = targets[target_idx, :].unsqueeze(0).cuda()

                    with torch.no_grad():
                        resulting_images_att_noattack, resulting_images_reg_noattack = self.G(
                            image_to_animate, targets_au)
                        resulting_image_noattack = self.imFromAttReg(
                            resulting_images_att_noattack,
                            resulting_images_reg_noattack,
                            image_to_animate).cuda()

                    # Transfer to different classes
                    # if target_idx == 0:
                    # Wrong Class
                    # x_adv, perturb = pgd_attack.perturb(image_to_animate, image_to_animate, targets[0, :].unsqueeze(0).cuda())

                    # Joint Class Conditional
                    # x_adv, perturb = pgd_attack.perturb_joint_class(image_to_animate, image_to_animate, targets[:, :].cuda())

                    # Iterative Class Conditional
                    # x_adv, perturb = pgd_attack.perturb_iter_class(image_to_animate, image_to_animate, targets[:, :].cuda())

                    # Iterative Data
                    # _, perturb = pgd_attack.perturb_iter_data(image_to_animate, all_images, image_to_animate, targets[68, :].unsqueeze(0).cuda())

                    # Normal Attack
                    x_adv, perturb = pgd_attack.perturb(
                        image_to_animate, resulting_image_noattack, targets_au)

                    # Use this line if transferring attacks
                    x_adv = image_to_animate + perturb

                    # No Attack
                    # x_adv = image_to_animate

                    with torch.no_grad():
                        resulting_images_att, resulting_images_reg = self.G(
                            x_adv, targets_au)
                        resulting_image = self.imFromAttReg(
                            resulting_images_att, resulting_images_reg,
                            x_adv).cuda()

                    save_image(
                        (resulting_image + 1) / 2,
                        os.path.join(
                            self.animation_results_dir,
                            image_path.split('/')[-1].split('.')[0] + '_' +
                            reference_expression_images[target_idx]))
                    if target_idx == 0:
                        save_image(
                            (x_adv + 1) / 2,
                            os.path.join(
                                self.animation_results_dir,
                                image_path.split('/')[-1].split('.')[0] +
                                '_ref.jpg'))

                    # Compare to ground-truth output
                    l1_error += F.l1_loss(resulting_image,
                                          resulting_image_noattack)
                    l2_error += F.mse_loss(resulting_image,
                                           resulting_image_noattack)
                    l0_error += (resulting_image -
                                 resulting_image_noattack).norm(0)
                    min_dist += (resulting_image -
                                 resulting_image_noattack).norm(float('-inf'))

                    # Compare to input image
                    # l1_error += F.l1_loss(resulting_image, x_adv)
                    # l2_error += F.mse_loss(resulting_image, x_adv)
                    # l0_error += (resulting_image - x_adv).norm(0)
                    # min_dist += (resulting_image - x_adv).norm(float('-inf'))

                    if F.mse_loss(resulting_image,
                                  resulting_image_noattack) > 0.05:
                        n_dist += 1
                    n_samples += 1

                    # Debug
                    # x_adv, targets_au, resulting_image, resulting_images_att, resulting_images_reg = None, None, None, None, None

                image_to_animate = None

            # Print metrics
            print(
                '{} images. L1 error: {}. L2 error: {}. prop_dist: {}. L0 error: {}. L_-inf error: {}.'
                .format(n_samples, l1_error / n_samples, l2_error / n_samples,
                        float(n_dist) / float(n_samples), l0_error / n_samples,
                        min_dist / n_samples))
예제 #5
0
파일: evaluation.py 프로젝트: j-cb/GOOD
def evaluate_ibp_lc(
    model,
    test_loader_in,
    test_loader_out_list,
    eps,
    conf_th,
    device,
    short_name=None,
    n_pgd=0,
    model_path=None,
    n_samples=30000,
    do_accuracy_above=False,
    save_plots=False
):  #pass model_path to save evaluation graphs in the model directory
    starttime_eval = datetime.datetime.utcnow()
    torch.set_printoptions(profile="full")
    in_name = str(test_loader_in.dataset.__repr__()).split()[1]
    num_classes = model.num_classes
    print('Evaluating {0} for accuracy and confidence-based OOD detection.'.
          format(model_path))
    model.eval()
    k = min(32, n_samples)  #number of worst case images to be saved
    if n_pgd > 0:
        pgd_attack_out = attacks.LinfPGDAttack(epsilon=eps,
                                               n=n_pgd,
                                               loss_fn=lossfunctions.LogConf,
                                               random_start=False,
                                               device=device)
        pgd_attack_in = attacks.LinfPGDAttack(epsilon=eps,
                                              n=n_pgd,
                                              loss_fn=nn.LogConf,
                                              random_start=False,
                                              device=device)

    #prepare LaTeX-usable outputs
    table_str_0 = f'Model '
    table_str_1 = f'{model_path} '

    print('\nIn dataset: {0}'.format(str(test_loader_in.dataset.__repr__())))
    n_in_samples = min(n_samples, len(test_loader_in.dataset))
    number_of_in_datapoints = 0
    labels_in = np.zeros(n_in_samples, dtype=int)
    logits_in = np.zeros((n_in_samples, num_classes))
    pred_scores_in = np.zeros((n_in_samples, num_classes))
    adversarial_logits_in = np.zeros((n_in_samples, num_classes))
    adversarial_pred_scores_in = np.zeros((n_in_samples, num_classes))
    ud_logits_in = np.zeros((n_in_samples, num_classes, num_classes))
    for batch, (img_torch, lbl) in enumerate(test_loader_in):
        bs = len(lbl)
        number_of_in_datapoints += bs
        if number_of_in_datapoints > n_in_samples:  #adjust length of last batch
            assert number_of_in_datapoints - n_in_samples < bs
            bs = bs - (number_of_in_datapoints - n_in_samples)
            img_torch = img_torch[:bs]
            lbl = lbl[:bs]
        labels_in[test_loader_in.batch_size *
                  batch:test_loader_in.batch_size * batch + bs] = lbl.numpy()
        img = img_torch.to(device)
        lbl_in_batch = lbl.to(device)

        #get clean, adversarial and guaranteed predictions
        logit_in_batch = model(img)
        pred_score_in_batch = logit_in_batch.softmax(dim=-1)
        if n_pgd > 0:
            #Run an adversarial attack trying to make prediction wrong.
            adversarial_img, _ = pgd_attack_in.perturbt(img, lbl, model)
            model.eval()
            adversarial_logit_in_batch = model(adversarial_img)
            adversarial_pred_score_in_batch = adversarial_logit_in_batch.softmax(
                dim=-1)
        else:
            adversarial_img = img
            adversarial_logit_in_batch = model(adversarial_img)
            adversarial_pred_score_in_batch = adversarial_logit_in_batch.softmax(
                dim=-1)
        l_logit_in_batch, u_logit_in_batch, ud_logit_in_batch = model.ibp_elision_forward(
            torch.clamp(img - eps, 0, 1), torch.clamp(img + eps, 0, 1),
            num_classes)

        labels_in[test_loader_in.batch_size *
                  batch:test_loader_in.batch_size * batch +
                  bs] = lbl_in_batch.detach().cpu().numpy()
        logits_in[test_loader_in.batch_size *
                  batch:test_loader_in.batch_size * batch +
                  bs] = logit_in_batch.detach().cpu().numpy()
        pred_scores_in[test_loader_in.batch_size *
                       batch:test_loader_in.batch_size * batch +
                       bs] = pred_score_in_batch.detach().cpu().numpy()
        adversarial_logits_in[test_loader_in.batch_size *
                              batch:test_loader_in.batch_size * batch +
                              bs] = adversarial_logit_in_batch.detach().cpu(
                              ).numpy()
        adversarial_pred_scores_in[
            test_loader_in.batch_size *
            batch:test_loader_in.batch_size * batch +
            bs] = adversarial_pred_score_in_batch.detach().cpu().numpy()
        ud_logits_in[test_loader_in.batch_size *
                     batch:test_loader_in.batch_size * batch +
                     bs] = ud_logit_in_batch.detach().cpu().numpy()

        #save the first batch of in-distribution images and prepare folders
        if model_path and batch == 0:
            model_folder = 'evals/' + ''.join(model_path.split(
                '/')[:-2]) + 'eval_' + starttime_eval.strftime(
                    "%m-%d-%H-%M-%S") + 'e=' + str(eps)
            print(rf'Evaluation outputs saved in {model_folder}')
            os.makedirs(model_folder + '/sample_images/', exist_ok=True)
            os.makedirs(model_folder + '/values/', exist_ok=True)
            save_path = model_folder + '/sample_images/{1}_{0}first_batch_val_in'.format(
                str(test_loader_in.dataset).split()[1], short_name)
            vutils.save_image(img.detach(),
                              save_path + '.png',
                              normalize=False)

        #stop loader iteration if specified number of samples is reached
        if number_of_in_datapoints >= n_in_samples:
            break

    #analysis of the in-distribution results
    pred_in_classes = np.argmax(logits_in, axis=1)
    pred_adversarial_in_classes = np.argmax(adversarial_logits_in, axis=1)
    in_accuracy = accuracy(pred_in_classes, labels_in)
    adv_in_accuracy = accuracy(pred_adversarial_in_classes, labels_in)

    pred_in_log_confidences = log_confs_from_logits(logits_in)
    pred_in_confidences, pred_in_confidence_mean, pred_in_confidence_median, pred_in_confidences_below_th, pred_in_confidences_above_th, pred_in_lowest_conf_indices, pred_in_highest_conf_indices = conf_stats_from_log_confs(
        pred_in_log_confidences, conf_th, k)

    pred_in_right_confidences, pred_in_wrong_confidences = right_and_wrong_confidences_from_logits(
        logits_in, labels_in)
    pred_in_worst_conf_indices = (-pred_in_wrong_confidences).argsort()[:k]
    pred_in_worst_ce_indices = pred_in_right_confidences.argsort()[:k]

    #analyze if and by how much the accuracy improves if low confidence predictions are discarded
    if do_accuracy_above:
        #conf_thresholds = [0.0, 0.1, 0.11, 0.12, 0.13, 0.14, 0.16, 0.18, 0.2, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999]
        #for conf in conf_thresholds:
        #    print(f'Accuracy wherever the confidence is above {conf:.2f}: {accuracy_above(logits_in, labels_in, conf):.2%}')
        t = np.linspace(0, 1, 1000)
        accuracy_above_vectorized = np.vectorize(accuracy_above,
                                                 excluded=[0, 1])
        plt.plot(t,
                 accuracy_above_vectorized(logits_in, labels_in, t),
                 c='#44DD44')
        plt.box(False)
        save_path = model_folder + f'/sample_images/{short_name}_accuracy_above_threshold.png'
        plt.axvline(x=1.0, linestyle='--', c='#BBBBBB')
        plt.axvline(x=0.1, linestyle='--', c='#BBBBBB')
        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)
        plt.grid(which='major', axis='y')
        plt.yticks([0.9, 0.92, 0.94, 0.96, 0.98, 1.0])
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()

        t = np.linspace(0, 1, 1000)
        frac_above_vectorized = np.vectorize(frac_above, excluded=[0, 1])
        plt.plot(t,
                 frac_above_vectorized(logits_in, labels_in, t),
                 c='#4444DD')
        plt.box(False)
        save_path = model_folder + f'/sample_images/{short_name}_frac_above_threshold.png'
        plt.axvline(x=1.0, linestyle='--', c='#BBBBBB')
        plt.axvline(x=0.1, linestyle='--', c='#BBBBBB')
        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)
        plt.grid(which='major', axis='y')
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()
    pred_adv_in_log_confidences = log_confs_from_logits(adversarial_logits_in)
    pred_adv_in_confidences, pred_adv_in_confidence_mean, pred_adv_in_confidence_median, pred_adv_in_confidences_below_th, pred_adv_in_confidences_above_th, pred_adv_in_lowest_conf_indices, pred_adv_in_highest_conf_indices = conf_stats_from_log_confs(
        pred_adv_in_log_confidences, conf_th, k)

    ub_elision_in_log_probs = -logsumexp(-ud_logits_in, axis=-1)
    ub_elision_in_log_probs[:, labels_in] = -1
    ub_elision_in_log_confs = np.amax(ub_elision_in_log_probs, axis=-1)
    ub_elision_in_confidences = np.exp(ub_elision_in_log_confs)
    ub_elision_mean_in_confidence = np.mean(ub_elision_in_confidences)
    ub_elision_median_in_confidence = np.median(ub_elision_in_confidences)
    ub_elision_in_below_th = np.sum(ub_elision_in_confidences < conf_th)

    auroc_vs_adversarials = auroc(pred_in_confidences, pred_adv_in_confidences)

    print(
        'The accuracy of the predictions of the {0} model on the in-distribution is {1:.2f}%. '
        .format(model.__name__, in_accuracy * 100),
        'Mean confidence: {0:.4f}. Median confidence: {1:.4f}. Above {2}: {3}/{4}.,'
        .format(pred_in_confidence_mean, pred_in_confidence_median, conf_th,
                pred_in_confidences_above_th, n_in_samples))
    table_str_0 += f'& Acc. '
    table_str_1 += f'& {100*in_accuracy:03.1f} '
    if n_pgd > 0:
        print(
            'Adversarial in samples under {0}: Accuracy: {1}'.format(
                pgd_attack_in.__name__, adv_in_accuracy),
            'Mean confidence: {0:.4f}. Median confidence: {1:.4f}. Above {2}: {3}/{4}.'
            .format(pred_adv_in_confidence_mean, pred_adv_in_confidence_median,
                    conf_th, pred_adv_in_confidences_above_th, n_in_samples),
            'AUROC in confidences vs adversarial in confidences: {0:.4f}'.
            format(auroc_vs_adversarials))

    if model_path:
        if save_plots:
            plotting.save_images(k,
                                 test_loader_in,
                                 pred_in_lowest_conf_indices,
                                 pred_in_confidences,
                                 pred_scores_in,
                                 in_name,
                                 model_folder,
                                 descr='pred_lowest',
                                 index_selection='lowest',
                                 short_name=short_name)
            plotting.save_images(k,
                                 test_loader_in,
                                 pred_in_worst_conf_indices,
                                 pred_in_wrong_confidences,
                                 pred_scores_in,
                                 in_name,
                                 model_folder,
                                 descr='pred_worst_wrong',
                                 index_selection='worst_wrong_prob',
                                 short_name=short_name)
            plotting.save_images(k,
                                 test_loader_in,
                                 pred_in_worst_ce_indices,
                                 pred_in_right_confidences,
                                 pred_scores_in,
                                 in_name,
                                 model_folder,
                                 descr='pred_worst_ce',
                                 index_selection='worst_right_prob',
                                 short_name=short_name)
            plt.hist(pred_in_confidences,
                     color='g',
                     bins=50,
                     range=(0.1, 1.0),
                     log=False,
                     label=f'{in_name}_in_pred_confidences')
            plt.ylim(ymin=0, ymax=len(pred_in_confidences) / 1.5)
            plt.xlim(xmin=0.09, xmax=1.01)
            plt.axvline(pred_in_confidences.mean(),
                        color='k',
                        linestyle='dashed',
                        linewidth=2)
            save_path = model_folder + f'/sample_images/{short_name}{in_name}_in_pred_confidences_hist'
            plt.savefig(save_path, bbox_inches='tight')
            plt.close()
        np.savetxt(
            model_folder +
            f'/values/{short_name}_{in_name}_pred_in_confidences.txt',
            pred_in_confidences)

    returns = {}
    for test_loader_out in test_loader_out_list:
        starttime_eval_out = datetime.datetime.utcnow()
        out_name = str(test_loader_out.dataset.__repr__()).split()[1]
        print('\nOut dataset: {0}'.format(test_loader_out.dataset.__repr__()))

        n_out_samples = min(n_samples, len(test_loader_out.dataset))
        print(f'{n_out_samples} out samples.')
        number_of_out_datapoints = 0
        #initialize numpy arrays for result values
        logits_out = np.zeros((n_out_samples, num_classes))
        l_logits_out = np.zeros((n_out_samples, num_classes))
        u_logits_out = np.zeros((n_out_samples, num_classes))
        ud_logits_out = np.zeros((n_out_samples, num_classes, num_classes))
        pred_scores_out = np.zeros((n_out_samples, num_classes))
        adversarial_logits_out = np.zeros((n_out_samples, num_classes))
        adversarial_pred_scores_out = np.zeros((n_out_samples, num_classes))
        th_eps_out = np.zeros(n_out_samples)

        for batch, (img_torch, lbl) in enumerate(test_loader_out):
            bs = len(lbl)
            number_of_out_datapoints += bs
            if number_of_out_datapoints > n_out_samples:  #as above, reduce the length of the last batch if it overflows the total number of samples.
                assert number_of_out_datapoints - n_out_samples < bs
                bs = bs - (number_of_out_datapoints - n_out_samples)
                img_torch = img_torch[:bs]
                lbl = 0 * lbl[:bs]
            logit_out_batch = np.zeros((bs, num_classes))
            img_out = img_torch.to(device)
            lbl_out_batch = lbl.to(device)

            #get the model outputs
            logit_out_batch = model(img_out)
            pred_score_out_batch = logit_out_batch.softmax(dim=-1)
            l_logit_out_batch, u_logit_out_batch, ud_logit_out_batch = model.ibp_elision_forward(
                torch.clamp(img_out - eps, 0, 1),
                torch.clamp(img_out + eps, 0, 1), num_classes)
            if n_pgd > 0:
                adversarial_out_img, _ = pgd_attack_out.perturbt(
                    img_out, lbl, model)
                adversarial_logit_out_batch = model(adversarial_out_img)
                adversarial_pred_score_out_batch = adversarial_logit_out_batch.softmax(
                    dim=-1)
            else:
                adversarial_out_img = img_out
                adversarial_logit_out_batch = model(adversarial_out_img)
                adversarial_pred_score_out_batch = adversarial_logit_out_batch.softmax(
                    dim=-1)

            logits_out[test_loader_out.batch_size *
                       batch:test_loader_out.batch_size * batch +
                       bs] = logit_out_batch.detach().cpu().numpy()
            pred_scores_out[test_loader_out.batch_size *
                            batch:test_loader_out.batch_size * batch +
                            bs] = pred_score_out_batch.detach().cpu().numpy()

            l_logits_out[test_loader_out.batch_size *
                         batch:test_loader_out.batch_size * batch +
                         bs] = l_logit_out_batch.detach().cpu().numpy()
            u_logits_out[test_loader_out.batch_size *
                         batch:test_loader_out.batch_size * batch +
                         bs] = u_logit_out_batch.detach().cpu().numpy()
            ud_logits_out[test_loader_out.batch_size *
                          batch:test_loader_out.batch_size * batch +
                          bs] = ud_logit_out_batch.detach().cpu().numpy()

            adversarial_logits_out[test_loader_out.batch_size *
                                   batch:test_loader_out.batch_size * batch +
                                   bs] = adversarial_logit_out_batch.detach(
                                   ).cpu().numpy()
            adversarial_pred_scores_out[
                test_loader_out.batch_size *
                batch:test_loader_out.batch_size * batch +
                bs] = adversarial_pred_score_out_batch.detach().cpu().numpy()

            if model_path and batch == 0:  #save some example images that the evaluation is run on
                save_path = model_folder + f'/sample_images/{short_name}_{out_name}first_batch_val'
                vutils.save_image(img_out.detach(),
                                  save_path + '.png',
                                  normalize=False)
                save_path = model_folder + f'/sample_images/{short_name}_{out_name}first_batch_val_pgd'
                vutils.save_image(adversarial_out_img.detach(),
                                  save_path + '.png',
                                  normalize=False)
            save_adversarials = False  #set this to True to save computed adversarial images.
            if model_path and n_pgd > 0 and save_adversarials:
                os.makedirs(model_folder +
                            f'/sample_images/{short_name}_{out_name}_adv/',
                            exist_ok=True)
                save_path = model_folder + f'/sample_images/{short_name}_{out_name}_adv/_batch{batch:3d}'
                for i in range(len(adversarial_out_img)):
                    vutils.save_image(img_out[i].detach(),
                                      f'{save_path}_out{i}.png',
                                      normalize=False)
                    vutils.save_image(adversarial_out_img[i].detach(),
                                      f'{save_path}_adv{i}.png',
                                      normalize=False)
            if number_of_out_datapoints >= n_out_samples:
                break

        pred_out_log_confidences = log_confs_from_logits(logits_out)
        pred_out_confidences, pred_out_confidence_mean, pred_out_confidence_median, pred_out_confidences_below_th, pred_out_confidences_above_th, pred_out_lowest_conf_indices, pred_out_highest_conf_indices = conf_stats_from_log_confs(
            pred_out_log_confidences, conf_th, k)

        #we calculate the bounds based on ibp with and without elision so we can judge the difference if we need to
        out_logit_spreads = u_logits_out.max(axis=-1) - l_logits_out.min(
            axis=-1)
        ub_spread_out_log_confidences = out_logit_spreads - np.log(num_classes)
        ub_spread_out_confidences, ub_spread_out_confidence_mean, ub_spread_out_confidence_median, ub_spread_out_confidences_below_th, ub_spread_out_confidences_above_th, ub_spread_out_lowest_conf_indices, ub_spread_out_highest_conf_indices = conf_stats_from_log_confs(
            ub_spread_out_log_confidences, conf_th, k)

        ub_out_logit_differences = u_logits_out[:, :, np.
                                                newaxis] - l_logits_out[:, np.
                                                                        newaxis, :]
        ub_out_log_confidences = ub_log_confs_from_ud_logits(
            ub_out_logit_differences, force_diag_0=True)
        ub_out_confidences, ub_out_confidence_mean, ub_out_confidence_median, ub_out_confidences_below_th, ub_out_confidences_above_th, ub_out_lowest_conf_indices, ub_out_highest_conf_indices = conf_stats_from_log_confs(
            ub_out_log_confidences, conf_th, k)

        ub_el_out_log_confidences = ub_log_confs_from_ud_logits(
            ud_logits_out, force_diag_0=False)
        ub_el_out_confidences, ub_el_out_confidence_mean, ub_el_out_confidence_median, ub_el_out_confidences_below_th, ub_el_out_confidences_above_th, ub_el_out_lowest_conf_indices, ub_el_out_highest_conf_indices = conf_stats_from_log_confs(
            ub_el_out_log_confidences, conf_th, k)

        pred_adv_out_log_confidences = log_confs_from_logits(
            adversarial_logits_out)
        pred_adv_out_confidences, pred_adv_out_confidence_mean, pred_adv_out_confidence_median, pred_adv_out_confidences_below_th, pred_adv_out_confidences_above_th, pred_adv_out_lowest_conf_indices, pred_adv_out_highest_conf_indices = conf_stats_from_log_confs(
            pred_adv_out_log_confidences, conf_th, k)

        auroc_from_predictions = auroc(pred_in_confidences,
                                       pred_out_confidences)
        auroc_out_guaranteed_spread = auroc(
            pred_in_confidences, np.nan_to_num(ub_spread_out_confidences))
        auroc_out_guaranteed_softmax = auroc(pred_in_confidences,
                                             np.nan_to_num(ub_out_confidences))
        auroc_out_guaranteed_softmax_elision = auroc(
            pred_in_confidences, np.nan_to_num(ub_el_out_confidences))
        auroc_out_adversarial = auroc(pred_in_confidences,
                                      pred_adv_out_confidences)

        auroc_from_predictions_conservative = auroc_conservative(
            pred_in_confidences, pred_out_confidences)
        auroc_out_guaranteed_spread_conservative = auroc_conservative(
            pred_in_confidences, ub_spread_out_confidences)
        auroc_out_guaranteed_softmax_conservative = auroc_conservative(
            pred_in_confidences, ub_out_confidences)
        auroc_out_guaranteed_softmax_elision_conservative = auroc_conservative(
            pred_in_confidences, ub_el_out_confidences)
        auroc_out_adversarial_conservative = auroc_conservative(
            pred_in_confidences, pred_adv_out_confidences)

        if model_path:
            if save_plots:
                plotting.save_images(k,
                                     test_loader_out,
                                     pred_out_highest_conf_indices,
                                     pred_out_confidences,
                                     pred_scores_out,
                                     out_name,
                                     model_folder,
                                     descr='pred',
                                     index_selection='highest',
                                     short_name=short_name)
                plt.hist(th_eps_out,
                         bins=100,
                         log=False,
                         label='eps with guarantee <{0:.2f}'.format(conf_th))
                save_path = model_folder + f'/sample_images/{short_name}{out_name}eps_hist'
                plt.savefig(save_path, bbox_inches='tight')
                plt.close()
                plotting.save_images(k,
                                     test_loader_out,
                                     pred_adv_out_highest_conf_indices,
                                     pred_adv_out_confidences,
                                     adversarial_pred_scores_out,
                                     out_name,
                                     model_folder,
                                     descr='adv',
                                     index_selection='highest',
                                     short_name=short_name)
                plotting.save_images(k,
                                     test_loader_out,
                                     ub_el_out_highest_conf_indices,
                                     ub_el_out_confidences,
                                     pred_scores_out,
                                     out_name,
                                     model_folder,
                                     descr='ub_el',
                                     index_selection='highest',
                                     short_name=short_name)
                bins = np.linspace(0.1, 0.4, 200)
                plt.hist(pred_out_confidences,
                         bins=bins,
                         log=False,
                         label='out_pred')
                plt.hist(ub_el_out_confidences,
                         bins=bins,
                         log=False,
                         label='out_ub')
                plt.hist(pred_in_confidences,
                         bins=bins,
                         log=False,
                         label='in_pred')
                plt.legend()
                save_path = model_folder + f'/sample_images/{short_name}{out_name}compare_hist'
                plt.savefig(save_path, bbox_inches='tight')
                plt.close()

            np.savetxt(
                model_folder +
                f'/values/{short_name}{out_name}_pred_out_confidences.txt',
                pred_out_confidences)
            np.savetxt(
                model_folder +
                f'/values/{short_name}{out_name}_pred_adv_out_confidences.txt',
                pred_adv_out_confidences)
            np.savetxt(
                model_folder +
                f'/values/{short_name}{out_name}_ub_el_out_confidences.txt',
                ub_el_out_confidences)

        #The non-elision values are mainly for comparing the methods and diagnosis but give less tight results, so you might want to comment them out.
        print(
            'The mean confidence of the predictions of the {0} model is {1:.4f} on the in-distribution and {2:.4f} on the out-distribution.'
            .format(model.__name__, pred_in_confidence_mean,
                    pred_out_confidence_mean))
        print(
            'For epsilon={0} on the out-distribution, the guaranteed confidence upper bound from logit spread has a mean of {1} and a median of {2}.'
            .format(eps, ub_spread_out_confidence_mean,
                    ub_spread_out_confidence_median))
        print(
            'For epsilon={0} on the out-distribution, the guaranteed confidence upper bound from softmax has a mean of {1} and a median of {2}.'
            .format(eps, ub_out_confidence_mean, ub_out_confidence_median))
        print(
            'For epsilon={0} on the out-distribution, the guaranteed confidence upper bound from softmax with elision has a mean of {1} and a median of {2}.'
            .format(eps, ub_el_out_confidence_mean,
                    ub_el_out_confidence_median))
        print(
            'In-samples confidence above {0}: {1}/{2} = {3}. Out-samples predicted confidence above {0}: {4}/{5} = {6}. Out-samples confidence {9}-guaranteed from spread below {0}: {7}/{5} = {8}. Out-samples confidence {9}-guaranteed from softmax below {0}: {10}/{5} = {11}. Out-samples confidence {9}-guaranteed from softmax with elision below {0}: {12}/{5} = {13}.'
            .format(conf_th, pred_in_confidences_above_th, n_in_samples,
                    pred_in_confidences_above_th / n_in_samples,
                    pred_out_confidences_above_th, n_out_samples,
                    pred_out_confidences_above_th / n_out_samples,
                    ub_spread_out_confidences_below_th,
                    ub_spread_out_confidences_below_th / n_out_samples, eps,
                    ub_out_confidences_below_th,
                    ub_out_confidences_below_th / n_out_samples,
                    ub_el_out_confidences_below_th,
                    ub_el_out_confidences_below_th / n_out_samples))
        print(
            'Prediction AUROC: {0}. {1}-guaranteed AUROC: {2} (logit spread), {3} (softmax), {4} (elision softmax).'
            .format(auroc_from_predictions, eps, auroc_out_guaranteed_spread,
                    auroc_out_guaranteed_softmax,
                    auroc_out_guaranteed_softmax_elision))
        print(
            'Conservative Prediction AUROC: {0}. {1}-guaranteed AUROC: {2} (logit spread), {3} (softmax), {4} (elision softmax).'
            .format(auroc_from_predictions_conservative, eps,
                    auroc_out_guaranteed_spread_conservative,
                    auroc_out_guaranteed_softmax_conservative,
                    auroc_out_guaranteed_softmax_elision_conservative))
        if n_pgd > 0:
            print(
                'Adversarial outs under {0}: Mean confidence: {1:.4f}. Median confidence: {2:.4f}. Confidence above {3}: {4}/{5} = {6:.4f}. AUROC: {7}.\n'
                .format(
                    pgd_attack_out.__name__, pred_adv_out_confidence_mean,
                    pred_adv_out_confidence_median, conf_th,
                    pred_adv_out_confidences_above_th,
                    len(pred_adv_out_confidences),
                    pred_adv_out_confidences_above_th /
                    len(pred_adv_out_confidences), auroc_out_adversarial))
        latex_str = 'Prediction/Adversarial/Guaranteed AUROC: {0:03.1f} & {1:03.1f} & {2:03.1f}'.format(
            100 * auroc_from_predictions, 100 * auroc_out_adversarial,
            100 * auroc_out_guaranteed_softmax_elision)
        latex_str_conservative = 'Conservative Prediction/Adversarial/Guaranteed AUROC: {0:03.1f} & {1:03.1f} & {2:03.1f}'.format(
            100 * auroc_from_predictions_conservative,
            100 * auroc_out_adversarial_conservative,
            100 * auroc_out_guaranteed_softmax_elision_conservative)
        print(latex_str)
        print(latex_str_conservative)
        table_str_0 += f'& {out_name} P & A & G '
        table_str_1 += f' & {auroc_from_predictions_conservative*100:03.1f} & {auroc_out_adversarial_conservative*100:3.1f} & {auroc_out_guaranteed_softmax_elision_conservative*100:3.1f} '
        returns[out_name] = (
            in_accuracy, pred_in_confidences, pred_in_confidence_mean,
            pred_in_confidences_above_th, n_in_samples, pred_out_confidences,
            pred_out_confidence_mean, pred_out_confidences_above_th,
            n_out_samples, ub_el_out_confidences, ub_el_out_confidence_mean,
            ub_el_out_confidence_median, ub_el_out_confidences_below_th,
            auroc_from_predictions, auroc_out_guaranteed_softmax_elision,
            auroc_from_predictions_conservative,
            auroc_out_guaranteed_softmax_elision_conservative,
            pred_adv_out_confidences, pred_adv_out_confidence_mean,
            pred_adv_out_confidence_median, pred_adv_out_confidences_above_th)
        print('Eval for this out dataset: {0}.\n'.format(
            datetime.datetime.utcnow() - starttime_eval_out))
    table_str_0 += f' \\\\'
    table_str_1 += f' \\\\'  #Note that the adversarial evaluations are sub-optimal since a dedicated attack evaluation is needed to get strong results.
    print(table_str_0)
    print(table_str_1)
    print('Total time for eval: {0}.\n'.format(datetime.datetime.utcnow() -
                                               starttime_eval))
    return returns


#def eps_reduction_for_th(img_out, model, eps_start, )