def __init__(self, model, resume, config, iters_per_epoch, val_logger=None, train_logger=None): self.model = model self.config = config self.val_logger = val_logger self.train_logger = train_logger self.logger = logging.getLogger(self.__class__.__name__) self.do_validation = self.config['trainer']['val'] self.start_epoch = 1 self.improved = False # SETTING THE DEVICE self.device, availble_gpus = self._get_available_devices(self.config['n_gpu']) self.model = torch.nn.DataParallel(self.model, device_ids=availble_gpus) self.model.to(self.device) # CONFIGS cfg_trainer = self.config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] # OPTIMIZER trainable_params = [{'params': filter(lambda p:p.requires_grad, self.model.module.get_other_params())}, {'params': filter(lambda p:p.requires_grad, self.model.module.get_backbone_params()), 'lr': config['optimizer']['args']['lr'] / 10}] self.optimizer = get_instance(torch.optim, 'optimizer', config, trainable_params) model_params = sum([i.shape.numel() for i in list(model.parameters())]) opt_params = sum([i.shape.numel() for j in self.optimizer.param_groups for i in j['params']]) assert opt_params == model_params, 'some params are missing in the opt' self.lr_scheduler = getattr(utils.lr_scheduler, config['lr_scheduler'])(optimizer=self.optimizer, num_epochs=self.epochs, iters_per_epoch=iters_per_epoch) # MONITORING self.monitor = cfg_trainer.get('monitor', 'off') if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf self.early_stoping = cfg_trainer.get('early_stop', math.inf) # CHECKPOINTS & TENSOBOARD date_time = datetime.datetime.now().strftime('%m-%d_%H-%M') run_name = config['experim_name'] self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], run_name) helpers.dir_exists(self.checkpoint_dir) config_save_path = os.path.join(self.checkpoint_dir, 'config.json') with open(config_save_path, 'w') as handle: json.dump(self.config, handle, indent=4, sort_keys=True) writer_dir = os.path.join(cfg_trainer['log_dir'], run_name) self.writer = tensorboard.SummaryWriter(writer_dir) self.html_results = HTML(web_dir=config['trainer']['save_dir'], exp_name=config['experim_name'], save_name=config['experim_name'], config=config, resume=resume) if resume: self._resume_checkpoint(resume)
def __init__(self, config, trainloader, valloader, model, train_logger, seed, resume, device): self.config = config self.train_loader = trainloader self.val_loader = valloader self.model = model self.device = device if self.device != "cpu": self.model = self.model.to(device) print(self.model.get_num_params()) self.set_optimization() if self.config["optimizer"]["args"]["stepscheduler"] is True: self.scheduler = torch.optim.lr_scheduler.StepLR(self.main_optimizer, step_size=self.config["optimizer"]["args"]["step"], gamma=self.config["optimizer"]["args"]["gamma"]) if self.config["dataset"] == "fastMRI": self.criterion = torch.nn.MSELoss(reduction="sum") elif self.config["dataset"] == "BSD500": self.criterion = torch.nn.MSELoss(size_average=False) self.train_logger = train_logger self.logger = logging.getLogger(self.__class__.__name__) self.start_epoch = 1 self.epochs = config["trainer"]['epochs'] self.save_period = config["trainer"]['save_period'] self.seed = seed self.wrt_mode, self.wrt_step = 'train_', 0 self.log_step = config['trainer'].get('log_per_iter', int(np.sqrt(config["val_loader"]["batch_size"]))) if config['trainer']['log_per_iter']: self.log_step = int(self.log_step / config["val_loader"]["batch_size"]) + 1 # CHECKPOINTS & TENSOBOARD date_time = datetime.datetime.now().strftime('%m-%d_%H-%M') run_name = config['experim_name'] + '_' + str(seed) self.checkpoint_dir = os.path.join(config["trainer"]['save_dir'], config["experim_name"], run_name) if not os.path.exists(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) config_save_path = os.path.join(self.checkpoint_dir, 'config.json') with open(config_save_path, 'w') as handle: json.dump(self.config, handle, indent=4, sort_keys=True) writer_dir = os.path.join(config["trainer"]['log_dir'], config["experim_name"], run_name) self.writer = tensorboard.SummaryWriter(writer_dir) self.html_results = HTML(web_dir=self.checkpoint_dir, exp_name=config['experim_name'], save_name=config['experim_name'], config=config, resume=resume) if resume: self._resume_checkpoint(resume)
class BaseTrainer: def __init__(self, model, resume, config, iters_per_epoch, train_logger=None): self.model = model self.config = config self.train_logger = train_logger self.logger = logging.getLogger(self.__class__.__name__) self.do_validation = self.config['trainer']['val'] self.start_epoch = 1 self.improved = False # SETTING THE DEVICE self.device, availble_gpus, self.str_device = self._get_available_devices( self.config['n_gpu']) self.name = self.model.__class__.__name__ self.model = torch.nn.DataParallel(self.model, device_ids=availble_gpus) self.model.to(self.device) # CONFIGS cfg_trainer = self.config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] # OPTIMIZER # print(self.model.__class__.__name__) if self.name == 'CCT_Unet': trainable_params = [{ 'params': filter(lambda p: p.requires_grad, self.model.module.get_other_params()) }] else: trainable_params = [{ 'params': filter(lambda p: p.requires_grad, self.model.module.get_other_params()) }, { 'params': filter(lambda p: p.requires_grad, self.model.module.get_backbone_params()), 'lr': config['optimizer']['args']['lr'] / 10 }] self.optimizer = get_instance(torch.optim, 'optimizer', config, trainable_params) model_params = sum([i.shape.numel() for i in list(model.parameters())]) opt_params = sum([ i.shape.numel() for j in self.optimizer.param_groups for i in j['params'] ]) assert opt_params == model_params, 'some params are missing in the opt' self.lr_scheduler = getattr(utils.lr_scheduler, config['lr_scheduler'])( optimizer=self.optimizer, num_epochs=self.epochs, iters_per_epoch=iters_per_epoch) # MONITORING self.monitor = cfg_trainer.get('monitor', 'off') if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf self.early_stoping = cfg_trainer.get('early_stop', math.inf) # CHECKPOINTS & TENSOBOARD date_time = datetime.datetime.now().strftime('%m-%d_%H-%M') run_name = config['experim_name'] self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], run_name) helpers.dir_exists(self.checkpoint_dir) config_save_path = os.path.join(self.checkpoint_dir, 'config.json') with open(config_save_path, 'w') as handle: json.dump(self.config, handle, indent=4, sort_keys=True) writer_dir = os.path.join(cfg_trainer['log_dir'], run_name) self.writer = tensorboard.SummaryWriter(writer_dir) self.html_results = HTML(web_dir=config['trainer']['save_dir'], exp_name=config['experim_name'], save_name=config['experim_name'], config=config, resume=resume) if resume: self._resume_checkpoint(resume) def _get_available_devices(self, n_gpu): sys_gpu = torch.cuda.device_count() if sys_gpu == 0: self.logger.warning('No GPUs detected, using the CPU') n_gpu = 0 elif n_gpu > sys_gpu: self.logger.warning( f'Nbr of GPU requested is {n_gpu} but only {sys_gpu} are available' ) n_gpu = sys_gpu device = torch.device('cuda:0' if n_gpu > 0 else 'cpu') str_device = 'gpu' if n_gpu > 0 else 'cpu' self.logger.info(f'Detected GPUs: {sys_gpu} Requested: {n_gpu}') available_gpus = list(range(n_gpu)) return device, available_gpus, str_device def train(self): for epoch in range(self.start_epoch, self.epochs + 1): results = self._train_epoch(epoch) if self.do_validation and epoch % self.config['trainer'][ 'val_per_epochs'] == 0: results = self._valid_epoch(epoch) self.logger.info('\n\n') for k, v in results.items(): self.logger.info(f'{str(k):15s}: {v}') if self.train_logger is not None: log = {'epoch': epoch, **results} self.train_logger.add_entry(log) # CHECKING IF THIS IS THE BEST MODEL (ONLY FOR VAL) if self.mnt_mode != 'off' and epoch % self.config['trainer'][ 'val_per_epochs'] == 0: try: if self.mnt_mode == 'min': self.improved = (log[self.mnt_metric] < self.mnt_best) else: self.improved = (log[self.mnt_metric] > self.mnt_best) except KeyError: self.logger.warning( f'The metrics being tracked ({self.mnt_metric}) has not been calculated. Training stops.' ) break if self.improved: self.mnt_best = log[self.mnt_metric] self.not_improved_count = 0 else: self.not_improved_count += 1 # if self.not_improved_count > self.early_stoping: # self.logger.info(f'\nPerformance didn\'t improve for {self.early_stoping} epochs') # self.logger.warning('Training Stoped') # break # SAVE CHECKPOINT if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=self.improved) self.html_results.save() self.writer.flush() self.writer.close() def _save_checkpoint(self, epoch, save_best=False): state = { 'arch': type(self.model).__name__, 'epoch': epoch, 'state_dict': self.model.state_dict(), 'monitor_best': self.mnt_best, 'config': self.config } filename = os.path.join(self.checkpoint_dir, f'checkpoint.pth') self.logger.info(f'\nSaving a checkpoint: {filename} ...') torch.save(state, filename) if save_best: filename = os.path.join(self.checkpoint_dir, f'best_model.pth') torch.save(state, filename) self.logger.info("Saving current best: best_model.pth") def _resume_checkpoint(self, resume_path): self.logger.info(f'Loading checkpoint : {resume_path}') checkpoint = torch.load(resume_path) self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] self.not_improved_count = 0 try: self.model.load_state_dict(checkpoint['state_dict']) except Exception as e: print(f'Error when loading: {e}') self.model.load_state_dict(checkpoint['state_dict'], strict=False) if "logger" in checkpoint.keys(): self.train_logger = checkpoint['logger'] self.logger.info( f'Checkpoint <{resume_path}> (epoch {self.start_epoch}) was loaded' ) def _train_epoch(self, epoch): raise NotImplementedError def _valid_epoch(self, epoch): raise NotImplementedError def _eval_metrics(self, output, target): raise NotImplementedError
class Trainer: def __init__(self, config, trainloader, valloader, model, train_logger, seed, resume, device): self.config = config self.train_loader = trainloader self.val_loader = valloader self.model = model self.device = device if self.device != "cpu": self.model = self.model.to(device) print(self.model.get_num_params()) self.set_optimization() if self.config["optimizer"]["args"]["stepscheduler"] is True: self.scheduler = torch.optim.lr_scheduler.StepLR(self.main_optimizer, step_size=self.config["optimizer"]["args"]["step"], gamma=self.config["optimizer"]["args"]["gamma"]) if self.config["dataset"] == "fastMRI": self.criterion = torch.nn.MSELoss(reduction="sum") elif self.config["dataset"] == "BSD500": self.criterion = torch.nn.MSELoss(size_average=False) self.train_logger = train_logger self.logger = logging.getLogger(self.__class__.__name__) self.start_epoch = 1 self.epochs = config["trainer"]['epochs'] self.save_period = config["trainer"]['save_period'] self.seed = seed self.wrt_mode, self.wrt_step = 'train_', 0 self.log_step = config['trainer'].get('log_per_iter', int(np.sqrt(config["val_loader"]["batch_size"]))) if config['trainer']['log_per_iter']: self.log_step = int(self.log_step / config["val_loader"]["batch_size"]) + 1 # CHECKPOINTS & TENSOBOARD date_time = datetime.datetime.now().strftime('%m-%d_%H-%M') run_name = config['experim_name'] + '_' + str(seed) self.checkpoint_dir = os.path.join(config["trainer"]['save_dir'], config["experim_name"], run_name) if not os.path.exists(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) config_save_path = os.path.join(self.checkpoint_dir, 'config.json') with open(config_save_path, 'w') as handle: json.dump(self.config, handle, indent=4, sort_keys=True) writer_dir = os.path.join(config["trainer"]['log_dir'], config["experim_name"], run_name) self.writer = tensorboard.SummaryWriter(writer_dir) self.html_results = HTML(web_dir=self.checkpoint_dir, exp_name=config['experim_name'], save_name=config['experim_name'], config=config, resume=resume) if resume: self._resume_checkpoint(resume) def set_optimization(self): """ """ self.optim_names = self.config["optimizer"]["type"] # main optimizer/scheduler if len(self.optim_names) == 2: try: # main optimizer only for network parameters main_params_iter = self.model.parameters_no_deepspline_apl() except AttributeError: print('Cannot use aux optimizer.') raise else: # single optimizer for all parameters main_params_iter = self.model.parameters() self.main_optimizer = self.construct_optimizer(main_params_iter, self.optim_names[0], 'main') self.aux_optimizer = None if len(self.optim_names) == 2: # aux optimizer/scheduler for deepspline/apl parameters try: if self.model.deepspline is not None: aux_params_iter = self.model.parameters_deepspline() # elif self.net.apl is not None: # aux_params_iter = self.net.parameters_apl() except AttributeError: print('Cannot use aux optimizer.') raise self.aux_optimizer = self.construct_optimizer(aux_params_iter, self.optim_names[1], 'aux') def construct_optimizer(self, params_iter, optim_name, mode='main'): """ """ lr = self.config["optimizer"]["args"]["lr"] if mode == 'main' else self.config["optimizer"]["args"]["lr"] # weight decay is added manually if optim_name == 'Adam': optimizer = torch.optim.Adam(params_iter, lr=lr) elif optim_name == 'SGD': optimizer = torch.optim.SGD(params_iter, lr=lr) else: raise ValueError('Need to provide a valid optimizer type') return optimizer def train(self): # losses Total_loss_train = [] MSE_loss_val = [] lr = [] if self.config["dataset"] == "BSD500": psnr_val = [] ssim_val = [] self.model.init_hyperparams() # loop for epoch in range(self.start_epoch, self.epochs+1): if epoch == self.epochs and self.config["model"]["sparsify_activations"]: print('\nLast epoch: freezing network for sparsifying the activations and evaluating training accuracy.') self.model.eval() # set network in evaluation mode self.model.sparsify_activations() self.model.freeze_parameters() results = self._train_epoch(epoch) Total_loss_train.append(results['mse_loss']) if epoch % self.config['trainer']['val_per_epochs'] == 0: results = self._valid_epoch(epoch) MSE_loss_val.append(results['val_loss']) if self.config["dataset"] == "BSD500": psnr_val.append(results['val_psnr']) ssim_val.append(results['val_ssim']) self.logger.info('\n\n') for k, v in results.items(): self.logger.info(f'{str(k):15s}: {v}') if epoch == self.epochs: with open(f'{self.checkpoint_dir}/val.txt', 'w') as f: for k, v in results.items(): f.write("%s\n" % (k + ':' + f'{v}')) if self.train_logger is not None: log = {'epoch': epoch, **results} self.train_logger.add_entry(log) for i, opt_group in enumerate(self.main_optimizer.param_groups): lr.append(opt_group['lr']) # SAVE CHECKPOINT if epoch % self.save_period == 0: self._save_checkpoint(epoch) self.html_results.save() self.writer.flush() self.writer.close() Total_loss_train = np.array(Total_loss_train) MSE_loss_val = np.array(MSE_loss_val) if self.config["dataset"] == "BSD500": psnr_val = np.array(psnr_val) ssim_val = np.array(ssim_val) epochs = np.arange(self.config["trainer"]["val_per_epochs"], self.config["trainer"]["epochs"] + self.config["trainer"]["val_per_epochs"], self.config["trainer"]["val_per_epochs"]) fig, ax1 = plt.subplots() ax2 = ax1.twinx() ax1.plot(Total_loss_train, 'g-', linewidth=1) ax2.plot(epochs, MSE_loss_val, 'b-', linewidth=1) ax1.set_xlabel('epochs') ax1.set_ylabel('Train loss', color='g') ax2.set_ylabel('Validation loss', color='b') ax1.set_title("Learning curves") fig.savefig(f'{self.checkpoint_dir}/curves.png') plt.show() if self.config["dataset"] == "BSD500": fig, ax1 = plt.subplots() ax2 = ax1.twinx() ax1.plot(epochs, psnr_val, 'g-', linewidth=1) ax2.plot(epochs, ssim_val, 'b-', linewidth=1) ax1.set_xlabel('epochs') ax1.set_ylabel('val psnr', color='g') ax2.set_ylabel('Val ssim', color='b') ax1.set_title("Metric curves") fig.savefig(f'{self.checkpoint_dir}/metrics.png') plt.show() fig, ax1 = plt.subplots() ax1.plot(lr, 'k-', linewidth=1) ax1.set_xlabel('epochs') ax1.set_ylabel('lr', color='b') ax1.set_title("Lr curves") fig.savefig(f'{self.checkpoint_dir}/lr.png') plt.show() def _train_epoch(self, epoch): self.html_results.save() self.logger.info('\n') self.model.train() tbar = tqdm(self.train_loader, ncols=135) self._reset_metrics() for batch_idx, data in enumerate(tbar): if self.config["dataset"] == "fastMRI": cropp1, cropp2, cropp3, cropp4, cropp5, cropp6, cropp7, cropp8, target1, target2, target3, target4, _ = data cropp = torch.cat([cropp1, cropp2, cropp3, cropp4, cropp5, cropp6, cropp7, cropp8], dim=0) target = torch.cat([target1, target2, target3, target4, target1, target2, target3, target4], dim=0) elif self.config["dataset"] == "BSD500": cropp, target = data if self.device != 'cpu': cropp, target = cropp.to(self.device, non_blocking=True), target.to(self.device, non_blocking=True) self.optimizer_zero_grad() batch_size = cropp.shape[0] output = self.model(cropp) # data fidelity if self.config["dataset"] == "fastMRI": data_fidelity = self.criterion(output, target) / batch_size elif self.config["dataset"] == "BSD500": data_fidelity = self.criterion(output, target)/(output.size()[0]*2) # data_fidelity.backward(retain_graph=True) # regularization regularization = torch.zeros_like(data_fidelity) if self.model.weight_decay_regularization is True: # the regularization weight is multiplied inside weight_decay() regularization = regularization + self.model.weight_decay() if self.model.tv_bv_regularization is True: # the regularization weight is multiplied inside TV_BV() tv_bv, tv_bv_unweighted = self.model.TV_BV() regularization = regularization + tv_bv # losses.append(tv_bv_unweighted) total_loss = data_fidelity + regularization total_loss.backward() self.optimizer_step() if self.config["model"]["spectral_norm"] == "Parseval": with torch.no_grad(): self.model.perseval_normalization(self.config["model"]["beta"]) self._update_losses(total_loss.detach().cpu().numpy()) log = self._log_values() if batch_idx % self.log_step == 0: self.wrt_step = (epoch - 1) * len(self.train_loader) + batch_idx self._write_scalars_tb(log) del total_loss, output tbar.set_description('T ({}) | TotalLoss {:.3f} |'.format(epoch, self.total_mse_loss.average)) if self.config["optimizer"]["args"]["stepscheduler"] is True: self.scheduler.step(epoch=epoch-1) return log def optimizer_zero_grad(self): """ """ self.main_optimizer.zero_grad() if self.aux_optimizer is not None: self.aux_optimizer.zero_grad() def optimizer_step(self): """ """ self.main_optimizer.step() if self.aux_optimizer is not None: self.aux_optimizer.step() # Do the projection step to constrain the Lipschitz constant to 1 if ((self.model.activation_type == 'deepBspline_lipschitz_orthoprojection') or (self.model.activation_type == 'deepBspline_lipschitz_maxprojection')): for module in self.model.modules_deepspline(): module.do_lipschitz_projection() def _valid_epoch(self, epoch): if self.val_loader is None: self.logger.warning('Not data loader was passed for the validation step, No validation is performed !') return {} self.logger.info('\n###### EVALUATION ######') self.model.eval() self.wrt_mode = 'val' total_loss_val = AverageMeter() psnr_val = 0 ssim_val = 0 tbar = tqdm(self.val_loader, ncols=130) with torch.no_grad(): for batch_idx, data in enumerate(tbar): if self.config["dataset"] == "fastMRI": cropp1, cropp2, cropp3, cropp4, target1, target2, target3, target4, _ = data cropp = torch.cat([cropp1, cropp2, cropp3, cropp4], dim=0) target = torch.cat([target1, target2, target3, target4], dim=0) elif self.config["dataset"] == "BSD500": cropp, target = data if self.device != 'cpu': cropp, target = cropp.to(self.device, non_blocking=True), target.to(self.device, non_blocking=True) batch_size = cropp.shape[0] output = self.model(cropp) # LOSS if self.config["dataset"] == "fastMRI": loss = self.criterion(output, target) / batch_size elif self.config["dataset"] == "BSD500": loss = self.criterion(output, target) / (output.size()[0] * 2) total_loss_val.update(loss.cpu()) out_val = torch.clamp(output, 0., 1.) psnr_val += batch_PSNR(out_val, target, 1.) ssim_val += batch_SSIM(out_val, target, 1.) # PRINT INFO tbar.set_description('EVAL ({}) | MSELoss: {:.3f} |'.format(epoch, total_loss_val.average)) # METRICS TO TENSORBOARD self.wrt_step = epoch * len(self.val_loader) self.writer.add_scalar(f'{self.wrt_mode}/loss', total_loss_val.average, self.wrt_step) psnr_val /= len(self.val_loader) ssim_val /= len(self.val_loader) self.writer.add_scalar(f'{self.wrt_mode}/Test PSNR', psnr_val, self.wrt_step) self.writer.add_scalar(f'{self.wrt_mode}/Test SSIM', ssim_val, self.wrt_step) log = {'val_loss': total_loss_val.average} self.html_results.add_results(epoch=epoch, results=log) self.html_results.save() log["val_psnr"] = psnr_val log["val_ssim"] = ssim_val return log def _reset_metrics(self): self.total_mse_loss = AverageMeter() def _update_losses(self, batch_loss): self.total_mse_loss.update(batch_loss.mean()) def _log_values(self): logs = {} logs['mse_loss'] = self.total_mse_loss.average return logs def _write_scalars_tb(self, logs): for k, v in logs.items(): self.writer.add_scalar(f'train/{k}', v, self.wrt_step) def _save_checkpoint(self, epoch): state = { 'arch': type(self.model).__name__, 'epoch': epoch, 'state_dict': self.model.state_dict(), 'config': self.config } filename = os.path.join(self.checkpoint_dir, f'checkpoint.pth') self.logger.info(f'\nSaving a checkpoint: {filename} ...') torch.save(state, filename) def _resume_checkpoint(self, resume_path): self.logger.info(f'Loading checkpoint : {resume_path}') checkpoint = torch.load(resume_path) self.start_epoch = checkpoint['epoch'] + 1 try: self.model.load_state_dict(checkpoint['state_dict']) except Exception as e: print(f'Error when loading: {e}') self.model.load_state_dict(checkpoint['state_dict'], strict=False) if "logger" in checkpoint.keys(): self.train_logger = checkpoint['logger'] self.logger.info(f'Checkpoint <{resume_path}> (epoch {self.start_epoch}) was loaded')