def __init__( self, model: torch.nn.Module, criterion: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, save_root: str, train_dataset: torch.utils.data.Dataset, valid_dataset: Optional[torch.utils.data.Dataset] = None, valid_metrics: Optional[Dict] = None, exp_name: Optional[str] = None, batchsize: int = 1, num_workers: int = 0, schedulers: Optional[Dict[Any, Any]] = None, overlay_alpha: float = 0.2, enable_tensorboard: bool = True, tensorboard_root_path: Optional[str] = None, model_has_softmax_outputs: bool = False, ignore_errors: bool = False, ipython_on_error: bool = False, classes: Optional[Sequence[int]] = None, ): self.ignore_errors = ignore_errors self.ipython_on_error = ipython_on_error self.device = device self.model = model.to(device) self.criterion = criterion.to(device) self.optimizer = optimizer self.train_dataset = train_dataset self.valid_dataset = valid_dataset self.valid_metrics = valid_metrics self.overlay_alpha = overlay_alpha self.save_root = os.path.expanduser(save_root) self.batchsize = batchsize self.num_workers = num_workers # TODO: This could be automatically determined by parsing the model self.model_has_softmax_outputs = model_has_softmax_outputs self._tracker = HistoryTracker() self._timer = Timer() self._first_plot = True self._shell_info = dedent(""" Entering IPython training shell. To continue, hit Ctrl-D twice. To terminate, set self.terminate = True and then hit Ctrl-D twice. """).strip() if exp_name is None: # Auto-generate a name based on model name and ISO timestamp timestamp = datetime.datetime.now().strftime('%y-%m-%d_%H-%M-%S') exp_name = model.__class__.__name__ + '__' + timestamp self.exp_name = exp_name self.save_path = os.path.join(save_root, exp_name) os.makedirs(self.save_path, exist_ok=True) # TODO: Warn if directory already exists self.terminate = False self.step = 0 if schedulers is None: schedulers = {'lr': StepLR(optimizer, 1000, 1)} # No-op scheduler self.schedulers = schedulers # Determine optional dataset properties self.classes = classes self.num_classes = None if hasattr(self.train_dataset, 'classes'): self.classes = self.train_dataset.classes self.num_classes = len(self.train_dataset.classes) self.previews_enabled = hasattr(valid_dataset, 'preview_batch')\ and valid_dataset.preview_shape is not None if not tensorboard_available and enable_tensorboard: enable_tensorboard = False logger.warning('Tensorboard is not available, so it has to be disabled.') self.tb = None # Tensorboard handler if enable_tensorboard: if tensorboard_root_path is None: tb_path = self.save_path else: tensorboard_root_path = os.path.expanduser(tensorboard_root_path) tb_path = os.path.join(tensorboard_root_path, self.exp_name) os.makedirs(tb_path, exist_ok=True) # TODO: Make always_flush user-configurable here: self.tb = TensorBoardLogger(log_dir=tb_path, always_flush=False) self.train_loader = DelayedDataLoader( self.train_dataset, batch_size=self.batchsize, shuffle=True, num_workers=self.num_workers, pin_memory=True, timeout=30 # timeout arg requires https://github.com/pytorch/pytorch/commit/1661370ac5f88ef11fedbeac8d0398e8369fc1f3 ) # num_workers is set to 0 for valid_loader because validation background processes sometimes # fail silently and stop responding, bringing down the whole training process. # This issue might be related to https://github.com/pytorch/pytorch/issues/1355. # The performance impact of disabling multiprocessing here is low in normal settings, # because the validation loader doesn't perform expensive augmentations, but just reads # data from hdf5s. if valid_dataset is not None: self.valid_loader = DelayedDataLoader( self.valid_dataset, self.batchsize, num_workers=0, pin_memory=False, timeout=30 ) self.best_val_loss = np.inf # Best recorded validation loss self.valid_metrics = {} if valid_metrics is None else valid_metrics
def train(self, max_steps: int = 1, max_runtime=3600 * 24 * 7) -> None: """Train the network for ``max_steps`` steps. After each training epoch, validation performance is measured and visualizations are computed and logged to tensorboard.""" self.start_time = datetime.datetime.now() self.end_time = self.start_time + datetime.timedelta(seconds=max_runtime) while not self.terminate: try: # --> self.train() self.model.train() # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, float] = {'tr_loss': 0.0} # Other scalars to be logged misc: Dict[str, float] = {} # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, torch.Tensor] = {} running_acc = 0 running_mean_target = 0 running_vx_size = 0 timer = Timer() for inp, target in self.train_loader: inp, target = inp.to(self.device), target.to(self.device) # forward pass out = self.model(inp) loss = self.criterion(out, target) if torch.isnan(loss): logger.error('NaN loss detected! Aborting training.') raise NaNException # update step self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Prevent accidental autograd overheads after optimizer step inp.detach_() target.detach_() out.detach_() loss.detach_() # get training performance stats['tr_loss'] += float(loss) acc = metrics.bin_accuracy(target, out) # TODO mean_target = target.to(torch.float32).mean() print(f'{self.step:6d}, loss: {loss:.4f}', end='\r') self._tracker.update_timeline([self._timer.t_passed, float(loss), mean_target]) # Preserve training batch and network output for later visualization images['inp'] = inp images['target'] = target images['out'] = out # this was changed to support ReduceLROnPlateau which does not implement get_lr misc['learning_rate'] = self.optimizer.param_groups[0]["lr"] # .get_lr()[-1] # update schedules for sched in self.schedulers.values(): # support ReduceLROnPlateau; doc. uses validation loss instead # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau if "metrics" in inspect.signature(sched.step).parameters: sched.step(metrics=float(loss)) else: sched.step() running_acc += acc running_mean_target += mean_target running_vx_size += inp.numel() self.step += 1 if self.step >= max_steps: logger.info(f'max_steps ({max_steps}) exceeded. Terminating...') self.terminate = True break if datetime.datetime.now() >= self.end_time: logger.info(f'max_runtime ({max_runtime} seconds) exceeded. Terminating...') self.terminate = True break stats['tr_accuracy'] = running_acc / len(self.train_loader) stats['tr_loss'] /= len(self.train_loader) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx mean_target = running_mean_target / len(self.train_loader) if self.valid_dataset is None: stats['val_loss'], stats['val_accuracy'] = float('nan'), float('nan') else: valid_stats = self.validate() stats.update(valid_stats) # Update history tracker (kind of made obsolete by tensorboard) # TODO: Decide what to do with this, now that most things are already in tensorboard. if self.step // len(self.train_dataset) > 1: tr_loss_gain = self._tracker.history[-1][2] - stats['tr_loss'] else: tr_loss_gain = 0 self._tracker.update_history([ self.step, self._timer.t_passed, stats['tr_loss'], stats['val_loss'], tr_loss_gain, stats['tr_accuracy'], stats['val_accuracy'], misc['learning_rate'], 0, 0 ]) # 0's correspond to mom and gradnet (?) t = pretty_string_time(self._timer.t_passed) loss_smooth = self._tracker.loss._ema # Logging to stdout, text log file text = "%05i L_m=%.3f, L=%.2f, tr_acc=%05.2f%%, " % (self.step, loss_smooth, stats['tr_loss'], stats['tr_accuracy']) text += "val_acc=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % (stats['val_accuracy'], "%", mean_target * 100, tr_loss_gain) text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % (misc['learning_rate'], misc['tr_speed'], misc['tr_speed_vx'], t) logger.info(text) # Plot tracker stats to pngs in save_path self._tracker.plot(self.save_path) # Reporting to tensorboard logger if self.tb: self.tb_log_scalars(stats, 'stats') self.tb_log_scalars(misc, 'misc') if self.previews_enabled: self.tb_log_preview() self.tb_log_sample_images(images, group='tr_samples') self.tb.writer.flush() # Save trained model state self.save_model() if stats['val_loss'] < self.best_val_loss: self.best_val_loss = stats['val_loss'] self.save_model(suffix='_best') except KeyboardInterrupt: IPython.embed(header=self._shell_info) if self.terminate: return except Exception as e: traceback.print_exc() if self.ignore_errors: # Just print the traceback and try to carry on with training. # This can go wrong in unexpected ways, so don't leave the training unattended. pass elif self.ipython_on_error: print("\nEntering Command line such that Exception can be " "further inspected by user.\n\n") IPython.embed(header=self._shell_info) if self.terminate: return else: raise e self.save_model(suffix='_final')
def _train(self, max_steps, max_runtime): self.model.train() # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, float] = {'tr_loss': 0.0} # Other scalars to be logged misc: Dict[str, float] = {} # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, np.ndarray] = {} running_acc = 0 running_mean_target = 0 running_vx_size = 0 timer = Timer() pbar = tqdm(enumerate(self.train_loader), 'Training', total=len(self.train_loader)) for i, (inp, target, scal) in pbar: # Everything with a "d" prefix refers to tensors on self.device (i.e. probably on GPU) dinp = inp.to(self.device, non_blocking=True) dtarget = target.to(self.device, non_blocking=True) dscal = scal.to(self.device, non_blocking=True) # forward pass dout = self.model(dinp, dscal) dloss = self.criterion(dout, dtarget) if torch.isnan(dloss): logger.error('NaN loss detected! Aborting training.') raise NaNException # update step self.optimizer.zero_grad() if self.mixed_precision: with self.amp_handle.scale_loss(dloss, self.optimizer) as scaled_loss: scaled_loss.backward() else: dloss.backward() self.optimizer.step() # End of core training loop on self.device # TODO: Evaluate performance impact of these copies and maybe avoid doing these so often out = dout.detach().cpu() # Copy model output to host memory for metrics, visualization with torch.no_grad(): loss = float(dloss) stats['tr_loss'] += loss acc = float(metrics.bin_accuracy(target, out)) mean_target = float(target.to(torch.float32).mean()) pbar.set_description(f'Training (loss {loss:.4f})') self._tracker.update_timeline([self._timer.t_passed, loss, mean_target]) # this was changed to support ReduceLROnPlateau which does not implement get_lr misc['learning_rate'] = self.optimizer.param_groups[0]["lr"] # .get_lr()[-1] # update schedules for sched in self.schedulers.values(): # support ReduceLROnPlateau; doc. uses validation loss instead # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau if "metrics" in inspect.signature(sched.step).parameters: sched.step(metrics=loss) else: sched.step() running_acc += acc running_mean_target += mean_target running_vx_size += inp.numel() self.step += 1 if self.step >= max_steps: logger.info(f'max_steps ({max_steps}) exceeded. Terminating...') self.terminate = True if datetime.datetime.now() >= self.end_time: logger.info(f'max_runtime ({max_runtime} seconds) exceeded. Terminating...') self.terminate = True if i == len(self.train_loader) - 1 or self.terminate: # Last step in this epoch or in the whole training # Preserve last training batch and network output for later visualization images['inp'] = inp.numpy() images['target'] = target.numpy() images['out'] = out.numpy() if self.terminate: break stats['tr_accuracy'] = running_acc / len(self.train_loader) stats['tr_loss'] /= len(self.train_loader) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx misc['mean_target'] = running_mean_target / len(self.train_loader) return stats, misc, images
def _train(self, max_steps, max_runtime): out_channels = self.out_channels def _channel_metric(metric, c, out_channels=out_channels, mean=False): """Returns an evaluator that calculates the ``metric`` and selects its value for channel ``c``.""" def evaluator(target, out): #pred = metrics._argmax(out) m = metric(target, out, num_classes=out_channels, ignore=out_channels - 1, mean=mean) return m[c] return evaluator tr_evaluators = {**{ f'tr_DSC_c{c}': _channel_metric(metrics.dice_coefficient, c=c) for c in range(out_channels) }, **{ f'tr_precision_c{c}': _channel_metric(metrics.precision, c=c) for c in range(out_channels) }, **{ f'tr_recall_c{c}': _channel_metric(metrics.precision, c=c) for c in range(out_channels) }} # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, Union[float, List[float]]] = {stat: [] for stat in ['tr_loss', 'tr_loss_mean', 'tr_accuracy']} stats.update({name: [] for name in tr_evaluators.keys()}) file_stats = {} # Other scalars to be logged misc: Dict[str, Union[float, List[float]]] = {misc: [] for misc in ['mean_target']} # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, np.ndarray] = {} self.model.train() self.optimizer.zero_grad() running_vx_size = 0 # Counts input sizes (number of pixels/voxels) of training batches timer = Timer() import gc gc.collect() batch_iter = tqdm(self.train_loader, 'Training', total=len(self.train_loader)) for i, batch in enumerate(batch_iter): if self.step in self.extra_save_steps: self._save_model(f'_step{self.step}', verbose=True) # Everything with a "d" prefix refers to tensors on self.device (i.e. probably on GPU) inp, target = batch['inp'], batch['target'] cube_meta = batch['cube_meta'] fname = batch['fname'] dinp = inp.to(self.device, non_blocking=True) dtarget = target[:,:,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop].to(self.device, non_blocking=True) if self.loss_crop else target.to(self.device, non_blocking=True) weight = cube_meta[0].to(device=self.device, dtype=self.criterion.weight.dtype, non_blocking=True) prev_weight = self.criterion.weight.clone() self.criterion.weight = weight if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): ignore_mask = (1 - dtarget[0][-1]).view(1,1,*dtarget.shape[2:]) dense_weight = self.criterion.weight.view(1,-1,1,1,1) positive_target_mask = (weight.view(1,-1,1,1,1) * dtarget)[0][1:-1].sum(dim=0).view(1,1,*dtarget.shape[2:]) # weighted targets w\ background and ignore needs_positive_target_mark = (dense_weight.sum() == 0).type(positive_target_mask.dtype) self.criterion.weight = ignore_mask * dense_weight + needs_positive_target_mark * positive_target_mask * prev_weight.view(1,-1,1,1,1) # forward pass dout = self.model(dinp)[:,:,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop] if self.loss_crop else self.model(dinp) #print(dout.dtype, dout.shape, dtarget.dtype, dtarget.shape, dout.min(), dout.max()) dloss = self.criterion(dout, dtarget) #dcumloss = dloss if i == 0 else dcumloss + dloss #print(dloss, dloss.size()) #dloss = (dloss * prev_weight * weight).mean() if torch.isnan(dloss).sum(): logger.error('NaN loss detected! Aborting training.') raise NaNException if self.mixed_precision: from apex import amp with amp.scale_loss(dloss, self.optimizer) as scaled_loss: scaled_loss.backward() else: # update step dloss.backward() if i % self.optimizer_iterations == self.optimizer_iterations - 1: self.optimizer.step() # TODO (lp): calling zero_grad() here makes gradients disappear from tb histograms self.optimizer.zero_grad() #loss2 = float(self.criterion(self.model(dinp), dtarget)) #print(f'loss gain factor {np.divide(float(dloss), (float(dloss)-loss2))})') # End of core training loop on self.device with torch.no_grad(): loss = float(dloss) # TODO: Evaluate performance impact of these copies and maybe avoid doing these so often out_class = dout.argmax(dim=1).detach().cpu() multi_class_target = target.argmax(1) if len(target.shape) > 4 else target # TODO if self.loss_crop: multi_class_target = multi_class_target[:,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop] acc = metrics.accuracy(multi_class_target, out_class, out_channels, mean=False).numpy() acc = np.average(acc[~np.isnan(acc)])#, weights=) mean_target = float(multi_class_target.to(torch.float32).mean()) # import h5py # dsc5 = channel_metric(metrics.dice_coefficient, c=5, out_channels=out_channels)(multi_class_target, out_class) # after_step = '+' if i % self.optimizer_iterations == 0 else '' # with h5py.File(os.path.join(self.save_path, f'batch {self.step}{after_step} loss={float(dloss)} dsc5={dsc5}.h5'), "w") as f: # f.create_dataset('raw', data=inp.squeeze(dim=0), compression="gzip") # f.create_dataset('labels', data=multi_class_target.numpy().astype(np.uint16), compression="gzip") # f.create_dataset('pred', data=dout.squeeze(dim=0).detach().cpu().numpy(), compression="gzip") if fname[0] not in file_stats: file_stats[fname[0]] = [] file_stats[fname[0]] += [float('nan')] * (i - len(file_stats[fname[0]])) + [loss] stats['tr_loss'].append(loss) stats['tr_loss_mean'] += [float('nan')] * (i - len(stats['tr_loss_mean'])) if i % self.optimizer_iterations == self.optimizer_iterations - 1: stats['tr_loss_mean'] += [np.mean(stats['tr_loss'][-self.optimizer_iterations:])] stats['tr_accuracy'].append(acc) for name, evaluator in tr_evaluators.items(): stats[name].append(evaluator(multi_class_target, out_class)) misc['mean_target'].append(mean_target) # if loss-loss2 == 0 and not torch.any(out_class != multi_class_target): # print('grad', self.model.up_convs[0].conv2.weight.grad) # IPython.embed() #if loss - 0.99 < 1e-3: # print('asd', loss, loss2) # IPython.embed() batch_iter.set_description(f'Training (loss {loss:.4f})') #pbar.set_description(f'Training (loss {loss} / {float(dcumloss)})') #pbar.set_description(f'Training (loss {loss} / {np.divide(loss, (loss-loss2))})') self._tracker.update_timeline([self._timer.t_passed, loss, mean_target]) self.criterion.weight = prev_weight # Not using .get_lr()[-1] because ReduceLROnPlateau does not implement get_lr() misc['learning_rate'] = self.optimizer.param_groups[0]['lr'] # LR for the this iteration # update schedules for sched in self.schedulers.values(): # support ReduceLROnPlateau; doc. uses validation loss instead # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau if "metrics" in inspect.signature(sched.step).parameters: sched.step(metrics=loss) else: sched.step() # Append LR of the next iteration (after sched.step()) for local LR minima detection self._lr_nhood.append(self.optimizer.param_groups[0]['lr']) self._handle_lr() running_vx_size += inp.numel() #if stats['tr_loss_mean'][-1] < self.best_tr_loss: # self.best_tr_loss = stats['tr_loss'][-1] # self._save_model(suffix='_best_train', loss=stats['tr_loss'][-1]) self.step += 1 if self.step >= max_steps: logger.info(f'max_steps ({max_steps}) exceeded. Terminating...') self.terminate = True if datetime.datetime.now() >= self.end_time: logger.info(f'max_runtime ({max_runtime} seconds) exceeded. Terminating...') self.terminate = True if i == len(self.train_loader) - 1 or self.terminate: # Last step in this epoch or in the whole training # Preserve last training batch and network output for later visualization images['fname'] = Path(fname[0]).stem images['inp'] = inp.numpy() images['target'] = multi_class_target.numpy() images['out'] = dout.detach().cpu().numpy() self._put_current_attention_maps_into(images) if self.terminate: break stats['tr_loss_std'] = np.std(stats['tr_loss']) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx return stats, file_stats, misc, images
def __init__( self, model: torch.nn.Module, criterion: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, save_root: str, train_dataset: torch.utils.data.Dataset, valid_dataset: Optional[torch.utils.data.Dataset] = None, valid_metrics: Optional[Dict] = None, preview_batch: Optional[torch.Tensor] = None, preview_tile_shape: Optional[Tuple[int, ...]] = None, preview_overlap_shape: Optional[Tuple[int, ...]] = None, preview_interval: int = 5, offset: Optional[Sequence[int]] = None, exp_name: Optional[str] = None, example_input: Optional[torch.Tensor] = None, enable_save_trace: bool = False, batchsize: int = 1, num_workers: int = 0, schedulers: Optional[Dict[Any, Any]] = None, overlay_alpha: float = 0.2, enable_videos: bool = True, enable_tensorboard: bool = True, tensorboard_root_path: Optional[str] = None, apply_softmax_for_prediction: bool = True, ignore_errors: bool = False, ipython_shell: bool = True, num_classes: Optional[int] = None, sample_plotting_handler: Optional[Callable] = None, preview_plotting_handler: Optional[Callable] = None, mixed_precision: bool = False, ): if preview_batch is not None and\ (preview_tile_shape is None or preview_overlap_shape is None): raise ValueError( 'If preview_batch is set, you will also need to specify ' 'preview_tile_shape and preview_overlap_shape!' ) if num_workers > 1 and 'PatchCreator' in str(type(train_dataset)): logger.warning( 'Training with num_workers > 1 can cause instabilities if ' 'you are using PatchCreator.\nBe advised that PatchCreator ' 'might randomly deliver broken batches in your training and ' 'can crash it at any point of time.\n' 'Please set num_workers to 1 or 0.\n' ) self.ignore_errors = ignore_errors self.ipython_shell = ipython_shell self.device = device try: model.to(device) except RuntimeError as exc: if isinstance(model, torch.jit.ScriptModule): # "RuntimeError: to is not supported on TracedModules" # But .cuda() works for some reason. Using this messy # workaround in the hope that we can drop it soon. # TODO: Remove this when ScriptModule.to() is supported # See https://github.com/pytorch/pytorch/issues/7354 if 'cuda' in str(self.device): # (Ignoring device number!) model.cuda() else: raise exc self.model = model self.criterion = criterion.to(device) self.optimizer = optimizer self.train_dataset = train_dataset self.valid_dataset = valid_dataset self.valid_metrics = valid_metrics self.preview_batch = preview_batch self.preview_tile_shape = preview_tile_shape self.preview_overlap_shape = preview_overlap_shape self.preview_interval = preview_interval self.offset = offset self.overlay_alpha = overlay_alpha self.save_root = os.path.expanduser(save_root) self.example_input = example_input self.enable_save_trace = enable_save_trace self.batchsize = batchsize self.num_workers = num_workers self.apply_softmax_for_prediction = apply_softmax_for_prediction self.sample_plotting_handler = sample_plotting_handler self.preview_plotting_handler = preview_plotting_handler self.mixed_precision = mixed_precision self._tracker = HistoryTracker() self._timer = Timer() self._first_plot = True self._shell_info = dedent(""" Entering IPython training shell. To continue, hit Ctrl-D twice. To terminate, set self.terminate = True and then hit Ctrl-D twice. """).strip() if self.mixed_precision: from apex import amp self.amp_handle = amp.init() if exp_name is None: # Auto-generate a name based on model name and ISO timestamp timestamp = datetime.datetime.now().strftime('%y-%m-%d_%H-%M-%S') exp_name = model.__class__.__name__ + '__' + timestamp self.exp_name = exp_name self.save_path = os.path.join(save_root, exp_name) if os.path.isdir(self.save_path): raise RuntimeError( f'{self.save_path} already exists.\nPlease choose a ' 'different combination of save_root and exp_name.' ) os.makedirs(self.save_path) logger.info(f'Writing files to save_path {self.save_path}/\n') self.terminate = False self.step = 0 self.epoch = 0 if schedulers is None: schedulers = {'lr': StepLR(optimizer, 1000, 1)} # No-op scheduler self.schedulers = schedulers self.num_classes = num_classes if enable_videos: try: import moviepy except: logger.warning('moviepy is not installed. Disabling video logs.') enable_videos = False self.enable_videos = enable_videos self.tb = None # Tensorboard handler if enable_tensorboard: if self.sample_plotting_handler is None: self.sample_plotting_handler = handlers._tb_log_sample_images if self.preview_plotting_handler is None: self.preview_plotting_handler = handlers._tb_log_preview if tensorboard_root_path is None: tb_path = self.save_path else: tensorboard_root_path = os.path.expanduser(tensorboard_root_path) tb_path = os.path.join(tensorboard_root_path, self.exp_name) os.makedirs(tb_path, exist_ok=True) # TODO: Make always_flush user-configurable here: self.tb = tensorboardX.SummaryWriter(log_dir=tb_path) self.train_loader = DelayedDataLoader( self.train_dataset, batch_size=self.batchsize, shuffle=True, num_workers=self.num_workers, pin_memory=True, timeout=60 ) # num_workers is set to 0 for valid_loader because validation background processes sometimes # fail silently and stop responding, bringing down the whole training process. # This issue might be related to https://github.com/pytorch/pytorch/issues/1355. # The performance impact of disabling multiprocessing here is low in normal settings, # because the validation loader doesn't perform expensive augmentations, but just reads # data from hdf5s. if valid_dataset is not None: self.valid_loader = DelayedDataLoader( self.valid_dataset, self.batchsize, num_workers=0, pin_memory=True, timeout=60 ) self.best_val_loss = np.inf # Best recorded validation loss self.valid_metrics = {} if valid_metrics is None else valid_metrics
def _train(self, max_steps, max_runtime): """Train for one epoch or until max_steps or max_runtime is reached""" self.model.train() # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, Union[float, List[float]]] = {stat: [] for stat in ['tr_loss']} # Other scalars to be logged misc: Dict[str, Union[float, List[float]]] = { misc: [] for misc in ['mean_target'] } # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, np.ndarray] = {} running_vx_size = 0 # Counts input sizes (number of pixels/voxels) of training batches timer = Timer() batch_iter = tqdm(self.train_loader, 'Training', total=len(self.train_loader), dynamic_ncols=True, **self.tqdm_kwargs) for i, batch in enumerate(batch_iter): if self.step in self.extra_save_steps: self._save_model(f'_step{self.step}', verbose=True) dloss, dout_imgs = self._train_step_triplet(batch) with torch.no_grad(): loss = float(dloss) mean_target = 0. # Dummy value misc['mean_target'].append(mean_target) stats['tr_loss'].append(loss) batch_iter.set_description(f'Training (loss {loss:.4f})') self._tracker.update_timeline( [self._timer.t_passed, loss, mean_target]) # Not using .get_lr()[-1] because ReduceLROnPlateau does not implement get_lr() misc['learning_rate'] = self.optimizer.param_groups[0][ 'lr'] # LR for the this iteration self._scheduler_step(loss) running_vx_size += batch['anchor'].numel() self._incr_step(max_runtime, max_steps) if i == len(self.train_loader) - 1 or self.terminate: # Last step in this epoch or in the whole training # Preserve last training batch and network output for later visualization for key, img in batch.items(): if isinstance(img, torch.Tensor): img = img.detach().cpu().numpy() images[key] = img self._put_current_attention_maps_into(images) # TODO: The plotting handler abstraction is inadequate here. Figure out how # we can handle plotting cleanly in one place. # Outputs are visualized here, while inputs are visualized in the plotting handler # which is called in _run()... for name, img in dout_imgs.items(): img = img.detach()[0].cpu().numpy( ) # select first item of batch for c in range(img.shape[0]): if img.ndim == 4: # 3D data img = img[:, img.shape[0] // 2] # take center slice of depth dim -> 2D self.tb.add_figure(f'tr_samples/{name}_c{c}', handlers.plot_image(img[c], cmap='gray'), global_step=self.step) if self.terminate: break stats['tr_loss_std'] = np.std(stats['tr_loss']) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx return stats, misc, images
def _train(self, max_steps, max_runtime): """Train for one epoch or until max_steps or max_runtime is reached""" self.model.train() # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, Union[float, List[float]]] = {stat: [] for stat in ['tr_loss']} # Other scalars to be logged misc: Dict[str, Union[float, List[float]]] = { misc: [] for misc in ['mean_target'] } # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, np.ndarray] = {} running_vx_size = 0 # Counts input sizes (number of pixels/voxels) of training batches timer = Timer() batch_iter = tqdm(self.train_loader, 'Training', total=len(self.train_loader), dynamic_ncols=True) unlabeled_iter = None if self.unlabeled_dataset is None else iter( self.unlabeled_loader) for i, batch in enumerate(batch_iter): if self.step in self.extra_save_steps: self._save_model(f'_step{self.step}', verbose=True) if unlabeled_iter is not None: batch['unlabeled'] = next(unlabeled_iter) dloss, dout = self._train_step(batch) with torch.no_grad(): loss = float(dloss) target = batch.get('target') mean_target = float(target.to( torch.float32).mean()) if target is not None else 0. misc['mean_target'].append(mean_target) stats['tr_loss'].append(loss) batch_iter.set_description(f'Training (loss {loss:.4f})') self._tracker.update_timeline( [self._timer.t_passed, loss, mean_target]) # Not using .get_lr()[-1] because ReduceLROnPlateau does not implement get_lr() misc['learning_rate'] = self.optimizer.param_groups[0][ 'lr'] # LR for the this iteration self._scheduler_step(loss) running_vx_size += batch['inp'].numel() self._incr_step(max_runtime, max_steps) if i == len(self.train_loader) - 1 or self.terminate: # Last step in this epoch or in the whole training # Preserve last training batch and network output for later visualization images['inp'] = batch['inp'].numpy() if 'target' in batch: images['target'] = batch['target'].numpy() if 'unlabeled' in batch: images['unlabeled'] = batch['unlabeled'] images['out'] = dout.detach().cpu().numpy() self._put_current_attention_maps_into(images) if self.terminate: break stats['tr_loss_std'] = np.std(stats['tr_loss']) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx return stats, misc, images
def __init__( self, model: torch.nn.Module, criterion: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, save_root: str, train_dataset: torch.utils.data.Dataset, valid_dataset: Optional[torch.utils.data.Dataset] = None, unlabeled_dataset: Optional[torch.utils.data.Dataset] = None, valid_metrics: Optional[Dict] = None, ss_criterion: Optional[torch.nn.Module] = None, preview_batch: Optional[torch.Tensor] = None, preview_interval: int = 5, inference_kwargs: Optional[Dict[str, Any]] = None, hparams: Optional[Dict[str, Any]] = None, extra_save_steps: Sequence[int] = (), exp_name: Optional[str] = None, example_input: Optional[torch.Tensor] = None, enable_save_trace: bool = False, save_jit: Optional[str] = None, batch_size: int = 1, num_workers: int = 0, schedulers: Optional[Dict[Any, Any]] = None, overlay_alpha: float = 0.4, enable_videos: bool = False, enable_tensorboard: bool = True, tensorboard_root_path: Optional[str] = None, ignore_errors: bool = False, ipython_shell: bool = False, out_channels: Optional[int] = None, sample_plotting_handler: Optional[Callable] = None, preview_plotting_handler: Optional[Callable] = None, mixed_precision: bool = False, ): inference_kwargs = {} if inference_kwargs is None else inference_kwargs if preview_batch is not None and ( 'tile_shape' not in inference_kwargs or ('overlap_shape' not in inference_kwargs and 'offset' not in inference_kwargs)): raise ValueError( 'If preview_batch is set, you will also need to specify ' 'tile_shape and overlap_shape or offset in inference_kwargs!') if enable_save_trace: logger.warning( 'enable_save_trace is deprecated. Please use the save_jit option instead.' ) assert save_jit in [None, 'trace'] save_jit = 'trace' # Ensure that all nn.Modules are on the right device model.to(device) if isinstance(criterion, torch.nn.Module): criterion.to(device) if isinstance(ss_criterion, torch.nn.Module): ss_criterion.to(device) self.ignore_errors = ignore_errors self.ipython_shell = ipython_shell self.device = device self.model = model self.criterion = criterion self.optimizer = optimizer self.train_dataset = train_dataset self.valid_dataset = valid_dataset self.unlabeled_dataset = unlabeled_dataset self.valid_metrics = valid_metrics self.ss_criterion = ss_criterion self.preview_batch = preview_batch self.preview_interval = preview_interval self.inference_kwargs = inference_kwargs self.extra_save_steps = extra_save_steps self.overlay_alpha = overlay_alpha self.save_root = os.path.expanduser(save_root) self.example_input = example_input self.save_jit = save_jit self.batch_size = batch_size self.num_workers = num_workers self.sample_plotting_handler = sample_plotting_handler self.preview_plotting_handler = preview_plotting_handler self.mixed_precision = mixed_precision self._tracker = HistoryTracker() self._timer = Timer() self._first_plot = True self._shell_info = dedent(""" Entering IPython training shell. To continue, hit Ctrl-D twice. To terminate, set self.terminate = True and then hit Ctrl-D twice. """).strip() self.inference_kwargs.setdefault('batch_size', 1) self.inference_kwargs.setdefault('verbose', True) self.inference_kwargs.setdefault('apply_softmax', True) if self.unlabeled_dataset is not None and self.ss_criterion is None: raise ValueError( 'If an unlabeled_dataset is supplied, you must also set ss_criterion.' ) if hparams is None: hparams = {} else: for k, v in hparams.items(): if isinstance(v, (tuple, list)): # Convert to str because tensorboardX doesn't support # tuples and lists in add_hparams() hparams[k] = str(v) self.hparams = hparams if self.mixed_precision: from apex import amp self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') if exp_name is None: # Auto-generate a name based on model name and ISO timestamp timestamp = datetime.datetime.now().strftime('%y-%m-%d_%H-%M-%S') exp_name = model.__class__.__name__ + '__' + timestamp self.exp_name = exp_name self.save_path = os.path.join(save_root, exp_name) if os.path.isdir(self.save_path): raise RuntimeError( f'{self.save_path} already exists.\nPlease choose a ' 'different combination of save_root and exp_name.') os.makedirs(self.save_path) _change_log_file_to(f'{self.save_path}/elektronn3.log') logger.info(f'Writing files to save_path {self.save_path}/\n') self.terminate = False self.step = 0 self.epoch = 0 if schedulers is None: schedulers = {'lr': StepLR(optimizer, 1000, 1)} # No-op scheduler self.schedulers = schedulers self.__lr_closetozero_alreadytriggered = False # Used in periodic scheduler handling self._lr_nhood = deque( maxlen=3 ) # Keeps track of the last, current and next learning rate self.out_channels = out_channels if enable_videos: try: import moviepy except: logger.warning( 'moviepy is not installed. Disabling video logs.') enable_videos = False self.enable_videos = enable_videos self.tb = None # Tensorboard handler if enable_tensorboard: if self.sample_plotting_handler is None: self.sample_plotting_handler = handlers._tb_log_sample_images if self.preview_plotting_handler is None: self.preview_plotting_handler = handlers._tb_log_preview if tensorboard_root_path is None: tb_path = self.save_path else: tensorboard_root_path = os.path.expanduser( tensorboard_root_path) tb_path = os.path.join(tensorboard_root_path, self.exp_name) os.makedirs(tb_path, exist_ok=True) self.tb = tensorboardX.SummaryWriter(logdir=tb_path, flush_secs=20) if self.hparams: self.tb.add_hparams(hparam_dict=self.hparams, metric_dict={}) self.train_loader = DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, timeout=60 if self.num_workers > 0 else 0, worker_init_fn=_worker_init_fn) if valid_dataset is not None: self.valid_loader = DataLoader(self.valid_dataset, self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, worker_init_fn=_worker_init_fn) if self.unlabeled_dataset is not None: self.unlabeled_loader = DataLoader( self.unlabeled_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, timeout=60 if self.num_workers > 0 else 0, worker_init_fn=_worker_init_fn) self.best_val_loss = np.inf # Best recorded validation loss self.best_tr_loss = np.inf self.valid_metrics = {} if valid_metrics is None else valid_metrics
def run(self, max_steps: int = 1) -> None: """Train the network for ``max_steps`` steps. After each training epoch, validation performance is measured and visualizations are computed and logged to tensorboard.""" while self.step < max_steps: try: # --> self.train() self.model.train() # Scalar training stats that should be logged and written to tensorboard later stats: Dict[str, float] = {'tr_loss_G': .0, 'tr_loss_D': .0} # Other scalars to be logged misc: Dict[str, float] = { 'G_loss_advreg': .0, 'G_loss_tnet': .0, 'G_loss_l2': .0, 'D_loss_fake': .0, 'D_loss_real': .0 } # Hold image tensors for real-time training sample visualization in tensorboard images: Dict[str, torch.Tensor] = {} running_error = 0 running_mean_target = 0 running_vx_size = 0 timer = Timer() latent_points_fake = [] latent_points_real = [] for inp in self.train_loader: # ref., pos., neg. samples if inp.size()[1] != 3: raise ValueError( "Data must not contain targets. " "Input data shape is assumed to be " "(N, 3, ch, x, y), where the first two" " images in each sample is the similar" " pair, while the third one is the " "distant one.") inp0 = Variable(inp[:, 0].to(self.device)) inp1 = Variable(inp[:, 1].to(self.device)) inp2 = Variable(inp[:, 2].to(self.device)) self.optimizer.zero_grad() # forward pass dA, dB, z0, z1, z2 = self.model(inp0, inp1, inp2) z_fake_gauss = torch.squeeze(torch.cat((z0, z1, z2), dim=1)) target = torch.FloatTensor(dA.size()).fill_(-1).to( self.device) target = Variable(target) loss = self.criterion(dA, dB, target) L_l2 = torch.mean( torch.cat((z0.norm(1, dim=1), z1.norm( 1, dim=1), z2.norm(1, dim=1)), dim=0)) misc['G_loss_l2'] += self.alpha * float(L_l2) loss = loss + self.alpha * L_l2 misc['G_loss_tnet'] += (1 - self.alpha2) * float( loss) # log actual loss if torch.isnan(loss): logger.error('NaN loss detected after {self.step} ' 'steps! Aborting training.') raise NaNException # Adversarial part to enforce latent variable distribution # to be Normal / whatever prior is used if self.alpha2 > 0: self.optimizer_discr.zero_grad() # adversarial labels valid = Variable(torch.Tensor(inp0.size()[0], 1).fill_(1.0), requires_grad=False).to(self.device) fake = Variable(torch.Tensor(inp0.shape[0], 1).fill_(0.0), requires_grad=False).to(self.device) # --- Generator / TripletNet self.model_discr.eval() # TripletNet latent space should be classified as valid L_advreg = self.criterion_discr( self.model_discr(z_fake_gauss), valid) # average adversarial reg. and triplet-loss loss = (1 - self.alpha2) * loss + self.alpha2 * L_advreg # perform generator step loss.backward() self.optimizer.step() # --- Discriminator self.model.eval() self.model_discr.train() # rebuild graph (model output) to get clean backprop. z_real_gauss = Variable( self.latent_distr(inp0.size()[0], z0.size()[-1] * 3)).to( self.device) _, _, z_fake_gauss0, z_fake_gauss1, z_fake_gauss2 = self.model( inp0, inp1, inp2) z_fake_gauss = torch.squeeze( torch.cat( (z_fake_gauss0, z_fake_gauss1, z_fake_gauss2), dim=1)) # Compute discriminator outputs and loss L_real_gauss = self.criterion_discr( self.model_discr(z_real_gauss), valid) L_fake_gauss = self.criterion_discr( self.model_discr(z_fake_gauss), fake) L_discr = 0.5 * (L_real_gauss + L_fake_gauss) L_discr.backward() # Backprop loss self.optimizer_discr.step() # Apply optimization step self.model.train() # set back to training mode # # clean and report L_discr.detach() L_advreg.detach() L_real_gauss.detach() L_fake_gauss.detach() stats['tr_loss_D'] += float(L_discr) misc['G_loss_advreg'] += self.alpha2 * float( L_advreg) # log actual part of advreg misc['D_loss_real'] += float(L_real_gauss) misc['D_loss_fake'] += float(L_fake_gauss) latent_points_real.append( z_real_gauss.detach().cpu().numpy()) else: loss.backward() self.optimizer.step() latent_points_fake.append( z_fake_gauss.detach().cpu().numpy()) # # Prevent accidental autograd overheads after optimizer step inp.detach() target.detach() dA.detach() dB.detach() z0.detach() z1.detach() z2.detach() loss.detach() L_l2.detach() # get training performance stats['tr_loss_G'] += float(loss) error = calculate_error(dA, dB) mean_target = target.to(torch.float32).mean() print(f'{self.step:6d}, loss: {loss:.4f}', end='\r') self._tracker.update_timeline( [self._timer.t_passed, float(loss), mean_target]) # Preserve training batch and network output for later visualization images['inp_ref'] = inp0.cpu().numpy() images['inp_+'] = inp1.cpu().numpy() images['inp_-'] = inp2.cpu().numpy() # this was changed to support ReduceLROnPlateau which does not implement get_lr misc['learning_rate_G'] = self.optimizer.param_groups[0][ "lr"] # .get_lr()[-1] misc[ 'learning_rate_D'] = self.optimizer_discr.param_groups[ 0]["lr"] # .get_lr()[-1] # update schedules for sched in self.schedulers.values(): # support ReduceLROnPlateau; doc. uses validation loss instead # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau if "metrics" in inspect.signature( sched.step).parameters: sched.step(metrics=float(loss)) else: sched.step() running_error += error running_mean_target += mean_target running_vx_size += inp.numel() self.step += 1 if self.step >= max_steps: break stats['tr_err_G'] = float(running_error) / len( self.train_loader) stats['tr_loss_G'] /= len(self.train_loader) stats['tr_loss_D'] /= len(self.train_loader) misc['G_loss_advreg'] /= len(self.train_loader) misc['G_loss_tnet'] /= len(self.train_loader) misc['G_loss_l2'] /= len(self.train_loader) misc['D_loss_fake'] /= len(self.train_loader) misc['D_loss_real'] /= len(self.train_loader) misc['tr_speed'] = len(self.train_loader) / timer.t_passed misc[ 'tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6 # MVx mean_target = running_mean_target / len(self.train_loader) if (self.valid_dataset is None) or (1 != np.random.randint( 0, 10)): # only validate 10% of the times stats['val_loss_G'], stats['val_err_G'] = float( 'nan'), float('nan') else: stats['val_loss_G'], stats['val_err_G'] = self._validate() # TODO: Report more metrics, e.g. dice error # Update history tracker (kind of made obsolete by tensorboard) # TODO: Decide what to do with this, now that most things are already in tensorboard. if self.step // len(self.train_dataset) > 1: tr_loss_gain = self._tracker.history[-1][2] - stats[ 'tr_loss_G'] else: tr_loss_gain = 0 self._tracker.update_history([ self.step, self._timer.t_passed, stats['tr_loss_G'], stats['val_loss_G'], tr_loss_gain, stats['tr_err_G'], stats['val_err_G'], misc['learning_rate_G'], 0, 0 ]) # 0's correspond to mom and gradnet (?) t = pretty_string_time(self._timer.t_passed) loss_smooth = self._tracker.loss._ema # Logging to stdout, text log file text = "%05i L_m=%.3f, L=%.2f, tr=%05.2f%%, " % ( self.step, loss_smooth, stats['tr_loss_G'], stats['tr_err_G']) text += "vl=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % ( stats['val_err_G'], "%", mean_target * 100, tr_loss_gain) text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % ( misc['learning_rate_G'], misc['tr_speed'], misc['tr_speed_vx'], t) logger.info(text) # Plot tracker stats to pngs in save_path self._tracker.plot(self.save_path) # Reporting to tensorboard logger if self.tb: self._tb_log_scalars(stats, 'stats') self._tb_log_scalars(misc, 'misc') self.tb_log_sample_images(images, group='tr_samples') # save histrograms if len(latent_points_fake) > 0: fig, ax = plt.subplots() sns.distplot(np.concatenate(latent_points_fake).flatten()) # plt.savefig(os.path.join(self.save_path, # 'latent_fake_{}.png'.format(self.step))) fig.canvas.draw() img_data = np.array(fig.canvas.renderer._renderer) self.tb.add_figure(f'latent_distr/latent_fake', plot_image(img_data), global_step=self.step) plt.close() if len(latent_points_real) > 0: fig, ax = plt.subplots() sns.distplot(np.concatenate(latent_points_real).flatten()) # plt.savefig(os.path.join(self.save_path, # 'latent_real_{}.png'.format(self.step))) fig.canvas.draw() img_data = np.array(fig.canvas.renderer._renderer) self.tb.add_figure(f'latent_distr/latent_real', plot_image(img_data), global_step=self.step) plt.close() # grab the pixel buffer and dump it into a numpy array # Save trained model state torch.save( self.model.state_dict(), # os.path.join(self.save_path, f'model-{self.step:06d}.pth') # Saving with different file names leads to heaps of large files, os.path.join(self.save_path, 'model-checkpoint.pth')) # TODO: Also save "best" model, not only the latest one, which is often overfitted. # -> "best" in which regard? Lowest validation loss, validation error? # We can't blindly trust these metrics and may have to calculate # additional metrics (with focus on object boundary correctness). except KeyboardInterrupt: IPython.embed(header=self._shell_info) if self.terminate: return except Exception as e: traceback.print_exc() if self.ignore_errors: # Just print the traceback and try to carry on with training. # This can go wrong in unexpected ways, so don't leave the training unattended. pass elif self.ipython_shell: print("\nEntering Command line such that Exception can be " "further inspected by user.\n\n") IPython.embed(header=self._shell_info) if self.terminate: return else: raise e torch.save( self.model.state_dict(), os.path.join(self.save_path, f'model-final-{self.step:06d}.pth'))