class Model(): def __init__(self, num_epochs=5, num_classes=10, batch_size=100, learning_rate=0.001): self.num_epochs = num_epochs self.num_classes = num_classes self.batch_size = batch_size self.learning_rate = learning_rate self.model = ConvNet(num_classes) # Loss and optimizer self.criterion = nn.CrossEntropyLoss() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) def train(self, train_loader): total_step = len(train_loader) for epoch in range(self.num_epochs): for i, (images, labels) in enumerate(train_loader): # Forward pass outputs = self.model(images) loss = self.criterion(outputs, labels) # Backward and optimize self.optimizer.zero_grad() loss.backward() self.optimizer.step() if (i + 1) % 100 == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( epoch + 1, self.num_epochs, i + 1, total_step, loss.item())) def eval(self, test_loader): self.model.eval() with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: outputs = self.model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() def save(self): # Save the model checkpoint torch.save(self.model.state_dict(), 'model.ckpt')
# datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([ # transforms.ToTensor(), # ])), # batch_size=1, shuffle=False, sampler=torch.utils.data.SubsetRandomSampler(list( # range(100)))) # Define what device we are using print("CUDA Available: ", torch.cuda.is_available()) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pretrained_model = "models/cnn_mnist.ckpt" model = ConvNet().to(device) # pretrained_model = "models/lenet_mnist_model.pth" # model = Net().to(device) model.load_state_dict(torch.load(pretrained_model, map_location='cpu')) model.eval() gp_model, likelihood = load_combined_model('models/gp_mnist.dat') gp_model.eval() likelihood.eval() # FGSM attack code def fgsm_attack(image, epsilon, data_grad): # Collect the element-wise sign of the data gradient sign_data_grad = data_grad.sign() # Create the perturbed image by adjusting each pixel of the input image perturbed_image = image + epsilon * sign_data_grad # Adding clipping to maintain [0,1] range perturbed_image = torch.clamp(perturbed_image, 0, 1) # Return the perturbed image