Ejemplo n.º 1
0
    def run(self):

        #assert self.model is not None, '[ERROR] No model object loaded. Please load a PyTorch model torch.nn object into the class object.'
        #assert (self.train_loader is not None) or (self.val_loader is not None), '[ERROR] You must specify data loaders.'

        for key in self.trainer_status.keys():
            assert self.trainer_status[
                key], '[ERROR] The {} has not been generated and you cannot proceed.'.format(
                    key)
        print('[INFO] Trainer pass OK for training.')

        print('[INFO] Commencing model training.')
        # Train the model
        self.best_model, self.history = train_model(
            model=self.model,
            criterion=self.criterion,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
            device=self.device,
            dataloaders={
                'train': self.train_loader,
                'test': self.val_loader
            },
            dataset_sizes=self.dataset_sizes,
            num_epochs=self.config.TRAIN.NUM_EPOCHS,
            return_history=True,
            log_history=self.config.SYSTEM.LOG_HISTORY,
            working_dir=self.config.DIRS.WORKING_DIR)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device::', device)


# Setup the model and optimiser
model_ft = model_func(model_name, pretrained=True, num_classes=len(class_names))

#num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
#model_ft.fc = nn.Linear(num_ftrs, len(class_names))

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

# Train the model
model_ft, history = train_model(model=model_ft, criterion=criterion, optimizer=optimizer_ft, scheduler=exp_lr_scheduler, 
                                device=device, dataloaders=dataloaders, dataset_sizes=dataset_sizes, num_epochs=num_epochs,
                                return_history=True, log_history=True, working_dir=working_dir
                               )

# Save out the best model and finish
save_model_full(model=model_ft, PATH=os.path.join(working_dir,'caltech_birds_{}_full.pth'.format(model_name)))
save_model_dict(model=model_ft, PATH=os.path.join(working_dir,'caltech_birds_{}_dict.pth'.format(model_name)))