class Trainer(object): def __init__(self): self.running_loss = 0.0 self.epochs = 20 self.current_epoch = 0 self.epoch_start_time = None self.model = None self.optimizer = None self.scheduler = None self.loss_fn = None self.vis = Visualizer() def setup_model(self, resume=False): print("Loading Model") self.model = LeNet5() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.005) self.scheduler = lr_scheduler.MultiStepLR(self.optimizer, milestones=[2, 5, 8, 12], gamma=0.1) if resume: print("Resuming from saved model") self.load_saved_model() if torch.cuda.is_available(): print("Using GPU") self.model.cuda() def load_saved_model(self, checkpoint='checkpoint.pth.tar'): checkpoint = torch.load(checkpoint) self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) def load_data(self): self.vis.write_log("Loading and Preprocessing MNIST Data") self.training_data = DataLoader(mnist(set_type='train'), batch_size=1) train_mean = self.training_data.dataset.pix_mean train_stdev = self.training_data.dataset.stdev trsfrms = transforms.Compose([ ZeroPad(pad_size=2), Normalize(mean=train_mean, stdev=train_stdev), ToTensor() ]) self.training_data.dataset.transform = trsfrms self.test_data = DataLoader(mnist(set_type='test', transform=trsfrms), batch_size=1) self.vis.write_log("Loading & Preprocessing Finished") def run(self): """Run training module, train then test""" self.vis.write_log( f"Training Module Started at {datetime.now().isoformat(' ', timespec='seconds')}" ) args = get_args() self.setup_model() self.loss_fn = torch.nn.CrossEntropyLoss(size_average=True) self.load_data() resume = args.resume self.running_loss = 0.0 self.start_time = time.time() start_epoch = 0 for self.current_epoch in range(start_epoch, self.epochs): self.epoch_start_time = time.time() self.train() self.test() self.vis.write_log("Creating checkpoint") save_model({ 'epoch': self.current_epoch, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict() }) def train(self): """Does one training iteration""" epoch_loss = 0 self.model.train(True) for sample in self.training_data: image = Variable(sample['image']) # TODO: Detect loss type and do the right transformation on label # Do this for MSELoss # label = Variable((sample['label'].squeeze() == 1).nonzero(), requires_grad=False) # label style for Cross Entropy Loss label = Variable(sample['label'].squeeze().nonzero().select(0, 0), requires_grad=False) y_pred = self.model(image) loss = self.loss_fn(y_pred, label) epoch_loss += loss.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() self.running_loss += epoch_loss self.vis.update_loss_plot(self.current_epoch + 1, epoch_loss) def test(self): """Tests model using test set""" self.model.train(False) correct = 0 for sample in self.test_data: image = Variable(sample['image']) label = Variable(sample['label']) y_pred = self.model(image) correct += 1 if torch.equal( torch.max(y_pred.data, 1)[1], torch.max(label.data, 1)[1]) else 0 test_accuracy = correct / len(self.test_data) self.vis.update_test_accuracy_plot(self.current_epoch + 1, test_accuracy) self.vis.write_log( f"Epoch: {self.current_epoch + 1}\tRunning Loss: {self.running_loss:.2f}\tEpoch time: {(time.time() - self.epoch_start_time):.2f} sec" ) self.vis.write_log(f"Test Accuracy: {test_accuracy:.2%}") self.vis.write_log( f"Elapsed time: {(time.time() - self.start_time):.2f} sec")