Beispiel #1
0
    def __init__(
            self,
            model: torch.nn.Module,
            criterion: torch.nn.Module,
            optimizer: torch.optim.Optimizer,
            device: torch.device,
            save_root: str,
            train_dataset: torch.utils.data.Dataset,
            valid_dataset: Optional[torch.utils.data.Dataset] = None,
            valid_metrics: Optional[Dict] = None,
            exp_name: Optional[str] = None,
            batchsize: int = 1,
            num_workers: int = 0,
            schedulers: Optional[Dict[Any, Any]] = None,
            overlay_alpha: float = 0.2,
            enable_tensorboard: bool = True,
            tensorboard_root_path: Optional[str] = None,
            model_has_softmax_outputs: bool = False,
            ignore_errors: bool = False,
            ipython_on_error: bool = False,
            classes: Optional[Sequence[int]] = None,
    ):
        self.ignore_errors = ignore_errors
        self.ipython_on_error = ipython_on_error
        self.device = device
        self.model = model.to(device)
        self.criterion = criterion.to(device)
        self.optimizer = optimizer
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.valid_metrics = valid_metrics
        self.overlay_alpha = overlay_alpha
        self.save_root = os.path.expanduser(save_root)
        self.batchsize = batchsize
        self.num_workers = num_workers
        # TODO: This could be automatically determined by parsing the model
        self.model_has_softmax_outputs = model_has_softmax_outputs

        self._tracker = HistoryTracker()
        self._timer = Timer()
        self._first_plot = True
        self._shell_info = dedent("""
            Entering IPython training shell. To continue, hit Ctrl-D twice.
            To terminate, set self.terminate = True and then hit Ctrl-D twice.
        """).strip()

        if exp_name is None:  # Auto-generate a name based on model name and ISO timestamp
            timestamp = datetime.datetime.now().strftime('%y-%m-%d_%H-%M-%S')
            exp_name = model.__class__.__name__ + '__' + timestamp
        self.exp_name = exp_name
        self.save_path = os.path.join(save_root, exp_name)
        os.makedirs(self.save_path, exist_ok=True)  # TODO: Warn if directory already exists

        self.terminate = False
        self.step = 0
        if schedulers is None:
            schedulers = {'lr': StepLR(optimizer, 1000, 1)}  # No-op scheduler
        self.schedulers = schedulers

        # Determine optional dataset properties
        self.classes = classes
        self.num_classes = None
        if hasattr(self.train_dataset, 'classes'):
            self.classes = self.train_dataset.classes
            self.num_classes = len(self.train_dataset.classes)
        self.previews_enabled = hasattr(valid_dataset, 'preview_batch')\
            and valid_dataset.preview_shape is not None

        if not tensorboard_available and enable_tensorboard:
            enable_tensorboard = False
            logger.warning('Tensorboard is not available, so it has to be disabled.')
        self.tb = None  # Tensorboard handler
        if enable_tensorboard:
            if tensorboard_root_path is None:
                tb_path = self.save_path
            else:
                tensorboard_root_path = os.path.expanduser(tensorboard_root_path)
                tb_path = os.path.join(tensorboard_root_path, self.exp_name)
                os.makedirs(tb_path, exist_ok=True)
            # TODO: Make always_flush user-configurable here:
            self.tb = TensorBoardLogger(log_dir=tb_path, always_flush=False)

        self.train_loader = DelayedDataLoader(
            self.train_dataset, batch_size=self.batchsize, shuffle=True,
            num_workers=self.num_workers, pin_memory=True,
            timeout=30  # timeout arg requires https://github.com/pytorch/pytorch/commit/1661370ac5f88ef11fedbeac8d0398e8369fc1f3
        )
        # num_workers is set to 0 for valid_loader because validation background processes sometimes
        # fail silently and stop responding, bringing down the whole training process.
        # This issue might be related to https://github.com/pytorch/pytorch/issues/1355.
        # The performance impact of disabling multiprocessing here is low in normal settings,
        # because the validation loader doesn't perform expensive augmentations, but just reads
        # data from hdf5s.
        if valid_dataset is not None:
            self.valid_loader = DelayedDataLoader(
                self.valid_dataset, self.batchsize, num_workers=0, pin_memory=False,
                timeout=30
            )
        self.best_val_loss = np.inf  # Best recorded validation loss

        self.valid_metrics = {} if valid_metrics is None else valid_metrics
Beispiel #2
0
    def train(self, max_steps: int = 1, max_runtime=3600 * 24 * 7) -> None:
        """Train the network for ``max_steps`` steps.

        After each training epoch, validation performance is measured and
        visualizations are computed and logged to tensorboard."""
        self.start_time = datetime.datetime.now()
        self.end_time = self.start_time + datetime.timedelta(seconds=max_runtime)
        while not self.terminate:
            try:
                # --> self.train()
                self.model.train()

                # Scalar training stats that should be logged and written to tensorboard later
                stats: Dict[str, float] = {'tr_loss': 0.0}
                # Other scalars to be logged
                misc: Dict[str, float] = {}
                # Hold image tensors for real-time training sample visualization in tensorboard
                images: Dict[str, torch.Tensor] = {}

                running_acc = 0
                running_mean_target = 0
                running_vx_size = 0
                timer = Timer()
                for inp, target in self.train_loader:
                    inp, target = inp.to(self.device), target.to(self.device)

                    # forward pass
                    out = self.model(inp)
                    loss = self.criterion(out, target)
                    if torch.isnan(loss):
                        logger.error('NaN loss detected! Aborting training.')
                        raise NaNException

                    # update step
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    # Prevent accidental autograd overheads after optimizer step
                    inp.detach_()
                    target.detach_()
                    out.detach_()
                    loss.detach_()

                    # get training performance
                    stats['tr_loss'] += float(loss)
                    acc = metrics.bin_accuracy(target, out)  # TODO
                    mean_target = target.to(torch.float32).mean()
                    print(f'{self.step:6d}, loss: {loss:.4f}', end='\r')
                    self._tracker.update_timeline([self._timer.t_passed, float(loss), mean_target])

                    # Preserve training batch and network output for later visualization
                    images['inp'] = inp
                    images['target'] = target
                    images['out'] = out
                    # this was changed to support ReduceLROnPlateau which does not implement get_lr
                    misc['learning_rate'] = self.optimizer.param_groups[0]["lr"] # .get_lr()[-1]
                    # update schedules
                    for sched in self.schedulers.values():
                        # support ReduceLROnPlateau; doc. uses validation loss instead
                        # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
                        if "metrics" in inspect.signature(sched.step).parameters:
                            sched.step(metrics=float(loss))
                        else:
                            sched.step()

                    running_acc += acc
                    running_mean_target += mean_target
                    running_vx_size += inp.numel()

                    self.step += 1
                    if self.step >= max_steps:
                        logger.info(f'max_steps ({max_steps}) exceeded. Terminating...')
                        self.terminate = True
                        break
                    if datetime.datetime.now() >= self.end_time:
                        logger.info(f'max_runtime ({max_runtime} seconds) exceeded. Terminating...')
                        self.terminate = True
                        break
                stats['tr_accuracy'] = running_acc / len(self.train_loader)
                stats['tr_loss'] /= len(self.train_loader)
                misc['tr_speed'] = len(self.train_loader) / timer.t_passed
                misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6  # MVx
                mean_target = running_mean_target / len(self.train_loader)
                if self.valid_dataset is None:
                    stats['val_loss'], stats['val_accuracy'] = float('nan'), float('nan')
                else:
                    valid_stats = self.validate()
                    stats.update(valid_stats)


                # Update history tracker (kind of made obsolete by tensorboard)
                # TODO: Decide what to do with this, now that most things are already in tensorboard.
                if self.step // len(self.train_dataset) > 1:
                    tr_loss_gain = self._tracker.history[-1][2] - stats['tr_loss']
                else:
                    tr_loss_gain = 0
                self._tracker.update_history([
                    self.step, self._timer.t_passed, stats['tr_loss'], stats['val_loss'],
                    tr_loss_gain, stats['tr_accuracy'], stats['val_accuracy'], misc['learning_rate'], 0, 0
                ])  # 0's correspond to mom and gradnet (?)
                t = pretty_string_time(self._timer.t_passed)
                loss_smooth = self._tracker.loss._ema

                # Logging to stdout, text log file
                text = "%05i L_m=%.3f, L=%.2f, tr_acc=%05.2f%%, " % (self.step, loss_smooth, stats['tr_loss'], stats['tr_accuracy'])
                text += "val_acc=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % (stats['val_accuracy'], "%", mean_target * 100, tr_loss_gain)
                text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % (misc['learning_rate'], misc['tr_speed'], misc['tr_speed_vx'], t)
                logger.info(text)

                # Plot tracker stats to pngs in save_path
                self._tracker.plot(self.save_path)

                # Reporting to tensorboard logger
                if self.tb:
                    self.tb_log_scalars(stats, 'stats')
                    self.tb_log_scalars(misc, 'misc')
                    if self.previews_enabled:
                        self.tb_log_preview()
                    self.tb_log_sample_images(images, group='tr_samples')
                    self.tb.writer.flush()

                # Save trained model state
                self.save_model()
                if stats['val_loss'] < self.best_val_loss:
                    self.best_val_loss = stats['val_loss']
                    self.save_model(suffix='_best')
            except KeyboardInterrupt:
                IPython.embed(header=self._shell_info)
                if self.terminate:
                    return
            except Exception as e:
                traceback.print_exc()
                if self.ignore_errors:
                    # Just print the traceback and try to carry on with training.
                    # This can go wrong in unexpected ways, so don't leave the training unattended.
                    pass
                elif self.ipython_on_error:
                    print("\nEntering Command line such that Exception can be "
                          "further inspected by user.\n\n")
                    IPython.embed(header=self._shell_info)
                    if self.terminate:
                        return
                else:
                    raise e
        self.save_model(suffix='_final')
Beispiel #3
0
    def _train(self, max_steps, max_runtime):
        self.model.train()

        # Scalar training stats that should be logged and written to tensorboard later
        stats: Dict[str, float] = {'tr_loss': 0.0}
        # Other scalars to be logged
        misc: Dict[str, float] = {}
        # Hold image tensors for real-time training sample visualization in tensorboard
        images: Dict[str, np.ndarray] = {}

        running_acc = 0
        running_mean_target = 0
        running_vx_size = 0
        timer = Timer()
        pbar = tqdm(enumerate(self.train_loader), 'Training', total=len(self.train_loader))
        for i, (inp, target, scal) in pbar:
            # Everything with a "d" prefix refers to tensors on self.device (i.e. probably on GPU)
            dinp = inp.to(self.device, non_blocking=True)
            dtarget = target.to(self.device, non_blocking=True)
            dscal = scal.to(self.device, non_blocking=True)

            # forward pass
            dout = self.model(dinp, dscal)
            dloss = self.criterion(dout, dtarget)
            if torch.isnan(dloss):
                logger.error('NaN loss detected! Aborting training.')
                raise NaNException

            # update step
            self.optimizer.zero_grad()
            if self.mixed_precision:
                with self.amp_handle.scale_loss(dloss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                dloss.backward()
            self.optimizer.step()
            # End of core training loop on self.device

            # TODO: Evaluate performance impact of these copies and maybe avoid doing these so often
            out = dout.detach().cpu()  # Copy model output to host memory for metrics, visualization

            with torch.no_grad():
                loss = float(dloss)
                stats['tr_loss'] += loss
                acc = float(metrics.bin_accuracy(target, out))
                mean_target = float(target.to(torch.float32).mean())
                pbar.set_description(f'Training (loss {loss:.4f})')
                self._tracker.update_timeline([self._timer.t_passed, loss, mean_target])

            # this was changed to support ReduceLROnPlateau which does not implement get_lr
            misc['learning_rate'] = self.optimizer.param_groups[0]["lr"]  # .get_lr()[-1]
            # update schedules
            for sched in self.schedulers.values():
                # support ReduceLROnPlateau; doc. uses validation loss instead
                # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
                if "metrics" in inspect.signature(sched.step).parameters:
                    sched.step(metrics=loss)
                else:
                    sched.step()

            running_acc += acc
            running_mean_target += mean_target
            running_vx_size += inp.numel()

            self.step += 1
            if self.step >= max_steps:
                logger.info(f'max_steps ({max_steps}) exceeded. Terminating...')
                self.terminate = True
            if datetime.datetime.now() >= self.end_time:
                logger.info(f'max_runtime ({max_runtime} seconds) exceeded. Terminating...')
                self.terminate = True
            if i == len(self.train_loader) - 1 or self.terminate:
                # Last step in this epoch or in the whole training
                # Preserve last training batch and network output for later visualization
                images['inp'] = inp.numpy()
                images['target'] = target.numpy()
                images['out'] = out.numpy()

            if self.terminate:
                break

        stats['tr_accuracy'] = running_acc / len(self.train_loader)
        stats['tr_loss'] /= len(self.train_loader)
        misc['tr_speed'] = len(self.train_loader) / timer.t_passed
        misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6  # MVx
        misc['mean_target'] = running_mean_target / len(self.train_loader)

        return stats, misc, images
Beispiel #4
0
    def _train(self, max_steps, max_runtime):

        out_channels = self.out_channels

        def _channel_metric(metric, c, out_channels=out_channels, mean=False):
            """Returns an evaluator that calculates the ``metric``
            and selects its value for channel ``c``."""

            def evaluator(target, out):
                #pred = metrics._argmax(out)
                m = metric(target, out, num_classes=out_channels, ignore=out_channels - 1, mean=mean)
                return m[c]

            return evaluator

        tr_evaluators = {**{
            f'tr_DSC_c{c}': _channel_metric(metrics.dice_coefficient, c=c) for c in range(out_channels)
        }, **{
            f'tr_precision_c{c}': _channel_metric(metrics.precision, c=c) for c in range(out_channels)
        }, **{
            f'tr_recall_c{c}': _channel_metric(metrics.precision, c=c) for c in range(out_channels)
        }}
        # Scalar training stats that should be logged and written to tensorboard later
        stats: Dict[str, Union[float, List[float]]] = {stat: [] for stat in ['tr_loss', 'tr_loss_mean', 'tr_accuracy']}
        stats.update({name: [] for name in tr_evaluators.keys()})
        file_stats = {}
        # Other scalars to be logged
        misc: Dict[str, Union[float, List[float]]] = {misc: [] for misc in ['mean_target']}
        # Hold image tensors for real-time training sample visualization in tensorboard
        images: Dict[str, np.ndarray] = {}

        self.model.train()
        self.optimizer.zero_grad()
        running_vx_size = 0  # Counts input sizes (number of pixels/voxels) of training batches
        timer = Timer()
        import gc
        gc.collect()
        batch_iter = tqdm(self.train_loader, 'Training', total=len(self.train_loader))
        for i, batch in enumerate(batch_iter):
            if self.step in self.extra_save_steps:
                self._save_model(f'_step{self.step}', verbose=True)
            # Everything with a "d" prefix refers to tensors on self.device (i.e. probably on GPU)
            inp, target = batch['inp'], batch['target']
            cube_meta = batch['cube_meta']
            fname = batch['fname']
            dinp = inp.to(self.device, non_blocking=True)
            dtarget = target[:,:,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop].to(self.device, non_blocking=True) if self.loss_crop else target.to(self.device, non_blocking=True)
            weight = cube_meta[0].to(device=self.device, dtype=self.criterion.weight.dtype, non_blocking=True)
            prev_weight = self.criterion.weight.clone()
            self.criterion.weight = weight

            if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
                ignore_mask = (1 - dtarget[0][-1]).view(1,1,*dtarget.shape[2:])
                dense_weight = self.criterion.weight.view(1,-1,1,1,1)
                positive_target_mask = (weight.view(1,-1,1,1,1) * dtarget)[0][1:-1].sum(dim=0).view(1,1,*dtarget.shape[2:]) # weighted targets w\ background and ignore
                needs_positive_target_mark = (dense_weight.sum() == 0).type(positive_target_mask.dtype)
                self.criterion.weight = ignore_mask * dense_weight + needs_positive_target_mark * positive_target_mask * prev_weight.view(1,-1,1,1,1)

            # forward pass
            dout = self.model(dinp)[:,:,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop] if self.loss_crop else self.model(dinp)

            #print(dout.dtype, dout.shape, dtarget.dtype, dtarget.shape, dout.min(), dout.max())
            dloss = self.criterion(dout, dtarget)
            #dcumloss = dloss if i == 0 else dcumloss + dloss
            #print(dloss, dloss.size())
            #dloss = (dloss * prev_weight * weight).mean()
            if torch.isnan(dloss).sum():
                logger.error('NaN loss detected! Aborting training.')
                raise NaNException

            if self.mixed_precision:
                from apex import amp
                with amp.scale_loss(dloss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                # update step
                dloss.backward()

            if i % self.optimizer_iterations == self.optimizer_iterations - 1:
                self.optimizer.step()
                # TODO (lp): calling zero_grad() here makes gradients disappear from tb histograms
                self.optimizer.zero_grad()
                #loss2 = float(self.criterion(self.model(dinp), dtarget))
                #print(f'loss gain factor {np.divide(float(dloss), (float(dloss)-loss2))})')
            # End of core training loop on self.device

            with torch.no_grad():
                loss = float(dloss)
                # TODO: Evaluate performance impact of these copies and maybe avoid doing these so often
                out_class = dout.argmax(dim=1).detach().cpu()
                multi_class_target = target.argmax(1) if len(target.shape) > 4 else target  # TODO
                if self.loss_crop:
                    multi_class_target = multi_class_target[:,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop,self.loss_crop:-self.loss_crop]
                acc = metrics.accuracy(multi_class_target, out_class, out_channels, mean=False).numpy()
                acc = np.average(acc[~np.isnan(acc)])#, weights=)
                mean_target = float(multi_class_target.to(torch.float32).mean())

                # import h5py
                # dsc5 = channel_metric(metrics.dice_coefficient, c=5, out_channels=out_channels)(multi_class_target, out_class)
                # after_step = '+' if i % self.optimizer_iterations == 0 else ''
                # with h5py.File(os.path.join(self.save_path, f'batch {self.step}{after_step} loss={float(dloss)} dsc5={dsc5}.h5'), "w") as f:
                #     f.create_dataset('raw', data=inp.squeeze(dim=0), compression="gzip")
                #     f.create_dataset('labels', data=multi_class_target.numpy().astype(np.uint16), compression="gzip")
                #     f.create_dataset('pred', data=dout.squeeze(dim=0).detach().cpu().numpy(), compression="gzip")

                if fname[0] not in file_stats:
                    file_stats[fname[0]] = []
                file_stats[fname[0]] += [float('nan')] * (i - len(file_stats[fname[0]])) + [loss]

                stats['tr_loss'].append(loss)
                stats['tr_loss_mean'] += [float('nan')] * (i - len(stats['tr_loss_mean']))
                if i % self.optimizer_iterations == self.optimizer_iterations - 1:
                    stats['tr_loss_mean'] += [np.mean(stats['tr_loss'][-self.optimizer_iterations:])]
                stats['tr_accuracy'].append(acc)
                for name, evaluator in tr_evaluators.items():
                    stats[name].append(evaluator(multi_class_target, out_class))

                misc['mean_target'].append(mean_target)
                # if loss-loss2 == 0 and not torch.any(out_class != multi_class_target):
                #     print('grad', self.model.up_convs[0].conv2.weight.grad)
                #     IPython.embed()
                #if loss - 0.99 < 1e-3:
                #    print('asd', loss, loss2)
                #    IPython.embed()
                batch_iter.set_description(f'Training (loss {loss:.4f})')
                #pbar.set_description(f'Training (loss {loss} / {float(dcumloss)})')
                #pbar.set_description(f'Training (loss {loss} / {np.divide(loss, (loss-loss2))})')
                self._tracker.update_timeline([self._timer.t_passed, loss, mean_target])

            self.criterion.weight = prev_weight

            # Not using .get_lr()[-1] because ReduceLROnPlateau does not implement get_lr()
            misc['learning_rate'] = self.optimizer.param_groups[0]['lr']  # LR for the this iteration
            # update schedules
            for sched in self.schedulers.values():
                # support ReduceLROnPlateau; doc. uses validation loss instead
                # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
                if "metrics" in inspect.signature(sched.step).parameters:
                    sched.step(metrics=loss)
                else:
                    sched.step()
            # Append LR of the next iteration (after sched.step()) for local LR minima detection
            self._lr_nhood.append(self.optimizer.param_groups[0]['lr'])
            self._handle_lr()

            running_vx_size += inp.numel()

            #if stats['tr_loss_mean'][-1] < self.best_tr_loss:
            #   self.best_tr_loss = stats['tr_loss'][-1]
            #   self._save_model(suffix='_best_train', loss=stats['tr_loss'][-1])

            self.step += 1
            if self.step >= max_steps:
                logger.info(f'max_steps ({max_steps}) exceeded. Terminating...')
                self.terminate = True
            if datetime.datetime.now() >= self.end_time:
                logger.info(f'max_runtime ({max_runtime} seconds) exceeded. Terminating...')
                self.terminate = True
            if i == len(self.train_loader) - 1 or self.terminate:
                # Last step in this epoch or in the whole training
                # Preserve last training batch and network output for later visualization
                images['fname'] = Path(fname[0]).stem
                images['inp'] = inp.numpy()
                images['target'] = multi_class_target.numpy()
                images['out'] = dout.detach().cpu().numpy()
                self._put_current_attention_maps_into(images)

            if self.terminate:
                break

        stats['tr_loss_std'] = np.std(stats['tr_loss'])
        misc['tr_speed'] = len(self.train_loader) / timer.t_passed
        misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6  # MVx

        return stats, file_stats, misc, images
Beispiel #5
0
    def __init__(
            self,
            model: torch.nn.Module,
            criterion: torch.nn.Module,
            optimizer: torch.optim.Optimizer,
            device: torch.device,
            save_root: str,
            train_dataset: torch.utils.data.Dataset,
            valid_dataset: Optional[torch.utils.data.Dataset] = None,
            valid_metrics: Optional[Dict] = None,
            preview_batch: Optional[torch.Tensor] = None,
            preview_tile_shape: Optional[Tuple[int, ...]] = None,
            preview_overlap_shape: Optional[Tuple[int, ...]] = None,
            preview_interval: int = 5,
            offset: Optional[Sequence[int]] = None,
            exp_name: Optional[str] = None,
            example_input: Optional[torch.Tensor] = None,
            enable_save_trace: bool = False,
            batchsize: int = 1,
            num_workers: int = 0,
            schedulers: Optional[Dict[Any, Any]] = None,
            overlay_alpha: float = 0.2,
            enable_videos: bool = True,
            enable_tensorboard: bool = True,
            tensorboard_root_path: Optional[str] = None,
            apply_softmax_for_prediction: bool = True,
            ignore_errors: bool = False,
            ipython_shell: bool = True,
            num_classes: Optional[int] = None,
            sample_plotting_handler: Optional[Callable] = None,
            preview_plotting_handler: Optional[Callable] = None,
            mixed_precision: bool = False,
    ):
        if preview_batch is not None and\
                (preview_tile_shape is None or preview_overlap_shape is None):
            raise ValueError(
                'If preview_batch is set, you will also need to specify '
                'preview_tile_shape and preview_overlap_shape!'
            )
        if num_workers > 1 and 'PatchCreator' in str(type(train_dataset)):
            logger.warning(
                'Training with num_workers > 1 can cause instabilities if '
                'you are using PatchCreator.\nBe advised that PatchCreator '
                'might randomly deliver broken batches in your training and '
                'can crash it at any point of time.\n'
                'Please set num_workers to 1 or 0.\n'
            )
        self.ignore_errors = ignore_errors
        self.ipython_shell = ipython_shell
        self.device = device
        try:
            model.to(device)
        except RuntimeError as exc:
            if isinstance(model, torch.jit.ScriptModule):
                # "RuntimeError: to is not supported on TracedModules"
                # But .cuda() works for some reason. Using this messy
                # workaround in the hope that we can drop it soon.
                # TODO: Remove this when ScriptModule.to() is supported
                # See https://github.com/pytorch/pytorch/issues/7354
                if 'cuda' in str(self.device):  # (Ignoring device number!)
                    model.cuda()
            else:
                raise exc
        self.model = model
        self.criterion = criterion.to(device)
        self.optimizer = optimizer
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.valid_metrics = valid_metrics
        self.preview_batch = preview_batch
        self.preview_tile_shape = preview_tile_shape
        self.preview_overlap_shape = preview_overlap_shape
        self.preview_interval = preview_interval
        self.offset = offset
        self.overlay_alpha = overlay_alpha
        self.save_root = os.path.expanduser(save_root)
        self.example_input = example_input
        self.enable_save_trace = enable_save_trace
        self.batchsize = batchsize
        self.num_workers = num_workers
        self.apply_softmax_for_prediction = apply_softmax_for_prediction
        self.sample_plotting_handler = sample_plotting_handler
        self.preview_plotting_handler = preview_plotting_handler
        self.mixed_precision = mixed_precision

        self._tracker = HistoryTracker()
        self._timer = Timer()
        self._first_plot = True
        self._shell_info = dedent("""
            Entering IPython training shell. To continue, hit Ctrl-D twice.
            To terminate, set self.terminate = True and then hit Ctrl-D twice.
        """).strip()

        if self.mixed_precision:
            from apex import amp
            self.amp_handle = amp.init()

        if exp_name is None:  # Auto-generate a name based on model name and ISO timestamp
            timestamp = datetime.datetime.now().strftime('%y-%m-%d_%H-%M-%S')
            exp_name = model.__class__.__name__ + '__' + timestamp
        self.exp_name = exp_name
        self.save_path = os.path.join(save_root, exp_name)
        if os.path.isdir(self.save_path):
            raise RuntimeError(
                f'{self.save_path} already exists.\nPlease choose a '
                'different combination of save_root and exp_name.'
            )
        os.makedirs(self.save_path)
        logger.info(f'Writing files to save_path {self.save_path}/\n')

        self.terminate = False
        self.step = 0
        self.epoch = 0
        if schedulers is None:
            schedulers = {'lr': StepLR(optimizer, 1000, 1)}  # No-op scheduler
        self.schedulers = schedulers

        self.num_classes = num_classes
        if enable_videos:
            try:
                import moviepy
            except:
                logger.warning('moviepy is not installed. Disabling video logs.')
                enable_videos = False
        self.enable_videos = enable_videos
        self.tb = None  # Tensorboard handler
        if enable_tensorboard:
            if self.sample_plotting_handler is None:
                self.sample_plotting_handler = handlers._tb_log_sample_images
            if self.preview_plotting_handler is None:
                self.preview_plotting_handler = handlers._tb_log_preview

            if tensorboard_root_path is None:
                tb_path = self.save_path
            else:
                tensorboard_root_path = os.path.expanduser(tensorboard_root_path)
                tb_path = os.path.join(tensorboard_root_path, self.exp_name)
                os.makedirs(tb_path, exist_ok=True)
            # TODO: Make always_flush user-configurable here:
            self.tb = tensorboardX.SummaryWriter(log_dir=tb_path)

        self.train_loader = DelayedDataLoader(
            self.train_dataset, batch_size=self.batchsize, shuffle=True,
            num_workers=self.num_workers, pin_memory=True,
            timeout=60
        )
        # num_workers is set to 0 for valid_loader because validation background processes sometimes
        # fail silently and stop responding, bringing down the whole training process.
        # This issue might be related to https://github.com/pytorch/pytorch/issues/1355.
        # The performance impact of disabling multiprocessing here is low in normal settings,
        # because the validation loader doesn't perform expensive augmentations, but just reads
        # data from hdf5s.
        if valid_dataset is not None:
            self.valid_loader = DelayedDataLoader(
                self.valid_dataset, self.batchsize, num_workers=0, pin_memory=True,
                timeout=60
            )
        self.best_val_loss = np.inf  # Best recorded validation loss

        self.valid_metrics = {} if valid_metrics is None else valid_metrics
    def _train(self, max_steps, max_runtime):
        """Train for one epoch or until max_steps or max_runtime is reached"""
        self.model.train()

        # Scalar training stats that should be logged and written to tensorboard later
        stats: Dict[str,
                    Union[float,
                          List[float]]] = {stat: []
                                           for stat in ['tr_loss']}
        # Other scalars to be logged
        misc: Dict[str, Union[float, List[float]]] = {
            misc: []
            for misc in ['mean_target']
        }
        # Hold image tensors for real-time training sample visualization in tensorboard
        images: Dict[str, np.ndarray] = {}

        running_vx_size = 0  # Counts input sizes (number of pixels/voxels) of training batches
        timer = Timer()
        batch_iter = tqdm(self.train_loader,
                          'Training',
                          total=len(self.train_loader),
                          dynamic_ncols=True,
                          **self.tqdm_kwargs)
        for i, batch in enumerate(batch_iter):
            if self.step in self.extra_save_steps:
                self._save_model(f'_step{self.step}', verbose=True)

            dloss, dout_imgs = self._train_step_triplet(batch)

            with torch.no_grad():
                loss = float(dloss)
                mean_target = 0.  # Dummy value
                misc['mean_target'].append(mean_target)
                stats['tr_loss'].append(loss)
                batch_iter.set_description(f'Training (loss {loss:.4f})')
                self._tracker.update_timeline(
                    [self._timer.t_passed, loss, mean_target])

            # Not using .get_lr()[-1] because ReduceLROnPlateau does not implement get_lr()
            misc['learning_rate'] = self.optimizer.param_groups[0][
                'lr']  # LR for the this iteration
            self._scheduler_step(loss)

            running_vx_size += batch['anchor'].numel()

            self._incr_step(max_runtime, max_steps)
            if i == len(self.train_loader) - 1 or self.terminate:
                # Last step in this epoch or in the whole training
                # Preserve last training batch and network output for later visualization
                for key, img in batch.items():
                    if isinstance(img, torch.Tensor):
                        img = img.detach().cpu().numpy()
                    images[key] = img
                self._put_current_attention_maps_into(images)

                # TODO: The plotting handler abstraction is inadequate here. Figure out how
                #       we can handle plotting cleanly in one place.
                # Outputs are visualized here, while inputs are visualized in the plotting handler
                #  which is called in _run()...
                for name, img in dout_imgs.items():
                    img = img.detach()[0].cpu().numpy(
                    )  # select first item of batch
                    for c in range(img.shape[0]):
                        if img.ndim == 4:  # 3D data
                            img = img[:, img.shape[0] //
                                      2]  # take center slice of depth dim -> 2D
                        self.tb.add_figure(f'tr_samples/{name}_c{c}',
                                           handlers.plot_image(img[c],
                                                               cmap='gray'),
                                           global_step=self.step)

            if self.terminate:
                break

        stats['tr_loss_std'] = np.std(stats['tr_loss'])
        misc['tr_speed'] = len(self.train_loader) / timer.t_passed
        misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6  # MVx

        return stats, misc, images
Beispiel #7
0
    def _train(self, max_steps, max_runtime):
        """Train for one epoch or until max_steps or max_runtime is reached"""
        self.model.train()

        # Scalar training stats that should be logged and written to tensorboard later
        stats: Dict[str,
                    Union[float,
                          List[float]]] = {stat: []
                                           for stat in ['tr_loss']}
        # Other scalars to be logged
        misc: Dict[str, Union[float, List[float]]] = {
            misc: []
            for misc in ['mean_target']
        }
        # Hold image tensors for real-time training sample visualization in tensorboard
        images: Dict[str, np.ndarray] = {}

        running_vx_size = 0  # Counts input sizes (number of pixels/voxels) of training batches
        timer = Timer()
        batch_iter = tqdm(self.train_loader,
                          'Training',
                          total=len(self.train_loader),
                          dynamic_ncols=True)
        unlabeled_iter = None if self.unlabeled_dataset is None else iter(
            self.unlabeled_loader)
        for i, batch in enumerate(batch_iter):
            if self.step in self.extra_save_steps:
                self._save_model(f'_step{self.step}', verbose=True)

            if unlabeled_iter is not None:
                batch['unlabeled'] = next(unlabeled_iter)
            dloss, dout = self._train_step(batch)

            with torch.no_grad():
                loss = float(dloss)
                target = batch.get('target')
                mean_target = float(target.to(
                    torch.float32).mean()) if target is not None else 0.
                misc['mean_target'].append(mean_target)
                stats['tr_loss'].append(loss)
                batch_iter.set_description(f'Training (loss {loss:.4f})')
                self._tracker.update_timeline(
                    [self._timer.t_passed, loss, mean_target])

            # Not using .get_lr()[-1] because ReduceLROnPlateau does not implement get_lr()
            misc['learning_rate'] = self.optimizer.param_groups[0][
                'lr']  # LR for the this iteration
            self._scheduler_step(loss)

            running_vx_size += batch['inp'].numel()

            self._incr_step(max_runtime, max_steps)
            if i == len(self.train_loader) - 1 or self.terminate:
                # Last step in this epoch or in the whole training
                # Preserve last training batch and network output for later visualization
                images['inp'] = batch['inp'].numpy()
                if 'target' in batch:
                    images['target'] = batch['target'].numpy()
                if 'unlabeled' in batch:
                    images['unlabeled'] = batch['unlabeled']
                images['out'] = dout.detach().cpu().numpy()
                self._put_current_attention_maps_into(images)

            if self.terminate:
                break

        stats['tr_loss_std'] = np.std(stats['tr_loss'])
        misc['tr_speed'] = len(self.train_loader) / timer.t_passed
        misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6  # MVx

        return stats, misc, images
Beispiel #8
0
    def __init__(
        self,
        model: torch.nn.Module,
        criterion: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        device: torch.device,
        save_root: str,
        train_dataset: torch.utils.data.Dataset,
        valid_dataset: Optional[torch.utils.data.Dataset] = None,
        unlabeled_dataset: Optional[torch.utils.data.Dataset] = None,
        valid_metrics: Optional[Dict] = None,
        ss_criterion: Optional[torch.nn.Module] = None,
        preview_batch: Optional[torch.Tensor] = None,
        preview_interval: int = 5,
        inference_kwargs: Optional[Dict[str, Any]] = None,
        hparams: Optional[Dict[str, Any]] = None,
        extra_save_steps: Sequence[int] = (),
        exp_name: Optional[str] = None,
        example_input: Optional[torch.Tensor] = None,
        enable_save_trace: bool = False,
        save_jit: Optional[str] = None,
        batch_size: int = 1,
        num_workers: int = 0,
        schedulers: Optional[Dict[Any, Any]] = None,
        overlay_alpha: float = 0.4,
        enable_videos: bool = False,
        enable_tensorboard: bool = True,
        tensorboard_root_path: Optional[str] = None,
        ignore_errors: bool = False,
        ipython_shell: bool = False,
        out_channels: Optional[int] = None,
        sample_plotting_handler: Optional[Callable] = None,
        preview_plotting_handler: Optional[Callable] = None,
        mixed_precision: bool = False,
    ):
        inference_kwargs = {} if inference_kwargs is None else inference_kwargs
        if preview_batch is not None and (
                'tile_shape' not in inference_kwargs or
            ('overlap_shape' not in inference_kwargs
             and 'offset' not in inference_kwargs)):
            raise ValueError(
                'If preview_batch is set, you will also need to specify '
                'tile_shape and overlap_shape or offset in inference_kwargs!')
        if enable_save_trace:
            logger.warning(
                'enable_save_trace is deprecated. Please use the save_jit option instead.'
            )
            assert save_jit in [None, 'trace']
            save_jit = 'trace'

        # Ensure that all nn.Modules are on the right device
        model.to(device)
        if isinstance(criterion, torch.nn.Module):
            criterion.to(device)
        if isinstance(ss_criterion, torch.nn.Module):
            ss_criterion.to(device)

        self.ignore_errors = ignore_errors
        self.ipython_shell = ipython_shell
        self.device = device
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.unlabeled_dataset = unlabeled_dataset
        self.valid_metrics = valid_metrics
        self.ss_criterion = ss_criterion
        self.preview_batch = preview_batch
        self.preview_interval = preview_interval
        self.inference_kwargs = inference_kwargs
        self.extra_save_steps = extra_save_steps
        self.overlay_alpha = overlay_alpha
        self.save_root = os.path.expanduser(save_root)
        self.example_input = example_input
        self.save_jit = save_jit
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.sample_plotting_handler = sample_plotting_handler
        self.preview_plotting_handler = preview_plotting_handler
        self.mixed_precision = mixed_precision

        self._tracker = HistoryTracker()
        self._timer = Timer()
        self._first_plot = True
        self._shell_info = dedent("""
            Entering IPython training shell. To continue, hit Ctrl-D twice.
            To terminate, set self.terminate = True and then hit Ctrl-D twice.
        """).strip()

        self.inference_kwargs.setdefault('batch_size', 1)
        self.inference_kwargs.setdefault('verbose', True)
        self.inference_kwargs.setdefault('apply_softmax', True)

        if self.unlabeled_dataset is not None and self.ss_criterion is None:
            raise ValueError(
                'If an unlabeled_dataset is supplied, you must also set ss_criterion.'
            )

        if hparams is None:
            hparams = {}
        else:
            for k, v in hparams.items():
                if isinstance(v, (tuple, list)):
                    # Convert to str because tensorboardX doesn't support
                    # tuples and lists in add_hparams()
                    hparams[k] = str(v)
        self.hparams = hparams

        if self.mixed_precision:
            from apex import amp
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')

        if exp_name is None:  # Auto-generate a name based on model name and ISO timestamp
            timestamp = datetime.datetime.now().strftime('%y-%m-%d_%H-%M-%S')
            exp_name = model.__class__.__name__ + '__' + timestamp
        self.exp_name = exp_name
        self.save_path = os.path.join(save_root, exp_name)
        if os.path.isdir(self.save_path):
            raise RuntimeError(
                f'{self.save_path} already exists.\nPlease choose a '
                'different combination of save_root and exp_name.')
        os.makedirs(self.save_path)
        _change_log_file_to(f'{self.save_path}/elektronn3.log')
        logger.info(f'Writing files to save_path {self.save_path}/\n')

        self.terminate = False
        self.step = 0
        self.epoch = 0
        if schedulers is None:
            schedulers = {'lr': StepLR(optimizer, 1000, 1)}  # No-op scheduler
        self.schedulers = schedulers
        self.__lr_closetozero_alreadytriggered = False  # Used in periodic scheduler handling
        self._lr_nhood = deque(
            maxlen=3
        )  # Keeps track of the last, current and next learning rate

        self.out_channels = out_channels
        if enable_videos:
            try:
                import moviepy
            except:
                logger.warning(
                    'moviepy is not installed. Disabling video logs.')
                enable_videos = False
        self.enable_videos = enable_videos
        self.tb = None  # Tensorboard handler
        if enable_tensorboard:
            if self.sample_plotting_handler is None:
                self.sample_plotting_handler = handlers._tb_log_sample_images
            if self.preview_plotting_handler is None:
                self.preview_plotting_handler = handlers._tb_log_preview

            if tensorboard_root_path is None:
                tb_path = self.save_path
            else:
                tensorboard_root_path = os.path.expanduser(
                    tensorboard_root_path)
                tb_path = os.path.join(tensorboard_root_path, self.exp_name)
                os.makedirs(tb_path, exist_ok=True)
            self.tb = tensorboardX.SummaryWriter(logdir=tb_path, flush_secs=20)

            if self.hparams:
                self.tb.add_hparams(hparam_dict=self.hparams, metric_dict={})

        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            timeout=60 if self.num_workers > 0 else 0,
            worker_init_fn=_worker_init_fn)
        if valid_dataset is not None:
            self.valid_loader = DataLoader(self.valid_dataset,
                                           self.batch_size,
                                           shuffle=True,
                                           num_workers=self.num_workers,
                                           pin_memory=True,
                                           worker_init_fn=_worker_init_fn)
        if self.unlabeled_dataset is not None:
            self.unlabeled_loader = DataLoader(
                self.unlabeled_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers,
                pin_memory=True,
                timeout=60 if self.num_workers > 0 else 0,
                worker_init_fn=_worker_init_fn)

        self.best_val_loss = np.inf  # Best recorded validation loss
        self.best_tr_loss = np.inf

        self.valid_metrics = {} if valid_metrics is None else valid_metrics
Beispiel #9
0
    def run(self, max_steps: int = 1) -> None:
        """Train the network for ``max_steps`` steps.

        After each training epoch, validation performance is measured and
        visualizations are computed and logged to tensorboard."""
        while self.step < max_steps:
            try:
                # --> self.train()
                self.model.train()

                # Scalar training stats that should be logged and written to tensorboard later
                stats: Dict[str, float] = {'tr_loss_G': .0, 'tr_loss_D': .0}
                # Other scalars to be logged
                misc: Dict[str, float] = {
                    'G_loss_advreg': .0,
                    'G_loss_tnet': .0,
                    'G_loss_l2': .0,
                    'D_loss_fake': .0,
                    'D_loss_real': .0
                }
                # Hold image tensors for real-time training sample visualization in tensorboard
                images: Dict[str, torch.Tensor] = {}

                running_error = 0
                running_mean_target = 0
                running_vx_size = 0
                timer = Timer()
                latent_points_fake = []
                latent_points_real = []
                for inp in self.train_loader:  # ref., pos., neg. samples
                    if inp.size()[1] != 3:
                        raise ValueError(
                            "Data must not contain targets. "
                            "Input data shape is assumed to be "
                            "(N, 3, ch, x, y), where the first two"
                            " images in each sample is the similar"
                            " pair, while the third one is the "
                            "distant one.")
                    inp0 = Variable(inp[:, 0].to(self.device))
                    inp1 = Variable(inp[:, 1].to(self.device))
                    inp2 = Variable(inp[:, 2].to(self.device))
                    self.optimizer.zero_grad()
                    # forward pass
                    dA, dB, z0, z1, z2 = self.model(inp0, inp1, inp2)
                    z_fake_gauss = torch.squeeze(torch.cat((z0, z1, z2),
                                                           dim=1))
                    target = torch.FloatTensor(dA.size()).fill_(-1).to(
                        self.device)
                    target = Variable(target)
                    loss = self.criterion(dA, dB, target)
                    L_l2 = torch.mean(
                        torch.cat((z0.norm(1, dim=1), z1.norm(
                            1, dim=1), z2.norm(1, dim=1)),
                                  dim=0))
                    misc['G_loss_l2'] += self.alpha * float(L_l2)
                    loss = loss + self.alpha * L_l2
                    misc['G_loss_tnet'] += (1 - self.alpha2) * float(
                        loss)  # log actual loss
                    if torch.isnan(loss):
                        logger.error('NaN loss detected after {self.step} '
                                     'steps! Aborting training.')
                        raise NaNException

                    # Adversarial part to enforce latent variable distribution
                    # to be Normal / whatever prior is used
                    if self.alpha2 > 0:
                        self.optimizer_discr.zero_grad()
                        # adversarial labels
                        valid = Variable(torch.Tensor(inp0.size()[0],
                                                      1).fill_(1.0),
                                         requires_grad=False).to(self.device)
                        fake = Variable(torch.Tensor(inp0.shape[0],
                                                     1).fill_(0.0),
                                        requires_grad=False).to(self.device)

                        # --- Generator / TripletNet
                        self.model_discr.eval()
                        # TripletNet latent space should be classified as valid
                        L_advreg = self.criterion_discr(
                            self.model_discr(z_fake_gauss), valid)
                        # average adversarial reg. and triplet-loss
                        loss = (1 -
                                self.alpha2) * loss + self.alpha2 * L_advreg
                        # perform generator step
                        loss.backward()
                        self.optimizer.step()

                        # --- Discriminator
                        self.model.eval()
                        self.model_discr.train()
                        # rebuild graph (model output) to get clean backprop.
                        z_real_gauss = Variable(
                            self.latent_distr(inp0.size()[0],
                                              z0.size()[-1] * 3)).to(
                                                  self.device)
                        _, _, z_fake_gauss0, z_fake_gauss1, z_fake_gauss2 = self.model(
                            inp0, inp1, inp2)
                        z_fake_gauss = torch.squeeze(
                            torch.cat(
                                (z_fake_gauss0, z_fake_gauss1, z_fake_gauss2),
                                dim=1))
                        # Compute discriminator outputs and loss
                        L_real_gauss = self.criterion_discr(
                            self.model_discr(z_real_gauss), valid)
                        L_fake_gauss = self.criterion_discr(
                            self.model_discr(z_fake_gauss), fake)
                        L_discr = 0.5 * (L_real_gauss + L_fake_gauss)
                        L_discr.backward()  # Backprop loss
                        self.optimizer_discr.step()  # Apply optimization step
                        self.model.train()  # set back to training mode

                        # # clean and report
                        L_discr.detach()
                        L_advreg.detach()
                        L_real_gauss.detach()
                        L_fake_gauss.detach()
                        stats['tr_loss_D'] += float(L_discr)
                        misc['G_loss_advreg'] += self.alpha2 * float(
                            L_advreg)  # log actual part of advreg
                        misc['D_loss_real'] += float(L_real_gauss)
                        misc['D_loss_fake'] += float(L_fake_gauss)
                        latent_points_real.append(
                            z_real_gauss.detach().cpu().numpy())
                    else:
                        loss.backward()
                        self.optimizer.step()

                    latent_points_fake.append(
                        z_fake_gauss.detach().cpu().numpy())
                    # # Prevent accidental autograd overheads after optimizer step
                    inp.detach()
                    target.detach()
                    dA.detach()
                    dB.detach()
                    z0.detach()
                    z1.detach()
                    z2.detach()
                    loss.detach()
                    L_l2.detach()

                    # get training performance
                    stats['tr_loss_G'] += float(loss)
                    error = calculate_error(dA, dB)
                    mean_target = target.to(torch.float32).mean()
                    print(f'{self.step:6d}, loss: {loss:.4f}', end='\r')
                    self._tracker.update_timeline(
                        [self._timer.t_passed,
                         float(loss), mean_target])

                    # Preserve training batch and network output for later visualization
                    images['inp_ref'] = inp0.cpu().numpy()
                    images['inp_+'] = inp1.cpu().numpy()
                    images['inp_-'] = inp2.cpu().numpy()
                    # this was changed to support ReduceLROnPlateau which does not implement get_lr
                    misc['learning_rate_G'] = self.optimizer.param_groups[0][
                        "lr"]  # .get_lr()[-1]
                    misc[
                        'learning_rate_D'] = self.optimizer_discr.param_groups[
                            0]["lr"]  # .get_lr()[-1]
                    # update schedules
                    for sched in self.schedulers.values():
                        # support ReduceLROnPlateau; doc. uses validation loss instead
                        # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
                        if "metrics" in inspect.signature(
                                sched.step).parameters:
                            sched.step(metrics=float(loss))
                        else:
                            sched.step()
                    running_error += error
                    running_mean_target += mean_target
                    running_vx_size += inp.numel()

                    self.step += 1
                    if self.step >= max_steps:
                        break
                stats['tr_err_G'] = float(running_error) / len(
                    self.train_loader)
                stats['tr_loss_G'] /= len(self.train_loader)
                stats['tr_loss_D'] /= len(self.train_loader)
                misc['G_loss_advreg'] /= len(self.train_loader)
                misc['G_loss_tnet'] /= len(self.train_loader)
                misc['G_loss_l2'] /= len(self.train_loader)
                misc['D_loss_fake'] /= len(self.train_loader)
                misc['D_loss_real'] /= len(self.train_loader)
                misc['tr_speed'] = len(self.train_loader) / timer.t_passed
                misc[
                    'tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6  # MVx
                mean_target = running_mean_target / len(self.train_loader)
                if (self.valid_dataset is None) or (1 != np.random.randint(
                        0, 10)):  # only validate 10% of the times
                    stats['val_loss_G'], stats['val_err_G'] = float(
                        'nan'), float('nan')
                else:
                    stats['val_loss_G'], stats['val_err_G'] = self._validate()
                # TODO: Report more metrics, e.g. dice error

                # Update history tracker (kind of made obsolete by tensorboard)
                # TODO: Decide what to do with this, now that most things are already in tensorboard.
                if self.step // len(self.train_dataset) > 1:
                    tr_loss_gain = self._tracker.history[-1][2] - stats[
                        'tr_loss_G']
                else:
                    tr_loss_gain = 0
                self._tracker.update_history([
                    self.step, self._timer.t_passed, stats['tr_loss_G'],
                    stats['val_loss_G'], tr_loss_gain, stats['tr_err_G'],
                    stats['val_err_G'], misc['learning_rate_G'], 0, 0
                ])  # 0's correspond to mom and gradnet (?)
                t = pretty_string_time(self._timer.t_passed)
                loss_smooth = self._tracker.loss._ema

                # Logging to stdout, text log file
                text = "%05i L_m=%.3f, L=%.2f, tr=%05.2f%%, " % (
                    self.step, loss_smooth, stats['tr_loss_G'],
                    stats['tr_err_G'])
                text += "vl=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % (
                    stats['val_err_G'], "%", mean_target * 100, tr_loss_gain)
                text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % (
                    misc['learning_rate_G'], misc['tr_speed'],
                    misc['tr_speed_vx'], t)
                logger.info(text)

                # Plot tracker stats to pngs in save_path
                self._tracker.plot(self.save_path)

                # Reporting to tensorboard logger
                if self.tb:
                    self._tb_log_scalars(stats, 'stats')
                    self._tb_log_scalars(misc, 'misc')
                    self.tb_log_sample_images(images, group='tr_samples')

                # save histrograms
                if len(latent_points_fake) > 0:
                    fig, ax = plt.subplots()
                    sns.distplot(np.concatenate(latent_points_fake).flatten())
                    # plt.savefig(os.path.join(self.save_path,
                    #                          'latent_fake_{}.png'.format(self.step)))
                    fig.canvas.draw()
                    img_data = np.array(fig.canvas.renderer._renderer)
                    self.tb.add_figure(f'latent_distr/latent_fake',
                                       plot_image(img_data),
                                       global_step=self.step)
                    plt.close()

                if len(latent_points_real) > 0:
                    fig, ax = plt.subplots()
                    sns.distplot(np.concatenate(latent_points_real).flatten())
                    # plt.savefig(os.path.join(self.save_path,
                    #                          'latent_real_{}.png'.format(self.step)))
                    fig.canvas.draw()
                    img_data = np.array(fig.canvas.renderer._renderer)
                    self.tb.add_figure(f'latent_distr/latent_real',
                                       plot_image(img_data),
                                       global_step=self.step)
                    plt.close()

                    # grab the pixel buffer and dump it into a numpy array

                # Save trained model state
                torch.save(
                    self.model.state_dict(),
                    # os.path.join(self.save_path, f'model-{self.step:06d}.pth')  # Saving with different file names leads to heaps of large files,
                    os.path.join(self.save_path, 'model-checkpoint.pth'))
                # TODO: Also save "best" model, not only the latest one, which is often overfitted.
                #       -> "best" in which regard? Lowest validation loss, validation error?
                #          We can't blindly trust these metrics and may have to calculate
                #          additional metrics (with focus on object boundary correctness).
            except KeyboardInterrupt:
                IPython.embed(header=self._shell_info)
                if self.terminate:
                    return
            except Exception as e:
                traceback.print_exc()
                if self.ignore_errors:
                    # Just print the traceback and try to carry on with training.
                    # This can go wrong in unexpected ways, so don't leave the training unattended.
                    pass
                elif self.ipython_shell:
                    print("\nEntering Command line such that Exception can be "
                          "further inspected by user.\n\n")
                    IPython.embed(header=self._shell_info)
                    if self.terminate:
                        return
                else:
                    raise e
        torch.save(
            self.model.state_dict(),
            os.path.join(self.save_path, f'model-final-{self.step:06d}.pth'))