def save_best_model(self, output_dir=None, filename_postfix=None): assert self.best_model is not None, '[ERROR] The best model attribute is empty, you likely need to train the model first.' # Ability to save to bespoke place or default to working directory. if output_dir is None: output_dir = self.config.DIRS.WORKING_DIR # Add some extra text to the filename. if filename_postfix is None: filename_postfix = '' full_path = os.path.join( output_dir, 'caltech_birds_{}_full{}.pth'.format(self.config.MODEL.MODEL_NAME, filename_postfix)) state_path = os.path.join( output_dir, 'caltech_birds_{}_dict{}.pth'.format(self.config.MODEL.MODEL_NAME, filename_postfix)) # Save out the best model and finish save_model_full(model=self.best_model, PATH=full_path) save_model_dict(model=self.best_model, PATH=state_path) print( '[INFO] Model has been successfully saved to the following directory: {}' .format(output_dir)) print('[INFO] The full model filename is: caltech_birds_{}_full.pth'. format(full_path)) print( '[INFO] The state dictionary filename is: caltech_birds_{}_full.pth' .format(state_path))
print('Device::', device) # Setup the model and optimiser model_ft = model_func(model_name, pretrained=True, num_classes=len(class_names)) #num_ftrs = model_ft.fc.in_features # Here the size of each output sample is set to 2. # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)). #model_ft.fc = nn.Linear(num_ftrs, len(class_names)) model_ft = model_ft.to(device) criterion = nn.CrossEntropyLoss() # Observe that all parameters are being optimized optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) # Decay LR by a factor of 0.1 every 7 epochs exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) # Train the model model_ft, history = train_model(model=model_ft, criterion=criterion, optimizer=optimizer_ft, scheduler=exp_lr_scheduler, device=device, dataloaders=dataloaders, dataset_sizes=dataset_sizes, num_epochs=num_epochs, return_history=True, log_history=True, working_dir=working_dir ) # Save out the best model and finish save_model_full(model=model_ft, PATH=os.path.join(working_dir,'caltech_birds_{}_full.pth'.format(model_name))) save_model_dict(model=model_ft, PATH=os.path.join(working_dir,'caltech_birds_{}_dict.pth'.format(model_name)))