class Solver(object): def __init__(self, config): self.model = None self.name = config.name self.lr = config.lr self.momentum = config.momentum self.beta = config.beta self.max_alpha = config.max_alpha self.epochs = config.epochs self.patience = config.patience self.N = config.N self.batch_size = config.batch_size self.random_labels = config.random_labels self.use_bn = config.batchnorm self.criterion = None self.optimizer = None self.scheduler = None self.device = None self.cuda = config.cuda self.train_loader = None self.test_loader = None def load_data(self): # ToTensor scales pixel values from [0,255] to [0,1] mean_var = (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255, 66.7 / 255) transform = transforms.Compose([ transforms.CenterCrop(28), transforms.ToTensor(), transforms.Normalize(*mean_var, inplace=True) ]) train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=DOWNLOAD, transform=transform) test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=DOWNLOAD, transform=transform) if self.random_labels: np.random.shuffle(train_set.targets) np.random.shuffle(test_set.targets) assert self.N <= 50000 if self.N < 50000: train_set.data = train_set.data[:self.N] # downsize the test set to improve speed for small N test_set.data = test_set.data[:self.N] self.train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=self.batch_size, shuffle=True, drop_last=True) self.test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=self.batch_size, shuffle=False, drop_last=True) def load_model(self): if self.cuda: self.device = torch.device('cuda') cudnn.benchmark = True else: self.device = torch.device('cpu') self.model = AlexNet(device=self.device, B=self.batch_size, max_alpha=self.max_alpha, use_bn=self.use_bn).to(self.device) self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=140) self.criterion = nn.NLLLoss().to(self.device) def getIw(self): # Iw should be normalized with respect to N # via reparameterization, we optimize alpha with only 1920 dimensions # but Iw should scale with the dimension of the weights return 7 * 7 * 64 * 384 / 1920 * self.model.getIw() / self.batch_size def do_batch(self, train, epoch): loader = self.train_loader if train else self.test_loader total_ce, total_Iw, total_loss = 0, 0, 0 total_correct = 0 total = 0 pbar = tqdm(loader) num_batches = len(loader) for batch_num, (data, target) in enumerate(pbar): data, target = data.to(self.device), target.to(self.device) if train: self.optimizer.zero_grad() output = self.model(data) # NLLLoss is averaged across observations for each minibatch ce = self.criterion(torch.log(output + EPS), target) Iw = self.getIw() loss = ce + 0.5 * self.beta * Iw if train: loss.backward() self.optimizer.step() total_ce += ce.item() total_Iw += Iw.item() total_loss += loss.item() prediction = torch.max( output, 1) # second param "1" represents the dimension to be reduced total_correct += np.sum( prediction[1].cpu().numpy() == target.cpu().numpy()) total += target.size(0) a = self.model.get_a() pbar.set_description('Train' if train else 'Test') pbar.set_postfix(N=self.N, b=self.beta, ep=epoch, acc=100. * total_correct / total, loss=total_loss / num_batches, ce=total_ce / num_batches, Iw=total_Iw / num_batches, a=a) return total_correct / total, total_loss / num_batches, total_ce / num_batches, total_Iw / num_batches, a def train(self, epoch): self.model.train() return self.do_batch(train=True, epoch=epoch) def test(self, epoch): self.model.eval() with torch.no_grad(): return self.do_batch(train=False, epoch=epoch) def save(self, name=None): model_out_path = (name or self.name) + ".pth" # torch.save(self.model, model_out_path) # print("Checkpoint saved to {}".format(model_out_path)) def run(self): self.load_data() self.load_model() results = [] best_acc, best_ep = -1, -1 for epoch in range(1, self.epochs + 1): # print("\n===> epoch: %d/200" % epoch) train_acc, train_loss, train_ce, train_Iw, train_a = self.train( epoch) self.scheduler.step(epoch) test_acc, test_loss, test_ce, test_Iw, test_a = self.test(epoch) results.append([ self.N, self.beta, train_acc, test_acc, train_loss, test_loss, train_ce, test_ce, train_Iw, test_Iw, train_a, test_a ]) if test_acc > best_acc: best_acc, best_ep = test_acc, epoch if self.patience >= 0: # early stopping if best_ep < epoch - self.patience: break with open(self.name + '.csv', 'a') as f: w = csv.writer(f) w.writerows(results) self.save() return train_acc, test_acc