def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    p = None
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        x_stack = None
        imgs = inputs.numpy().astype('float32')

        # block scrambling
        x_stack = blockwise_scramble(imgs)
        imgs = np.transpose(x_stack,(0,3,1,2))

        # block shuffling
        x_stack = block_location_shuffle(_shf,imgs)
        imgs  = x_stack
        inputs = torch.from_numpy(imgs)

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs = net(inputs)
        true_loss = criterion(outputs, targets)

        loss = true_loss
        loss.backward()
        optimizer.step()

        train_loss += true_loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return train_loss, 100.*correct/total
Ejemplo n.º 2
0
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    p = None
    param12 = 0.001
    param3 = 0.001
    param4 = 1e-1
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        x_stack = None
        imgs = inputs.numpy().astype('float32')

        # block scrambling
        x_stack = blockwise_scramble(imgs)
        imgs = np.transpose(x_stack, (0, 3, 1, 2))

        # block shuffling
        x_stack = block_location_shuffle(_shf, imgs)
        imgs = x_stack
        inputs = torch.from_numpy(imgs)

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs, mat, feature = net(inputs)
        true_loss = criterion(outputs, targets)

        # doubly stochastic constraint
        dsc = 0
        for i in range(64):
            dsc += torch.abs(mat[i, :]).sum() - torch.sqrt(
                (mat[i, :] * mat[i, :]).sum())
            dsc += torch.abs(mat[:, i]).sum() - torch.sqrt(
                (mat[:, i] * mat[:, i]).sum())

        dsc = param3 * dsc / (64 * 64)
        natural_image_prior = param4 * total_variation_norm(
            feature) / inputs.size()[0]

        loss = true_loss + dsc + natural_image_prior
        loss.backward()
        optimizer.step()

        train_loss += true_loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return train_loss, 100. * correct / total
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            x_stack = None
            imgs = inputs.numpy().astype('float32')

            # block scrambling
            x_stack = blockwise_scramble(imgs)
            imgs = np.transpose(x_stack,(0,3,1,2))

            # block shuffling
            x_stack = block_location_shuffle(_shf,imgs)
            imgs  = x_stack
            inputs = torch.from_numpy(imgs)

            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            
            outputs = net(inputs)
            true_loss = criterion(outputs, targets)

            test_loss += true_loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
       
    # Save checkpoint.
    acc = 100.*correct/total
    if best_acc < acc:
        # print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state,'./'+args.training_model_name)
        best_acc = acc
    return test_loss, 100.*correct/total