def train_loop(model, loader, optimizer, criterion, epoch="-", normalization=None, store=None, adv=False): acc_meter = AverageMeter() # iterator = tqdm(iter(loader), total=len(loader), position=0, leave=True) for data, target in loader: data, target = data.cuda(), target.cuda() model.train() if adv: data = utils.L2PGD(model, data, target, normalization, step_size=0.5, Nsteps=20, eps=1.25, targeted=False, use_tqdm=False) optimizer.zero_grad() logits = utils.forward_pass(model, data, normalization) loss = criterion(logits, target) loss.backward() optimizer.step() with torch.no_grad(): model.eval() val = utils.accuracy(model, data, target, normalization) acc_meter.update(val, data.shape[0]) # Commented out to reduce the amount of logs in Colab # iterator.set_description(f"Epoch: {epoch}, Adv: {adv}, Train accuracy={acc_meter.avg:.2f}") # iterator.refresh() if store: store.tensorboard.add_scalar("train_accuracy", acc_meter.avg, epoch)
def train_step(self, loader): self.discriminator.train() self.generator.train() device = self.args.device loss_dict = dict() loss_dict["D_loss"] = AverageMeter() loss_dict["G_loss"] = AverageMeter() loss_dict["MSE_loss"] = AverageMeter() b_loader = tqdm(loader) for _, x_batch, _, m_batch in b_loader: x_batch, m_batch = x_batch.to(device), m_batch.to(device) self.g_optimizer.zero_grad() sample, random_combined, x_hat = self.generator(x_batch, m_batch) G_loss, mse_loss = self.generator_loss(m_batch, self.discriminator(x_hat, m_batch), random_combined, sample) generator_loss = G_loss + self.alpha * mse_loss generator_loss.backward() self.g_optimizer.step() self.d_optimizer.zero_grad() D_prob = self.discriminator(x_hat.detach(), m_batch) D_loss = self.discriminator_loss(m_batch, D_prob) D_loss.backward() self.d_optimizer.step() N = x_batch.shape[0] loss_dict["D_loss"].update(D_loss.detach().item(), N) loss_dict["G_loss"].update(G_loss.detach().item(), N) loss_dict["MSE_loss"].update(mse_loss.detach().item(), N) desc = [] for k, v in loss_dict.items(): desc.append(f"{k}: {v.avg:.4f}") b_loader.set_description(" ".join(desc)) for k, v in loss_dict.items(): loss_dict[k] = v.avg return loss_dict
def eval_metrics(self, loader, mode="Train", feature_metrics=True, pred_metrics=True): result_dict = {} if feature_metrics: names = [ mode + "-TPR-Mean", mode + "-TPR-STD", mode + "-FDR-Mean", mode + "-FDR-STD" ] for name in names: result_dict[name] = AverageMeter() if pred_metrics: names = [mode + "-AUC", mode + "-APR", mode + "-ACC"] for name in names: result_dict[name] = AverageMeter() g_hats, y_hats = [], [] g_trues, y_trues = [], [] with torch.no_grad(): for x, y, g in loader: x = x.to(self.args.device) y_hat = self.model.predict(x).detach().numpy() g_hat = self.model.importance_score(x).detach().numpy() if pred_metrics: auc, apr, acc = prediction_performance_metric(y, y_hat) result_dict[mode + "-AUC"].update(auc, y.shape[0]) result_dict[mode + "-APR"].update(apr, y.shape[0]) result_dict[mode + "-ACC"].update(acc, y.shape[0]) if feature_metrics: importance_score = 1. * (g_hat > 0.5) # Evaluate the performance of feature importance mean_tpr, std_tpr, mean_fdr, std_fdr = feature_performance_metric( g.detach().numpy(), importance_score) result_dict[mode + "-TPR-Mean"].update( mean_tpr, y.shape[0]) result_dict[mode + "-TPR-STD"].update(std_tpr, y.shape[0]) result_dict[mode + "-FDR-Mean"].update( mean_fdr, y.shape[0]) result_dict[mode + "-FDR-STD"].update(std_fdr, y.shape[0]) g_hats.append(g_hat) y_hats.append(y_hat) g_trues.append(g.detach().numpy()) y_trues.append(y.detach().numpy()) for metric, val in result_dict.items(): result_dict[metric] = val.avg g_hat = np.concatenate(g_hats, axis=0) y_hat = np.concatenate(y_hats, axis=0) g_true = np.concatenate(g_trues, axis=0) y_true = np.concatenate(y_trues, axis=0) return result_dict, g_hat, y_hat, g_true, y_true
def eval_loop(model, loader, epoch="-", normalization=None, store=None, adv=False): acc_meter = AverageMeter() iterator = tqdm(iter(loader), total=len(loader), position=0, leave=True) model.eval() for data, target in iterator: data, target = data.cuda(), target.cuda() if adv: data = utils.L2PGD(model, data, target, normalization, step_size=0.5, Nsteps=20, eps=1.25, targeted=False, use_tqdm=False) val = utils.accuracy(model, data, target, normalization) acc_meter.update(val, data.shape[0]) iterator.set_description(f"Epoch: {epoch}, Adv: {adv}, TEST accuracy={acc_meter.avg:.2f}") iterator.refresh() if store: store.tensorboard.add_scalar(f"test_accuracy_{str(adv)}", acc_meter.avg, epoch) print(f'test_accuracy_{str(adv)}') store['result'].update_row({f'test_accuracy_{str(adv)}': acc_meter.avg, 'epoch': epoch})
def _model_loop(args, loop_type, loader, atm, opts, epoch, advs, writer): if not loop_type in ['train', 'val']: err_msg = "loop_type ({0}) must be 'train' or 'val'".format(loop_type) raise ValueError(err_msg) is_train = (loop_type == 'train') adv_eval, = advs prec = 'NatPrec' if not adv_eval else 'AdvPrec' loop_msg = 'Train' if loop_type == 'train' else 'Val' # switch to train/eval mode depending atm = atm.train() if is_train else atm.eval() # 操! # If adv training (or evaling), set eps and random_restarts appropriately eps = calc_fadein_eps(epoch, args.eps_fadein_epochs, args.eps) \ if is_train else args.eps random_restarts = 0 if is_train else args.random_restarts attack_kwargs = { 'constraint': args.constraint, 'eps': eps, 'step_size': args.attack_lr, 'iterations': args.attack_steps, 'random_start': False, 'random_restarts': random_restarts, 'use_best': bool(args.use_best) } if is_train: opt_enc, opt_dim_local, opt_dim_global, opt_cla = opts else: opt_enc, opt_dim_local, opt_dim_global, opt_cla = None, None, None, None losses_cla = AverageMeter() precs_cla = AverageMeter() losses_enc_dim = AverageMeter() iterator = tqdm(enumerate(loader), total=len(loader)) for i, (input, target) in iterator: target = target.cuda(non_blocking=True) # Compute Loss: eval if not is_train: # if adv_mi_type == 'lo': attack_kwargs['custom_loss'] = partial( atm.attacker.model.custom_loss_func, loss_type='dim') # elif adv_mi_type == 'up': # attack_kwargs['custom_loss'] = atm.attacker.model.cal_adv_mi_up_loss_dim loss_enc_dim, _, _ = atm.forward_custom( input=input, target=None, # no need for target in computing mi loss_type='dim', make_adv=adv_eval, detach=True, # whatever in eval mode enc_in_eval=True, **attack_kwargs) attack_kwargs['custom_loss'] = partial( atm.attacker.model.custom_loss_func, loss_type='cla') _, loss_cla, prec_cla = atm.forward_custom(input=input, target=target, loss_type='cla', make_adv=adv_eval, detach=True, enc_in_eval=True, **attack_kwargs) # Compute Loss: train else: if args.task == 'estimate-mi': target = None loss_type = 'dim' make_adv = True if args.estimator_loss == 'worst' else False detach = True enc_in_eval = True elif args.task == 'train-encoder': target = None loss_type = 'dim' make_adv = True if args.estimator_loss == 'worst' else False detach = True enc_in_eval = False elif args.task == 'train-classifier': target = target loss_type = 'cla' make_adv = True if args.classifier_loss == 'robust' else False detach = True enc_in_eval = True elif args.task == 'train-model': target = target loss_type = 'cla' make_adv = True if args.classifier_loss == 'robust' else False detach = False enc_in_eval = False else: raise NotImplementedError attack_kwargs['custom_loss'] = partial( atm.attacker.model.custom_loss_func, loss_type=loss_type) loss_enc_dim, loss_cla, prec_cla = atm.forward_custom( input=input, loss_type=loss_type, target=target, make_adv=make_adv, detach=detach, enc_in_eval=enc_in_eval, **attack_kwargs) # Compute gradient and do SGD step if is_train: if args.task == 'estimate-mi': opt_dim_local.zero_grad() opt_dim_global.zero_grad() loss_enc_dim.backward() opt_dim_local.step() opt_dim_global.step() elif args.task == 'train-encoder': opt_enc.zero_grad() opt_dim_local.zero_grad() opt_dim_global.zero_grad() loss_enc_dim.backward() opt_enc.step() opt_dim_local.step() opt_dim_global.step() elif args.task == 'train-classifier': opt_cla.zero_grad() loss_cla.backward() #retain_graph=True opt_cla.step() elif args.task == 'train-model': opt_enc.zero_grad() # opt_cla.zero_grad() loss_cla.backward() #retain_graph=True opt_enc.step() # opt_cla.step() else: raise NotImplementedError losses_cla.update(loss_cla.item(), input.size(0)) precs_cla.update(prec_cla.item(), input.size(0)) losses_enc_dim.update(loss_enc_dim.item(), input.size(0)) # ITERATOR desc = ('{2} Epoch:{0} | ' 'Loss_dim {Loss_dim:.4f} | ' 'Loss_cla {Loss_cla:.4f} | ' 'prec_cla {prec_cla:.3f} |'.format(epoch, prec, loop_msg, Loss_dim=losses_enc_dim.avg, Loss_cla=losses_cla.avg, prec_cla=precs_cla.avg)) # USER-DEFINED HOOK # if has_attr(args, 'iteration_hook'): # args.iteration_hook(testee, i, loop_type, inp, target) iterator.set_description(desc) iterator.refresh() return precs_cla.avg, losses_enc_dim.avg
def main(args): path = args.trained_path ckpt_path = os.path.join(path, "checkpoint") config_path = os.path.join(path, "config.json") decode_result_path = os.path.join(path, "decode_results.json") # Reload the experiment configurations with open(config_path, "r") as fp: trainer_args_dict = json.load(fp) trainer_args = Namespace(**trainer_args_dict) # Get the data dim, label_dim, train_loader, test_loader = get_data(trainer_args) dim = train_loader.dataset.input_size label_dim = train_loader.dataset.output_size # Load from the checkpoint trainer = INVASETrainer(dim, label_dim, trainer_args, path) trainer = load_ckpt(trainer, ckpt_path) # Construct the decoder decoder = LinearDecoder(dim) optimizer = optim.Adam(decoder.parameters(), 0.1, weight_decay=1e-4) loss_fn = nn.MSELoss() # Obtain these parameters to undo normalization mean = torch.tensor(train_loader.dataset.means) std = torch.tensor(train_loader.dataset.stds) # Tuning the decoder for i in range(args.decoder_epochs): MSE = AverageMeter() b_loader = tqdm(train_loader) trainer.model.eval() for x_batch, y_batch, _ in b_loader: b_loader.set_description(f"EpochProvision: DecodingMSE: {MSE.avg}") x_batch, y_batch = x_batch.to(args.device), y_batch.to(args.device) optimizer.zero_grad() # Generate a batch of selections selection_probability = trainer.model(x_batch, fw_module="selector") # Predictor objective used, reconstruction = decoder(selection_probability, x_batch) # Convert to pixels space reconstruction = reconstruction * std + mean x_batch = x_batch * std + mean loss = loss_fn(reconstruction, x_batch) MSE.update(loss.detach().item(), y_batch.shape[0]) loss.backward() optimizer.step() if (i + 1) % args.eval_freq == 0: for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] / 2 fig, axs = plt.subplots(N_IMAGES, 3, figsize=(10, 5)) flat_shape = x_batch.shape[1] img_dim = int(np.sqrt(flat_shape)) for i in range(N_IMAGES): im = x_batch[i].detach().numpy().reshape((img_dim, img_dim)) im_rec = reconstruction[i].detach().numpy().reshape((img_dim, img_dim)) im_chosen = used[i].detach().numpy().reshape((img_dim, img_dim)) axs[i][0].imshow(im) axs[i][1].imshow(im_rec) axs[i][2].imshow(im_chosen) axs[i][0].set_axis_off() axs[i][1].set_axis_off() axs[i][2].set_axis_off() axs[0][0].set_title("Original Image", fontsize=18) axs[0][1].set_title("Reconstructed Image", fontsize=18) axs[0][2].set_title("Chosen Pixels", fontsize=18) fig.savefig(os.path.join(path, "reconstruction_viz.pdf")) fig.savefig(os.path.join(path, "reconstruction_viz.png")) plt.close(fig) MSE = AverageMeter() modes = [("Train", train_loader), ("Test", test_loader)] decoder.eval() trainer.model.eval() result = dict() for mode, loader in modes: b_loader = tqdm(loader) for x_batch, y_batch, _ in b_loader: b_loader.set_description(f"EpochProvision: DecodingMSE: {MSE.avg}") x_batch, y_batch = x_batch.to(args.device), y_batch.to(args.device) selection_probability = trainer.model(x_batch, fw_module="selector") used, reconstruction = decoder(selection_probability, x_batch) reconstruction = reconstruction * std + mean x_batch = x_batch * std + mean loss = loss_fn(reconstruction, x_batch) MSE.update(loss.detach().item(), y_batch.shape[0]) print(f"{mode} Final: ", MSE.avg) result[mode] = MSE.avg with open(decode_result_path, "w") as fp: json.dump(result, fp)
def train_step(self, train_loader): device = self.args.device self.model.train() CriticAcc = AverageMeter() BaselineAcc = AverageMeter() ActorLoss = AverageMeter() b_loader = tqdm(train_loader) for x_batch, y_batch, _ in b_loader: b_loader.set_description( f"EpochProvision: Critic: {CriticAcc.avg}, Baseline: {BaselineAcc.avg}, Actor: {ActorLoss.avg}" ) x_batch, y_batch = x_batch.to(device), y_batch.to(device) # Select a random batch of samples self.optimizer.zero_grad() labels = torch.argmax(y_batch, dim=1).long() # Generate a batch of selections selection_probability = self.model(x_batch, fw_module="selector") selection = torch.bernoulli(selection_probability).detach() # Predictor objective critic_input = x_batch * selection critic_out = self.model(critic_input, fw_module="predictor") critic_loss = self.critic_loss(critic_out, labels) # Baseline objective baseline_out = self.model(x_batch, fw_module="baseline") baseline_loss = self.baseline_loss(baseline_out, labels) batch_data = torch.cat([ selection.clone().detach(), self.softmax(critic_out).clone().detach(), self.softmax(baseline_out).clone().detach(), y_batch.float() ], dim=1) # Actor objective actor_output = self.model(x_batch, fw_module="selector") actor_loss = self.actor_loss(batch_data, actor_output) total_loss = actor_loss + critic_loss + baseline_loss total_loss.backward() self.optimizer.step() N = labels.shape[0] critic_acc = accuracy(critic_out, labels)[0] baseline_acc = accuracy(baseline_out, labels)[0] CriticAcc.update(critic_acc.detach().item(), N) BaselineAcc.update(baseline_acc.detach().item(), N) ActorLoss.update(actor_loss.detach().item(), N) summary = { "CriticAcc": CriticAcc.avg, "BaselineAcc": BaselineAcc.avg, "ActorLoss": ActorLoss.avg } return summary