Пример #1
0
def train_loop(model, loader, optimizer, criterion, epoch="-", normalization=None,
               store=None, adv=False):
    acc_meter = AverageMeter()
    # iterator = tqdm(iter(loader), total=len(loader), position=0, leave=True)
    for data, target in loader:
        data, target = data.cuda(), target.cuda()
        model.train()
        if adv:
            data = utils.L2PGD(model, data, target, normalization,
                               step_size=0.5, Nsteps=20,
                               eps=1.25, targeted=False, use_tqdm=False)

        optimizer.zero_grad()
        logits = utils.forward_pass(model, data, normalization)
        loss = criterion(logits, target)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            model.eval()
            val = utils.accuracy(model, data, target, normalization)
            acc_meter.update(val, data.shape[0])

            # Commented out to reduce the amount of logs in Colab
            # iterator.set_description(f"Epoch: {epoch}, Adv: {adv}, Train accuracy={acc_meter.avg:.2f}")
            # iterator.refresh()
    if store:
        store.tensorboard.add_scalar("train_accuracy", acc_meter.avg, epoch)
Пример #2
0
    def train_step(self, loader):
        self.discriminator.train()
        self.generator.train()
        device = self.args.device
        loss_dict = dict()
        loss_dict["D_loss"] = AverageMeter()
        loss_dict["G_loss"] = AverageMeter()
        loss_dict["MSE_loss"] = AverageMeter()
        b_loader = tqdm(loader)
        for _, x_batch, _, m_batch in b_loader:
            x_batch, m_batch = x_batch.to(device), m_batch.to(device)
            self.g_optimizer.zero_grad()
            sample, random_combined, x_hat = self.generator(x_batch, m_batch)
            G_loss, mse_loss = self.generator_loss(m_batch, self.discriminator(x_hat, m_batch), random_combined, sample)
            generator_loss = G_loss + self.alpha * mse_loss
            generator_loss.backward()
            self.g_optimizer.step()

            self.d_optimizer.zero_grad()
            D_prob = self.discriminator(x_hat.detach(), m_batch)
            D_loss = self.discriminator_loss(m_batch, D_prob)
            D_loss.backward()
            self.d_optimizer.step()

            N = x_batch.shape[0]
            loss_dict["D_loss"].update(D_loss.detach().item(), N)
            loss_dict["G_loss"].update(G_loss.detach().item(), N)
            loss_dict["MSE_loss"].update(mse_loss.detach().item(), N)
            desc = []
            for k, v in loss_dict.items():
                desc.append(f"{k}: {v.avg:.4f}")
            b_loader.set_description(" ".join(desc))
        for k, v in loss_dict.items():
            loss_dict[k] = v.avg
        return loss_dict
Пример #3
0
    def eval_metrics(self,
                     loader,
                     mode="Train",
                     feature_metrics=True,
                     pred_metrics=True):
        result_dict = {}
        if feature_metrics:
            names = [
                mode + "-TPR-Mean", mode + "-TPR-STD", mode + "-FDR-Mean",
                mode + "-FDR-STD"
            ]
            for name in names:
                result_dict[name] = AverageMeter()

        if pred_metrics:
            names = [mode + "-AUC", mode + "-APR", mode + "-ACC"]
            for name in names:
                result_dict[name] = AverageMeter()
        g_hats, y_hats = [], []
        g_trues, y_trues = [], []
        with torch.no_grad():
            for x, y, g in loader:
                x = x.to(self.args.device)
                y_hat = self.model.predict(x).detach().numpy()
                g_hat = self.model.importance_score(x).detach().numpy()
                if pred_metrics:

                    auc, apr, acc = prediction_performance_metric(y, y_hat)
                    result_dict[mode + "-AUC"].update(auc, y.shape[0])
                    result_dict[mode + "-APR"].update(apr, y.shape[0])
                    result_dict[mode + "-ACC"].update(acc, y.shape[0])

                if feature_metrics:
                    importance_score = 1. * (g_hat > 0.5)
                    # Evaluate the performance of feature importance
                    mean_tpr, std_tpr, mean_fdr, std_fdr = feature_performance_metric(
                        g.detach().numpy(), importance_score)
                    result_dict[mode + "-TPR-Mean"].update(
                        mean_tpr, y.shape[0])
                    result_dict[mode + "-TPR-STD"].update(std_tpr, y.shape[0])
                    result_dict[mode + "-FDR-Mean"].update(
                        mean_fdr, y.shape[0])
                    result_dict[mode + "-FDR-STD"].update(std_fdr, y.shape[0])
                g_hats.append(g_hat)
                y_hats.append(y_hat)
                g_trues.append(g.detach().numpy())
                y_trues.append(y.detach().numpy())

        for metric, val in result_dict.items():
            result_dict[metric] = val.avg

        g_hat = np.concatenate(g_hats, axis=0)
        y_hat = np.concatenate(y_hats, axis=0)
        g_true = np.concatenate(g_trues, axis=0)
        y_true = np.concatenate(y_trues, axis=0)
        return result_dict, g_hat, y_hat, g_true, y_true
Пример #4
0
def eval_loop(model, loader, epoch="-", normalization=None, store=None, adv=False):
    acc_meter = AverageMeter()
    iterator = tqdm(iter(loader), total=len(loader), position=0, leave=True)
    model.eval()

    for data, target in iterator:
        data, target = data.cuda(), target.cuda()
        if adv:
            data = utils.L2PGD(model, data, target, normalization,
                               step_size=0.5, Nsteps=20,
                               eps=1.25, targeted=False, use_tqdm=False)

        val = utils.accuracy(model, data, target, normalization)
        acc_meter.update(val, data.shape[0])

        iterator.set_description(f"Epoch: {epoch}, Adv: {adv}, TEST accuracy={acc_meter.avg:.2f}")
        iterator.refresh()
    if store:
        store.tensorboard.add_scalar(f"test_accuracy_{str(adv)}", acc_meter.avg, epoch)
        print(f'test_accuracy_{str(adv)}')
        store['result'].update_row({f'test_accuracy_{str(adv)}': acc_meter.avg,
                                    'epoch': epoch})
Пример #5
0
def _model_loop(args, loop_type, loader, atm, opts, epoch, advs, writer):

    if not loop_type in ['train', 'val']:
        err_msg = "loop_type ({0}) must be 'train' or 'val'".format(loop_type)
        raise ValueError(err_msg)

    is_train = (loop_type == 'train')

    adv_eval, = advs

    prec = 'NatPrec' if not adv_eval else 'AdvPrec'
    loop_msg = 'Train' if loop_type == 'train' else 'Val'

    # switch to train/eval mode depending
    atm = atm.train() if is_train else atm.eval()  # 操!

    # If adv training (or evaling), set eps and random_restarts appropriately
    eps = calc_fadein_eps(epoch, args.eps_fadein_epochs, args.eps) \
        if is_train else args.eps
    random_restarts = 0 if is_train else args.random_restarts

    attack_kwargs = {
        'constraint': args.constraint,
        'eps': eps,
        'step_size': args.attack_lr,
        'iterations': args.attack_steps,
        'random_start': False,
        'random_restarts': random_restarts,
        'use_best': bool(args.use_best)
    }

    if is_train:
        opt_enc, opt_dim_local, opt_dim_global, opt_cla = opts
    else:
        opt_enc, opt_dim_local, opt_dim_global, opt_cla = None, None, None, None

    losses_cla = AverageMeter()
    precs_cla = AverageMeter()
    losses_enc_dim = AverageMeter()

    iterator = tqdm(enumerate(loader), total=len(loader))
    for i, (input, target) in iterator:
        target = target.cuda(non_blocking=True)

        # Compute Loss: eval
        if not is_train:
            # if adv_mi_type == 'lo':
            attack_kwargs['custom_loss'] = partial(
                atm.attacker.model.custom_loss_func, loss_type='dim')
            # elif adv_mi_type == 'up':
            #     attack_kwargs['custom_loss'] = atm.attacker.model.cal_adv_mi_up_loss_dim
            loss_enc_dim, _, _ = atm.forward_custom(
                input=input,
                target=None,  # no need for target in computing mi
                loss_type='dim',
                make_adv=adv_eval,
                detach=True,  # whatever in eval mode
                enc_in_eval=True,
                **attack_kwargs)

            attack_kwargs['custom_loss'] = partial(
                atm.attacker.model.custom_loss_func, loss_type='cla')
            _, loss_cla, prec_cla = atm.forward_custom(input=input,
                                                       target=target,
                                                       loss_type='cla',
                                                       make_adv=adv_eval,
                                                       detach=True,
                                                       enc_in_eval=True,
                                                       **attack_kwargs)

        # Compute Loss: train
        else:
            if args.task == 'estimate-mi':
                target = None
                loss_type = 'dim'
                make_adv = True if args.estimator_loss == 'worst' else False
                detach = True
                enc_in_eval = True

            elif args.task == 'train-encoder':
                target = None
                loss_type = 'dim'
                make_adv = True if args.estimator_loss == 'worst' else False
                detach = True
                enc_in_eval = False

            elif args.task == 'train-classifier':
                target = target
                loss_type = 'cla'
                make_adv = True if args.classifier_loss == 'robust' else False
                detach = True
                enc_in_eval = True

            elif args.task == 'train-model':
                target = target
                loss_type = 'cla'
                make_adv = True if args.classifier_loss == 'robust' else False
                detach = False
                enc_in_eval = False

            else:
                raise NotImplementedError

            attack_kwargs['custom_loss'] = partial(
                atm.attacker.model.custom_loss_func, loss_type=loss_type)
            loss_enc_dim, loss_cla, prec_cla = atm.forward_custom(
                input=input,
                loss_type=loss_type,
                target=target,
                make_adv=make_adv,
                detach=detach,
                enc_in_eval=enc_in_eval,
                **attack_kwargs)

        # Compute gradient and do SGD step
        if is_train:

            if args.task == 'estimate-mi':
                opt_dim_local.zero_grad()
                opt_dim_global.zero_grad()
                loss_enc_dim.backward()
                opt_dim_local.step()
                opt_dim_global.step()

            elif args.task == 'train-encoder':
                opt_enc.zero_grad()
                opt_dim_local.zero_grad()
                opt_dim_global.zero_grad()
                loss_enc_dim.backward()
                opt_enc.step()
                opt_dim_local.step()
                opt_dim_global.step()

            elif args.task == 'train-classifier':
                opt_cla.zero_grad()
                loss_cla.backward()  #retain_graph=True
                opt_cla.step()

            elif args.task == 'train-model':
                opt_enc.zero_grad()
                # opt_cla.zero_grad()
                loss_cla.backward()  #retain_graph=True
                opt_enc.step()
                # opt_cla.step()

            else:
                raise NotImplementedError

        losses_cla.update(loss_cla.item(), input.size(0))
        precs_cla.update(prec_cla.item(), input.size(0))
        losses_enc_dim.update(loss_enc_dim.item(), input.size(0))

        # ITERATOR
        desc = ('{2} Epoch:{0} | '
                'Loss_dim {Loss_dim:.4f} | '
                'Loss_cla {Loss_cla:.4f} | '
                'prec_cla {prec_cla:.3f} |'.format(epoch,
                                                   prec,
                                                   loop_msg,
                                                   Loss_dim=losses_enc_dim.avg,
                                                   Loss_cla=losses_cla.avg,
                                                   prec_cla=precs_cla.avg))

        # USER-DEFINED HOOK
        # if has_attr(args, 'iteration_hook'):
        #     args.iteration_hook(testee, i, loop_type, inp, target)

        iterator.set_description(desc)
        iterator.refresh()

    return precs_cla.avg, losses_enc_dim.avg
Пример #6
0
def main(args):
    path = args.trained_path
    ckpt_path = os.path.join(path, "checkpoint")
    config_path = os.path.join(path, "config.json")
    decode_result_path = os.path.join(path, "decode_results.json")

    # Reload the experiment configurations
    with open(config_path, "r") as fp:
        trainer_args_dict = json.load(fp)
    trainer_args = Namespace(**trainer_args_dict)

    # Get the data
    dim, label_dim, train_loader, test_loader = get_data(trainer_args)
    dim = train_loader.dataset.input_size
    label_dim = train_loader.dataset.output_size

    # Load from the checkpoint
    trainer = INVASETrainer(dim, label_dim, trainer_args, path)
    trainer = load_ckpt(trainer, ckpt_path)

    # Construct the decoder
    decoder = LinearDecoder(dim)
    optimizer = optim.Adam(decoder.parameters(), 0.1, weight_decay=1e-4)
    loss_fn = nn.MSELoss()

    # Obtain these parameters to undo normalization
    mean = torch.tensor(train_loader.dataset.means)
    std = torch.tensor(train_loader.dataset.stds)

    # Tuning the decoder
    for i in range(args.decoder_epochs):
        MSE = AverageMeter()
        b_loader = tqdm(train_loader)
        trainer.model.eval()
        for x_batch, y_batch, _ in b_loader:
            b_loader.set_description(f"EpochProvision: DecodingMSE: {MSE.avg}")

            x_batch, y_batch = x_batch.to(args.device), y_batch.to(args.device)
            optimizer.zero_grad()
            # Generate a batch of selections
            selection_probability = trainer.model(x_batch,
                                                  fw_module="selector")
            # Predictor objective
            used, reconstruction = decoder(selection_probability, x_batch)

            # Convert to pixels space
            reconstruction = reconstruction * std + mean
            x_batch = x_batch * std + mean
            loss = loss_fn(reconstruction, x_batch)
            MSE.update(loss.detach().item(), y_batch.shape[0])
            loss.backward()
            optimizer.step()

        if (i + 1) % args.eval_freq == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] / 2

    fig, axs = plt.subplots(N_IMAGES, 3, figsize=(10, 5))
    flat_shape = x_batch.shape[1]
    img_dim = int(np.sqrt(flat_shape))
    for i in range(N_IMAGES):
        im = x_batch[i].detach().numpy().reshape((img_dim, img_dim))
        im_rec = reconstruction[i].detach().numpy().reshape((img_dim, img_dim))
        im_chosen = used[i].detach().numpy().reshape((img_dim, img_dim))
        axs[i][0].imshow(im)
        axs[i][1].imshow(im_rec)
        axs[i][2].imshow(im_chosen)
        axs[i][0].set_axis_off()
        axs[i][1].set_axis_off()
        axs[i][2].set_axis_off()

    axs[0][0].set_title("Original Image", fontsize=18)
    axs[0][1].set_title("Reconstructed Image", fontsize=18)
    axs[0][2].set_title("Chosen Pixels", fontsize=18)

    fig.savefig(os.path.join(path, "reconstruction_viz.pdf"))
    fig.savefig(os.path.join(path, "reconstruction_viz.png"))
    plt.close(fig)

    MSE = AverageMeter()
    modes = [("Train", train_loader), ("Test", test_loader)]
    decoder.eval()
    trainer.model.eval()
    result = dict()
    for mode, loader in modes:
        b_loader = tqdm(loader)
        for x_batch, y_batch, _ in b_loader:
            b_loader.set_description(f"EpochProvision: DecodingMSE: {MSE.avg}")
            x_batch, y_batch = x_batch.to(args.device), y_batch.to(args.device)
            selection_probability = trainer.model(x_batch,
                                                  fw_module="selector")
            used, reconstruction = decoder(selection_probability, x_batch)
            reconstruction = reconstruction * std + mean
            x_batch = x_batch * std + mean
            loss = loss_fn(reconstruction, x_batch)
            MSE.update(loss.detach().item(), y_batch.shape[0])

        print(f"{mode} Final: ", MSE.avg)
        result[mode] = MSE.avg

    with open(decode_result_path, "w") as fp:
        json.dump(result, fp)
Пример #7
0
    def train_step(self, train_loader):
        device = self.args.device
        self.model.train()

        CriticAcc = AverageMeter()
        BaselineAcc = AverageMeter()
        ActorLoss = AverageMeter()

        b_loader = tqdm(train_loader)
        for x_batch, y_batch, _ in b_loader:
            b_loader.set_description(
                f"EpochProvision: Critic: {CriticAcc.avg}, Baseline: {BaselineAcc.avg}, Actor: {ActorLoss.avg}"
            )
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            # Select a random batch of samples
            self.optimizer.zero_grad()
            labels = torch.argmax(y_batch, dim=1).long()
            # Generate a batch of selections
            selection_probability = self.model(x_batch, fw_module="selector")
            selection = torch.bernoulli(selection_probability).detach()

            # Predictor objective
            critic_input = x_batch * selection
            critic_out = self.model(critic_input, fw_module="predictor")
            critic_loss = self.critic_loss(critic_out, labels)
            # Baseline objective
            baseline_out = self.model(x_batch, fw_module="baseline")
            baseline_loss = self.baseline_loss(baseline_out, labels)

            batch_data = torch.cat([
                selection.clone().detach(),
                self.softmax(critic_out).clone().detach(),
                self.softmax(baseline_out).clone().detach(),
                y_batch.float()
            ],
                                   dim=1)

            # Actor objective
            actor_output = self.model(x_batch, fw_module="selector")
            actor_loss = self.actor_loss(batch_data, actor_output)

            total_loss = actor_loss + critic_loss + baseline_loss
            total_loss.backward()
            self.optimizer.step()

            N = labels.shape[0]
            critic_acc = accuracy(critic_out, labels)[0]
            baseline_acc = accuracy(baseline_out, labels)[0]
            CriticAcc.update(critic_acc.detach().item(), N)
            BaselineAcc.update(baseline_acc.detach().item(), N)
            ActorLoss.update(actor_loss.detach().item(), N)

        summary = {
            "CriticAcc": CriticAcc.avg,
            "BaselineAcc": BaselineAcc.avg,
            "ActorLoss": ActorLoss.avg
        }

        return summary