Ejemplo n.º 1
0
def main():
    model = ResNet18().to('cuda')
    model_weights = torch.load(MODEL_LOAD_PATH).copy()
    """
    Uncomment following block if you trained a network on newer versions of PyTorch 
    and are now copying the .pt file to your old machine
    all_keys = model_weights.items()
    valid_weights = OrderedDict()
    print(type(all_keys))
    for i, (k,v) in enumerate(all_keys):
        if 'num_batches_tracked' in k:
            print('Found num_batches_tracked')
        else:
            valid_weights[k] = v 
    """
    model.load_state_dict(model_weights)
    model.eval()
    _, _, test_loader = utils.get_cifar10_data_loaders()
    correct = 0
    for i, (images, labels) in enumerate(test_loader):
        images, labels = images.to('cuda'), labels.to('cuda')
        logits = model(images)
        correct += (torch.max(logits, 1)[-1] == labels).sum().item()
        utils.progress(i + 1, len(test_loader),
                       'Batch [{}/{}]'.format(i + 1, len(test_loader)))
    print('Accuracy on test set of CIFAR10 = {}%'.format(
        float(correct) * 100.0 / 10000))
    reprogrammed = ReProgramCIFAR10ToMNIST(model)
    save_tensor = reprogrammed.weight_matrix * reprogrammed.reprogram_weights
    torchvision.utils.save_image(save_tensor.view(1, 3, 32, 32),
                                 SAVE_DIR + 'reprogram_init.png')
    train_loader, val_loader, test_loader = utils.get_mnist_data_loaders()
    """
    These parameters seem to be working best
    Feel free to play around these values
    """
    optim = torch.optim.SGD([reprogrammed.reprogram_weights],
                            lr=1e-1,
                            momentum=0.9)
    xent = nn.CrossEntropyLoss()
    n_epochs = 64
    for epoch in range(n_epochs):
        print('Epoch {}'.format(epoch + 1))
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to('cuda'), labels.to('cuda')
            logits = reprogrammed(images)
            logits = logits.to('cuda')
            optim.zero_grad()
            loss = xent(logits, labels)
            loss.backward()
            optim.step()
            #reprogrammed.visualize_adversarial_program()
            utils.progress(i+1, len(train_loader), 'Batch [{}/{}] Loss = {} Batch Acc = {}%'.format(i+1, len(train_loader), loss.item(),\
                ((torch.max(logits, 1)[-1] == labels).sum().item() * 100.0/images.size(0))))
        reprogrammed.visualize(images)
    correct = 0
    reprogrammed.eval()
    for i, (images, labels) in enumerate(test_loader):
        images, labels = images.to('cuda'), labels.to('cuda')
        logits = reprogrammed(images)
        logits = logits.to('cuda')
        correct += (torch.max(logits, 1)[-1] == labels).sum().item()
        utils.progress(i + 1, len(test_loader),
                       'Batch [{}/{}]'.format(i + 1, len(test_loader)))
    print('Accuracy on MNIST test set = {}%'.format(
        float(correct) * 100.0 / 10000))
    print('Done')
Ejemplo n.º 2
0
sub = SmallCNN().cuda()
sub.load_state_dict(torch.load('./substitute_models/mnist_trades.pt'))
sub.eval()

adversaries = [
    GradientSignAttack(model, nn.CrossEntropyLoss(size_average=False),
                       eps=0.3),
    GradientSignAttack(sub, nn.CrossEntropyLoss(size_average=False), eps=0.3),
    LinfBasicIterativeAttack(model,
                             nn.CrossEntropyLoss(size_average=False),
                             eps=0.3,
                             nb_iter=40,
                             eps_iter=0.01),
    LinfBasicIterativeAttack(sub,
                             nn.CrossEntropyLoss(size_average=False),
                             eps=0.3,
                             nb_iter=40,
                             eps_iter=0.01)
]
_, _, test_loader = get_mnist_data_loaders()
for adversary in adversaries:
    correct_adv = 0
    for i, (x_batch, y_batch) in enumerate(test_loader):
        x_batch, y_batch = x_batch.cuda(), y_batch.cuda()
        adv_x_batch = adversary.perturb(x_batch, y_batch)
        logits = model(adv_x_batch)
        _, preds = torch.max(logits, 1)
        correct_adv += (preds == y_batch).sum().item()
        progress(i + 1, len(test_loader),
                 'correct_adv = {}'.format(correct_adv))
Ejemplo n.º 3
0
from torchvision.models import resnet18

# module from the example
from utils import get_mnist_data_loaders

seed = 12
debug = False
train_batch_size = 128
val_batch_size = 512


train_transform = Compose([RandomHorizontalFlip(), ToTensor(), Normalize((0.1307,), (0.3081,))])
val_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])


path = os.getenv("DATASET_PATH", "/tmp/mnist")
train_loader, val_loader = get_mnist_data_loaders(
    path, train_transform, train_batch_size, val_transform, val_batch_size
)

model = resnet18(num_classes=10)
model.conv1 = nn.Conv2d(1, 64, 3)

learning_rate = 0.01

optimizer = SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

num_epochs = 5
val_interval = 2