def train_alp(model, device, train_loader, optimizer, epoch, train_losses): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() data.requires_grad_(True) v = torch.zeros_like(data) xv = (data, v) def adv_loss(x, y=target): return -F.nll_loss(model(x), y) xx, mmsgf = fgsmm(adv_loss, xv, T=1, lr=0.075, gamma=0.) output = model(data) loss = F.nll_loss(output, target) + torch.mean( torch.abs(model.logits(data) - model.logits(xx[0]))) loss.backward() optimizer.step() if batch_idx % 1000 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) train_losses.append(loss.item())
def train_arapx(model, device, train_loader, optimizer, epoch, train_losses): def energy(x): return -torch.logsumexp(model.logits(x), dim=1) model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() data.requires_grad_(True) v = torch.zeros_like(data) xv = (data, v) def adv_loss(x, y=target): return -F.nll_loss(model(x), y) xxs = [] N = 1 # To do N > 1, we would have to use a Bayesian model to avoid overfitting. for _ in range(N): T = 1 + np.random.poisson(1) #T = 5 lr_min, lr_max = 0.05, 0.15 lr_a = np.random.beta(1, 1) * (lr_max - lr_min) + lr_min #lr = 0.05 gamma = 0.0 xx, mmsgf = fgsmm(adv_loss, xv, T=T, lr=lr_a, gamma=gamma) xxs.append(xx[0]) xx = torch.cat(tuple(xxs), 0) output = model(torch.cat((data, xx), 0)) loss = F.nll_loss(output, torch.cat( (target, target.repeat(N)), 0)) + torch.abs(energy(xx).mean() - energy(data).mean()) loss.backward() optimizer.step() if batch_idx % 1000 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) train_losses.append(loss.item())