def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    p = None
    param12 = 0.001
    param3 = 0.001
    param4 = 1e-1
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs, mat, feature = net(inputs)
        true_loss = criterion(outputs, targets)

        # doubly stochastic constraint
        dsc = 0
        for i in range(64):
           dsc += torch.abs(mat[i,:]).sum()-torch.sqrt((mat[i,:]*mat[i,:]).sum())
           dsc += torch.abs(mat[:,i]).sum()-torch.sqrt((mat[:,i]*mat[:,i]).sum())

        dsc = param3 * dsc / (64*64) 
        natural_image_prior = param4 * total_variation_norm(feature) / inputs.size()[0]

        loss = true_loss +  dsc + natural_image_prior
        loss.backward()
        optimizer.step()

        train_loss += true_loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return train_loss, 100.*correct/total
Beispiel #2
0
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    p = None
    param12 = 0.001
    param3 = 0.001
    param4 = 1e-1
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        x_stack = None
        imgs = inputs.numpy().astype('float32')

        # block scrambling
        x_stack = blockwise_scramble(imgs)
        imgs = np.transpose(x_stack, (0, 3, 1, 2))

        # block shuffling
        x_stack = block_location_shuffle(_shf, imgs)
        imgs = x_stack
        inputs = torch.from_numpy(imgs)

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs, mat, feature = net(inputs)
        true_loss = criterion(outputs, targets)

        # doubly stochastic constraint
        dsc = 0
        for i in range(64):
            dsc += torch.abs(mat[i, :]).sum() - torch.sqrt(
                (mat[i, :] * mat[i, :]).sum())
            dsc += torch.abs(mat[:, i]).sum() - torch.sqrt(
                (mat[:, i] * mat[:, i]).sum())

        dsc = param3 * dsc / (64 * 64)
        natural_image_prior = param4 * total_variation_norm(
            feature) / inputs.size()[0]

        loss = true_loss + dsc + natural_image_prior
        loss.backward()
        optimizer.step()

        train_loss += true_loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return train_loss, 100. * correct / total