Beispiel #1
0
 def defineModel(self, *args, **kwargs):
     arch = self.architecture
     # A dictionary would be more elegant here, but if you put the models as values in a dict, python will go download them all as well (which we dont want unless we're going to use it)
     if arch == 'vgg11':
         self.model = models.vgg11(pretrained=True)
         self.framework = self.model
     elif arch == 'vgg13':
         self.model = models.vgg13(pretrained=True)
         self.framework = self.model
     elif arch == 'vgg16':
         self.model = models.vgg16(pretrained=True)
         self.framework = self.model
     elif arch == 'vgg19':
         self.model = models.vgg19(pretrained=True)
         self.framework = self.model
     else:
         # No architecture provided. Let them define one now, or use a default
         shouldUseDefault = IOUtils.yesOrNo(
             "No architecture was provided. Press (y) to use the default [vgg19], or (n) to define your own architecture."
         )
         if shouldUseDefault:
             self.architecture = 'vgg19'
             self.model = models.vgg19(pretrained=True)
             self.framework = self.model
         else:
             supportedModels = ['vgg11', 'vgg13', 'vgg16', 'vgg19']
             chosenArchitecture = IOUtils.getResponse(
                 f"Choose model architecture. Options are {supportedModels}:",
                 supportedModels)
             self.architecture = chosenArchitecture
             self.defineModel()
Beispiel #2
0
 def promptSave(self):
     IOUtils.notify("Training complete. Save the trained model?")
     shouldSave = IOUtils.yesOrNo(
         "Press 'y' to save or 'n' to end without saving.")
     if shouldSave:
         savePath = IOUtils.getResponse(
             "Enter filename (it should end in .pth)")
         # TODO: Save the checkpoint
         checkpoint = {
             'architecture': self.framework,
             'classifier': self.model.classifier,
             'input_size': 25088,
             'output_size': 102,
             'hidden_layers': [1646, 1584],
             'state_dict': self.model.state_dict(),
         }
         self.save_dir = self.save_dir if self.save_dir else 'model_checkpoints'
         os.makedirs(
             f"./{self.save_dir}",
             exist_ok=True)  # Make the save_dir (OK if it already exists)
         torch.save(checkpoint, f"{self.save_dir}/{savePath}")
     else:
         IOUtils.notify("Not saving model. Program terminated.")