def __init__(self, hparams, gpu=None, inference=False): self.hparams = hparams self.gpu = gpu self.inference = inference self.start_training = time() # ininialize model architecture self.__setup_model(inference=inference, gpu=gpu) # define model parameters self.__setup_model_hparams() # declare preprocessing object self.postprocessing = Post_Processing() self.__seed_everything(42)
class Model: """ This class handles basic methods for handling the model: 1. Fit the model 2. Make predictions 3. Make inference predictions 3. Save 4. Load weights 5. Restore the model 6. Restore the model with averaged weights """ def __init__(self, hparams, gpu=None, inference=False): self.hparams = hparams self.gpu = gpu self.inference = inference self.start_training = time() # ininialize model architecture self.__setup_model(inference=inference, gpu=gpu) # define model parameters self.__setup_model_hparams() # declare preprocessing object self.postprocessing = Post_Processing() self.__seed_everything(42) def fit(self, train, valid): # setup train and val dataloaders train_loader = DataLoader( train, batch_size=self.hparams['batch_size'], shuffle=True, num_workers=self.hparams['num_workers'], ) valid_loader = DataLoader( valid, batch_size=self.hparams['batch_size'], shuffle=False, num_workers=self.hparams['num_workers'], ) # tensorboard writer = SummaryWriter( f"runs/{self.hparams['model_name']}_{self.start_training}") print('Start training the model') for epoch in range(self.hparams['n_epochs']): # training mode self.model.train() avg_loss = 0.0 for X_batch, y_batch in tqdm(train_loader): # push the data into the GPU X_batch = X_batch.float().to(self.device) y_batch = y_batch.float().to(self.device) # clean gradients from the previous step self.optimizer.zero_grad() # get model predictions pred = self.model(X_batch) # process main loss y_batch = y_batch.permute(0, 2, 3, 1) pred = pred.permute(0, 2, 3, 1) pred = pred.reshape(-1, pred.shape[-1]) y_batch = y_batch.reshape(-1, y_batch.shape[-1]) train_loss = self.loss(pred, y_batch) # remove data from GPU y_batch = y_batch.float().cpu().detach() pred = pred.float().cpu().detach() X_batch = X_batch.float().cpu().detach() # calc loss avg_loss += train_loss.item() / len(train_loader) # gradient clipping if self.apply_clipping: torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.5) # backprop train_loss.backward() # iptimizer step self.optimizer.step() pred = self.postprocessing.run(pred.numpy()) y_batch = self.postprocessing.run(y_batch.numpy()) # calculate a step for metrics self.metric.calc_running_score(labels=y_batch, outputs=pred) # calc train metrics metric_train = self.metric.compute() # evaluate the model print('Model evaluation') # val mode self.model.eval() self.optimizer.zero_grad() avg_val_loss = 0.0 with torch.no_grad(): for X_batch, y_batch in tqdm(valid_loader): # push the data into the GPU X_batch = X_batch.float().to(self.device) y_batch = y_batch.float().to(self.device) # get predictions pred = self.model(X_batch) # calculate main loss y_batch = y_batch.permute(0, 2, 3, 1) pred = pred.permute(0, 2, 3, 1) pred = pred.reshape(-1, pred.shape[-1]) y_batch = y_batch.reshape(-1, y_batch.shape[-1]) avg_val_loss += self.loss( pred, y_batch).item() / len(valid_loader) # remove data from GPU X_batch = X_batch.float().cpu().detach() pred = pred.float().cpu().detach() y_batch = y_batch.float().cpu().detach() pred = self.postprocessing.run(pred.numpy()) y_batch = self.postprocessing.run(y_batch.numpy()) # calculate a step for metrics self.metric.calc_running_score(labels=y_batch, outputs=pred) # calc val metrics metric_val = self.metric.compute() # early stopping for scheduler if self.hparams['scheduler_name'] == 'ReduceLROnPlateau': self.scheduler.step(metric_val) else: self.scheduler.step() es_result = self.early_stopping(score=metric_val, model=self.model, threshold=None) # print statistics if self.hparams['verbose_train']: print( '| Epoch: ', epoch + 1, '| Train_loss: ', avg_loss, '| Val_loss: ', avg_val_loss, '| Metric_train: ', metric_train, '| Metric_val: ', metric_val, '| Current LR: ', self.__get_lr(self.optimizer), ) # add data to tensorboard writer.add_scalars( 'Loss', { 'Train_loss': avg_loss, 'Val_loss': avg_val_loss }, epoch, ) writer.add_scalars('Metric', { 'Metric_train': metric_train, 'Metric_val': metric_val }, epoch) # early stopping procesudre if es_result == 2: print("Early Stopping") print( f'global best val_loss model score {self.early_stopping.best_score}' ) break elif es_result == 1: print(f'save global val_loss model score {metric_val}') writer.close() # load the best model trained so fat self.model = self.early_stopping.load_best_weights() return self.start_training def predict(self, X_test): """ This function makes: 1. batch-wise predictions 2. calculation of the metric for each sample 3. calculation of the metric for the entire dataset Parameters ---------- X_test Returns ------- """ # evaluate the model self.model.eval() test_loader = torch.utils.data.DataLoader( X_test, batch_size=self.hparams['batch_size'], shuffle=False, num_workers=0, ) error_samplewise = [] self.metric.reset() print('Getting predictions') with torch.no_grad(): for i, (X_batch, y_batch) in enumerate(tqdm(test_loader)): X_batch = X_batch.float().to(self.device) y_batch = y_batch.float().to(self.device) pred = self.model(X_batch) # calculate main loss y_batch = y_batch.permute(0, 2, 3, 1) pred = pred.permute(0, 2, 3, 1) pred = pred.reshape(-1, pred.shape[-1]) y_batch = y_batch.reshape(-1, y_batch.shape[-1]) pred = pred.cpu().detach().numpy() X_batch = X_batch.cpu().detach().numpy() y_batch = y_batch.cpu().detach().numpy() # calculate a sample-wise error error_samplewise += self.metric.calc_running_score_samplewise( labels=y_batch, outputs=pred) pred = self.postprocessing.run(pred) y_batch = self.postprocessing.run(y_batch) self.metric.calc_running_score(labels=y_batch, outputs=pred) fold_score = self.metric.compute() error_samplewise = np.array(error_samplewise) return error_samplewise, fold_score def save(self, model_path): print('Saving the model') # states (weights + optimizers) if self.gpu != None: if len(self.gpu) > 1: torch.save(self.model.module.state_dict(), model_path + '.pt') else: torch.save(self.model.state_dict(), model_path + '.pt') else: torch.save(self.model.state_dict(), model_path) # hparams with open(f"{model_path}_hparams.yml", 'w') as file: yaml.dump(self.hparams, file) return True def load(self, model_name): self.model.load_state_dict( torch.load(model_name + '.pt', map_location=self.device)) self.model.eval() return True @classmethod def restore(cls, model_name: str, gpu: list, inference: bool): if gpu is not None: assert all([isinstance(i, int) for i in gpu]), "All gpu indexes should be integer" # load hparams hparams = yaml.load(open(model_name + "_hparams.yml"), Loader=yaml.FullLoader) # construct class self = cls(hparams, gpu=gpu, inference=inference) # load weights + optimizer state self.load(model_name=model_name) return self ################## Utils ##################### def __get_lr(self, optimizer): for param_group in optimizer.param_groups: return param_group['lr'] def __setup_model(self, inference, gpu): # TODO: re-write to pure DDP if inference or gpu is None: self.device = torch.device('cpu') self.model = UNet(hparams=self.hparams['model']).to(self.device) else: if torch.cuda.device_count() > 1: if len(gpu) > 1: print("Number of GPUs will be used: ", len(gpu)) self.device = torch.device(f"cuda:{gpu[0]}" if torch.cuda. is_available() else "cpu") self.model = UNet(hparams=self.hparams['model']).to( self.device) self.model = DP(self.model, device_ids=gpu, output_device=gpu[0]) else: print("Only one GPU will be used") self.device = torch.device(f"cuda:{gpu[0]}" if torch.cuda. is_available() else "cpu") self.model = UNet(hparams=self.hparams['model']).to( self.device) else: self.device = torch.device( f"cuda:{gpu[0]}" if torch.cuda.is_available() else "cpu") self.model = UNet(hparams=self.hparams['model']).to( self.device) print('Only one GPU is available') print('Cuda available: ', torch.cuda.is_available()) return True def __setup_model_hparams(self): # 1. define losses self.loss = Dice_loss() # 2. define model metric self.metric = Dice_score(self.hparams['model']['n_classes']) # 3. define optimizer self.optimizer = eval(f"torch.optim.{self.hparams['optimizer_name']}")( params=self.model.parameters(), **self.hparams['optimizer_hparams']) # 4. define scheduler self.scheduler = eval( f"torch.optim.lr_scheduler.{self.hparams['scheduler_name']}")( optimizer=self.optimizer, **self.hparams['scheduler_hparams']) # 5. define early stopping self.early_stopping = EarlyStopping( checkpoint_path=self.hparams['checkpoint_path'] + f'/checkpoint_{self.start_training}' + '.pt', patience=self.hparams['patience'], delta=self.hparams['min_delta'], is_maximize=True, ) # 6. set gradient clipping self.apply_clipping = self.hparams['clipping'] # clipping of gradients # 7. Set scaler for optimizer self.scaler = torch.cuda.amp.GradScaler() return True def __seed_everything(self, seed): np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False