예제 #1
0
def train_alp(model, device, train_loader, optimizer, epoch, train_losses):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        data.requires_grad_(True)
        v = torch.zeros_like(data)
        xv = (data, v)

        def adv_loss(x, y=target):
            return -F.nll_loss(model(x), y)

        xx, mmsgf = fgsmm(adv_loss, xv, T=1, lr=0.075, gamma=0.)

        output = model(data)

        loss = F.nll_loss(output, target) + torch.mean(
            torch.abs(model.logits(data) - model.logits(xx[0])))

        loss.backward()
        optimizer.step()
        if batch_idx % 1000 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
예제 #2
0
def train_arapx(model, device, train_loader, optimizer, epoch, train_losses):
    def energy(x):
        return -torch.logsumexp(model.logits(x), dim=1)

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        data.requires_grad_(True)
        v = torch.zeros_like(data)
        xv = (data, v)

        def adv_loss(x, y=target):
            return -F.nll_loss(model(x), y)

        xxs = []
        N = 1
        # To do N > 1, we would have to use a Bayesian model to avoid overfitting.
        for _ in range(N):
            T = 1 + np.random.poisson(1)
            #T = 5
            lr_min, lr_max = 0.05, 0.15
            lr_a = np.random.beta(1, 1) * (lr_max - lr_min) + lr_min
            #lr = 0.05
            gamma = 0.0

            xx, mmsgf = fgsmm(adv_loss, xv, T=T, lr=lr_a, gamma=gamma)
            xxs.append(xx[0])

        xx = torch.cat(tuple(xxs), 0)

        output = model(torch.cat((data, xx), 0))

        loss = F.nll_loss(output, torch.cat(
            (target, target.repeat(N)),
            0)) + torch.abs(energy(xx).mean() - energy(data).mean())

        loss.backward()
        optimizer.step()
        if batch_idx % 1000 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())