Ejemplo n.º 1
0
def train(cf, logger):
    """
    performs the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs. logs to file and tensorboard.
    """
    logger.info(
        'performing training in {}D over fold {} on experiment {} with model {}'
        .format(cf.dim, cf.fold, cf.exp_dir, cf.model))
    logger.time("train_val")

    # -------------- inits and settings -----------------
    net = model.net(cf, logger).cuda()
    if cf.optimizer == "ADAMW":
        optimizer = torch.optim.AdamW(utils.parse_params_for_optim(
            net,
            weight_decay=cf.weight_decay,
            exclude_from_wd=cf.exclude_from_wd),
                                      lr=cf.learning_rate[0])
    elif cf.optimizer == "SGD":
        optimizer = torch.optim.SGD(utils.parse_params_for_optim(
            net, weight_decay=cf.weight_decay),
                                    lr=cf.learning_rate[0],
                                    momentum=0.3)
    if cf.dynamic_lr_scheduling:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=cf.scheduling_mode,
            factor=cf.lr_decay_factor,
            patience=cf.scheduling_patience)
    model_selector = utils.ModelSelector(cf, logger)

    starting_epoch = 1
    if cf.resume:
        checkpoint_path = os.path.join(cf.fold_dir, "last_state.pth")
        starting_epoch, net, optimizer, model_selector = \
            utils.load_checkpoint(checkpoint_path, net, optimizer, model_selector)
        logger.info('resumed from checkpoint {} to epoch {}'.format(
            checkpoint_path, starting_epoch))

    # prepare monitoring
    monitor_metrics = utils.prepare_monitoring(cf)

    logger.info('loading dataset and initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, logger)

    # -------------- training -----------------
    for epoch in range(starting_epoch, cf.num_epochs + 1):

        logger.info('starting training epoch {}/{}'.format(
            epoch, cf.num_epochs))
        logger.time("train_epoch")

        net.train()

        train_results_list = []
        train_evaluator = Evaluator(cf, logger, mode='train')

        for i in range(cf.num_train_batches):
            logger.time("train_batch_loadfw")
            batch = next(batch_gen['train'])
            batch_gen['train'].generator.stats['roi_counts'] += batch[
                'roi_counts']
            batch_gen['train'].generator.stats['empty_counts'] += batch[
                'empty_counts']

            logger.time("train_batch_loadfw")
            logger.time("train_batch_netfw")
            results_dict = net.train_forward(batch)
            logger.time("train_batch_netfw")
            logger.time("train_batch_bw")
            optimizer.zero_grad()
            results_dict['torch_loss'].backward()
            if cf.clip_norm:
                torch.nn.utils.clip_grad_norm_(
                    net.parameters(), cf.clip_norm,
                    norm_type=2)  # gradient clipping
            optimizer.step()
            train_results_list.append(
                ({k: v
                  for k, v in results_dict.items()
                  if k != "seg_preds"}, batch["pid"]))  # slim res dict
            if not cf.server_env:
                print(
                    "\rFinished training batch " +
                    "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw)."
                    .format(
                        i + 1, cf.num_train_batches,
                        logger.get_time("train_batch_loadfw") +
                        logger.get_time("train_batch_netfw") +
                        logger.time("train_batch_bw"),
                        logger.get_time("train_batch_loadfw", reset=True),
                        logger.get_time("train_batch_netfw", reset=True),
                        logger.get_time("train_batch_bw", reset=True)),
                    end="",
                    flush=True)
        print()

        #--------------- train eval ----------------
        if (epoch - 1) % cf.plot_frequency == 0:
            # view an example batch
            utils.split_off_process(
                plg.view_batch,
                cf,
                batch,
                results_dict,
                has_colorchannels=cf.has_colorchannels,
                show_gt_labels=True,
                get_time="train-example plot",
                out_file=os.path.join(
                    cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold)))

        logger.time("evals")
        _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(
            train_results_list, monitor_metrics['train'])
        logger.time("evals")
        logger.time("train_epoch", toggle=False)
        del train_results_list

        #----------- validation ------------
        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        logger.time("val_epoch")
        with torch.no_grad():
            net.eval()
            val_results_list = []
            val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)
            val_predictor = Predictor(cf, net, logger, mode='val')

            for i in range(batch_gen['n_val']):
                logger.time("val_batch")
                batch = next(batch_gen[cf.val_mode])
                if cf.val_mode == 'val_patient':
                    results_dict = val_predictor.predict_patient(batch)
                elif cf.val_mode == 'val_sampling':
                    results_dict = net.train_forward(batch, is_validation=True)
                val_results_list.append([results_dict, batch["pid"]])
                if not cf.server_env:
                    print("\rFinished validation {} {}/{} in {:.1f}s.".format(
                        'patient' if cf.val_mode == 'val_patient' else 'batch',
                        i + 1, batch_gen['n_val'], logger.time("val_batch")),
                          end="",
                          flush=True)
            print()

            #------------ val eval -------------
            if (epoch - 1) % cf.plot_frequency == 0:
                utils.split_off_process(plg.view_batch,
                                        cf,
                                        batch,
                                        results_dict,
                                        has_colorchannels=cf.has_colorchannels,
                                        show_gt_labels=True,
                                        get_time="val-example plot",
                                        out_file=os.path.join(
                                            cf.plot_dir,
                                            'batch_example_val_{}.png'.format(
                                                cf.fold)))

            logger.time("evals")
            _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(
                val_results_list, monitor_metrics['val'])

            model_selector.run_model_selection(net, optimizer, monitor_metrics,
                                               epoch)
            del val_results_list
            #----------- monitoring -------------
            monitor_metrics.update({
                "lr": {
                    str(g): group['lr']
                    for (g, group) in enumerate(optimizer.param_groups)
                }
            })
            logger.metrics2tboard(monitor_metrics, global_step=epoch)
            logger.time("evals")

            logger.info(
                'finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'
                .format(
                    epoch, cf.num_epochs,
                    logger.get_time("train_epoch") + logger.time("val_epoch"),
                    logger.get_time("train_epoch"),
                    logger.get_time("train_epoch", reset=True) /
                    cf.num_train_batches, logger.get_time("val_epoch"),
                    logger.get_time("val_epoch", reset=True) /
                    batch_gen["n_val"]))
            logger.info("time for evals: {:.2f}s".format(
                logger.get_time("evals", reset=True)))

        #-------------- scheduling -----------------
        if cf.dynamic_lr_scheduling:
            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
        else:
            for param_group in optimizer.param_groups:
                param_group['lr'] = cf.learning_rate[epoch - 1]

    logger.time("train_val")
    logger.info("Training and validating over {} epochs took {}".format(
        cf.num_epochs, logger.get_time("train_val", format="hms", reset=True)))
    batch_gen['train'].generator.print_stats(logger, plot=True)
    def return_metrics(self, monitor_metrics=None):
        """
        calculates AP/AUC scores for internal dataframe. called directly from evaluate_predictions during training for monitoring,
        or from score_test_df during inference (for single folds or aggregated test set). Loops over foreground classes
        and score_levels (typically 'roi' and 'patient'), gets scores and stores them. Optionally creates plots of
        prediction histograms and roc/prc curves.
        :param monitor_metrics: dict of dicts with all metrics of previous epochs.
        this function adds metrics for current epoch and returns the same object.
        :return: all_stats: list. Contains dicts with resulting scores for each combination of foreground class and
        score_level.
        :return: monitor_metrics
        """

        # -------------- monitoring independent of class, score level ------------
        if monitor_metrics is not None:
            for l_name in self.epoch_losses:
                monitor_metrics[l_name] = [self.epoch_losses[l_name]]

        df = self.test_df

        all_stats = []
        for cl in list(self.cf.class_dict.keys()):
            cl_df = df[df.pred_class == cl]

            for score_level in self.cf.report_score_level:
                stats_dict = {}
                stats_dict['name'] = 'fold_{} {} cl_{}'.format(
                    self.cf.fold, score_level, cl)

                if score_level == 'rois':
                    # kick out dummy entries for true negative patients. not needed on roi-level.
                    spec_df = cl_df[cl_df.det_type != 'patient_tn']
                    stats_dict['ap'] = get_roi_ap_from_df([
                        spec_df, self.cf.min_det_thresh, self.cf.per_patient_ap
                    ])
                    # AUC not sensible on roi-level, since true negative box predictions do not exist. Would reward
                    # higher amounts of low confidence false positives.
                    stats_dict['auc'] = np.nan
                    stats_dict['roc'] = np.nan
                    stats_dict['prc'] = np.nan

                    # for the aggregated test set case, additionally get the scores for averaging over fold results.
                    if len(df.fold.unique()) > 1:
                        aps = []
                        for fold in df.fold.unique():
                            fold_df = spec_df[spec_df.fold == fold]
                            aps.append(
                                get_roi_ap_from_df([
                                    fold_df, self.cf.min_det_thresh,
                                    self.cf.per_patient_ap
                                ]))
                        stats_dict['mean_ap'] = np.mean(aps)
                        stats_dict['mean_auc'] = 0

                # on patient level, aggregate predictions per patient (pid): The patient predicted score is the highest
                # confidence prediction for this class. The patient class label is 1 if roi of this class exists in patient, else 0.
                if score_level == 'patient':
                    spec_df = cl_df.groupby(['pid'], as_index=False).agg({
                        'class_label':
                        'max',
                        'pred_score':
                        'max',
                        'fold':
                        'first'
                    })

                    if len(spec_df.class_label.unique()) > 1:
                        stats_dict['auc'] = roc_auc_score(
                            spec_df.class_label.tolist(),
                            spec_df.pred_score.tolist())
                        stats_dict['roc'] = roc_curve(
                            spec_df.class_label.tolist(),
                            spec_df.pred_score.tolist())
                    else:
                        stats_dict['auc'] = np.nan
                        stats_dict['roc'] = np.nan

                    if (spec_df.class_label == 1).any():
                        stats_dict['ap'] = average_precision_score(
                            spec_df.class_label.tolist(),
                            spec_df.pred_score.tolist())
                        stats_dict['prc'] = precision_recall_curve(
                            spec_df.class_label.tolist(),
                            spec_df.pred_score.tolist())
                    else:
                        stats_dict['ap'] = np.nan
                        stats_dict['prc'] = np.nan

                    # for the aggregated test set case, additionally get the scores for averaging over fold results.
                    if len(df.fold.unique()) > 1:
                        aucs = []
                        aps = []
                        for fold in df.fold.unique():
                            fold_df = spec_df[spec_df.fold == fold]
                            if len(fold_df.class_label.unique()) > 1:
                                aucs.append(
                                    roc_auc_score(fold_df.class_label.tolist(),
                                                  fold_df.pred_score.tolist()))
                            if (fold_df.class_label == 1).any():
                                aps.append(
                                    average_precision_score(
                                        fold_df.class_label.tolist(),
                                        fold_df.pred_score.tolist()))
                        stats_dict['mean_auc'] = np.mean(aucs)
                        stats_dict['mean_ap'] = np.mean(aps)

                # fill new results into monitor_metrics dict. for simplicity, only one class (of interest) is monitored on patient level.
                if monitor_metrics is not None and not (
                        score_level == 'patient'
                        and cl != self.cf.patient_class_of_interest):
                    score_level_name = 'patient' if score_level == 'patient' else self.cf.class_dict[
                        cl]
                    monitor_metrics[score_level_name + '_ap'].append(
                        stats_dict['ap'] if stats_dict['ap'] > 0 else np.nan)
                    if score_level == 'patient':
                        monitor_metrics[score_level_name + '_auc'].append(
                            stats_dict['auc'] if stats_dict['auc'] > 0 else np.
                            nan)

                if self.cf.plot_prediction_histograms:
                    out_filename = os.path.join(
                        self.hist_dir, 'pred_hist_{}_{}_{}_cl{}'.format(
                            self.cf.fold,
                            'val' if 'val' in self.mode else self.mode,
                            score_level, cl))
                    type_list = None if score_level == 'patient' else spec_df.det_type.tolist(
                    )
                    utils.split_off_process(plotting.plot_prediction_hist,
                                            spec_df.class_label.tolist(),
                                            spec_df.pred_score.tolist(),
                                            type_list, out_filename)

                all_stats.append(stats_dict)

                # analysis of the  hyper-parameter cf.min_det_thresh, for optimization on validation set.
                if self.cf.scan_det_thresh:
                    conf_threshs = list(np.arange(0.9, 1, 0.01))
                    pool = Pool(processes=10)
                    mp_inputs = [[spec_df, ii, self.cf.per_patient_ap]
                                 for ii in conf_threshs]
                    aps = pool.map(get_roi_ap_from_df, mp_inputs, chunksize=1)
                    pool.close()
                    pool.join()
                    self.logger.info('results from scanning over det_threshs:',
                                     [[i, j]
                                      for i, j in zip(conf_threshs, aps)])

        if self.cf.plot_stat_curves:
            out_filename = os.path.join(
                self.curves_dir,
                '{}_{}_stat_curves'.format(self.cf.fold, self.mode))
            utils.split_off_process(plotting.plot_stat_curves, all_stats,
                                    out_filename)

        # get average stats over foreground classes on roi level.
        avg_ap = np.mean([d['ap'] for d in all_stats if 'rois' in d['name']])
        all_stats.append({
            'name': 'average_foreground_roi',
            'auc': 0,
            'ap': avg_ap
        })
        if len(df.fold.unique()) > 1:
            avg_mean_ap = np.mean(
                [d['mean_ap'] for d in all_stats if 'rois' in d['name']])
            all_stats[-1]['mean_ap'] = avg_mean_ap
            all_stats[-1]['mean_auc'] = 0

        # in small data sets, values of model_selection_criterion can be identical across epochs, wich breaks the
        # ranking of model_selector. Thus, pertube identical values by a neglectibale random term.
        for sc in self.cf.model_selection_criteria:
            if 'val' in self.mode and monitor_metrics[sc].count(
                    monitor_metrics[sc]
                [-1]) > 1 and monitor_metrics[sc][-1] is not None:
                monitor_metrics[sc][-1] += 1e-6 * np.random.rand()

        return all_stats, monitor_metrics
Ejemplo n.º 3
0
def train(logger):
    """
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs.
    """
    logger.info(
        'performing training in {}D over fold {} on experiment {} with model {}'
        .format(cf.dim, cf.fold, cf.exp_dir, cf.model))

    net = model.net(cf, logger).cuda()
    if hasattr(cf, "optimizer") and cf.optimizer.lower() == "adam":
        logger.info("Using Adam optimizer.")
        optimizer = torch.optim.Adam(utils.parse_params_for_optim(
            net,
            weight_decay=cf.weight_decay,
            exclude_from_wd=cf.exclude_from_wd),
                                     lr=cf.learning_rate[0])
    else:
        logger.info("Using AdamW optimizer.")
        optimizer = torch.optim.AdamW(utils.parse_params_for_optim(
            net,
            weight_decay=cf.weight_decay,
            exclude_from_wd=cf.exclude_from_wd),
                                      lr=cf.learning_rate[0])

    if cf.dynamic_lr_scheduling:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=cf.scheduling_mode,
            factor=cf.lr_decay_factor,
            patience=cf.scheduling_patience)

    model_selector = utils.ModelSelector(cf, logger)
    train_evaluator = Evaluator(cf, logger, mode='train')
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)

    starting_epoch = 1

    # prepare monitoring
    monitor_metrics = utils.prepare_monitoring(cf)

    if cf.resume:
        checkpoint_path = os.path.join(cf.fold_dir, "last_checkpoint")
        starting_epoch, net, optimizer, monitor_metrics = \
            utils.load_checkpoint(checkpoint_path, net, optimizer)
        logger.info('resumed from checkpoint {} to epoch {}'.format(
            checkpoint_path, starting_epoch))

    logger.info('loading dataset and initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, logger)

    # Prepare MLFlow
    best_loss = 1e3
    step = 1
    mlflow.log_artifacts(cf.exp_dir, "exp")

    for epoch in range(starting_epoch, cf.num_epochs + 1):

        logger.info('starting training epoch {}'.format(epoch))
        start_time = time.time()

        net.train()
        train_results_list = []
        bix = 0
        seen_pids = []
        while True:
            bix = bix + 1
            try:
                batch = next(batch_gen['train'])
            except StopIteration:
                break
            for pid in batch['pid']:
                seen_pids.append(pid)
            # print(f'\rtr. batch {bix}: {batch["pid"]}')
            tic_fw = time.time()
            results_dict = net.train_forward(batch)
            tic_bw = time.time()
            optimizer.zero_grad()
            results_dict['torch_loss'].backward()
            optimizer.step()
            print(
                '\rtr. batch {0} (ep. {1}) fw {2:.2f}s / bw {3:.2f} s / total {4:.2f} s || '
                .format(bix + 1, epoch, tic_bw - tic_fw,
                        time.time() - tic_bw,
                        time.time() - tic_fw) + results_dict['logger_string'],
                flush=True,
                end="")
            train_results_list.append(
                ({k: v
                  for k, v in results_dict.items()
                  if k != "seg_preds"}, batch["pid"]))
        print(f"Seen pids (unique): {len(np.unique(seen_pids))}")
        print()

        _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(
            train_results_list, monitor_metrics['train'])

        logger.info('generating training example plot.')
        utils.split_off_process(
            plot_batch_prediction,
            batch,
            results_dict,
            cf,
            outfile=os.path.join(cf.plot_dir,
                                 'pred_example_{}_train.png'.format(cf.fold)))

        train_time = time.time() - start_time

        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        with torch.no_grad():
            net.eval()
            if cf.do_validation:
                val_results_list = []
                val_predictor = Predictor(cf, net, logger, mode='val')
                while True:
                    try:
                        batch = next(batch_gen[cf.val_mode])
                    except StopIteration:
                        break
                    if cf.val_mode == 'val_patient':
                        results_dict = val_predictor.predict_patient(batch)
                    elif cf.val_mode == 'val_sampling':
                        results_dict = net.train_forward(batch,
                                                         is_validation=True)
                    val_results_list.append(({
                        k: v
                        for k, v in results_dict.items() if k != "seg_preds"
                    }, batch["pid"]))

                _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(
                    val_results_list, monitor_metrics['val'])
                best_model_path = model_selector.run_model_selection(
                    net, optimizer, monitor_metrics, epoch)
                # Save best model
                mlflow.log_artifacts(
                    best_model_path,
                    os.path.join("exp", os.path.basename(cf.fold_dir),
                                 'best_checkpoint'))

            # Save logs and plots
            mlflow.log_artifacts(os.path.join(cf.exp_dir, "logs"),
                                 os.path.join("exp", 'logs'))
            mlflow.log_artifacts(
                cf.plot_dir, os.path.join("exp",
                                          os.path.basename(cf.plot_dir)))

            # update monitoring and prediction plots
            monitor_metrics.update({
                "lr": {
                    str(g): group['lr']
                    for (g, group) in enumerate(optimizer.param_groups)
                }
            })

            # replace tboard metrics with MLFlow
            #logger.metrics2tboard(monitor_metrics, global_step=epoch)
            mlflow.log_metric('learning rate', optimizer.param_groups[0]['lr'],
                              cf.num_epochs * cf.fold + epoch)
            for key in ['train', 'val']:
                for tag, val in monitor_metrics[key].items():
                    val = val[
                        -1]  # maybe remove list wrapping, recording in evaluator?
                    if 'loss' in tag.lower() and not np.isnan(val):
                        mlflow.log_metric(f'{key}_{tag}', val,
                                          cf.num_epochs * cf.fold + epoch)
                    elif not np.isnan(val):
                        mlflow.log_metric(f'{key}_{tag}', val,
                                          cf.num_epochs * cf.fold + epoch)

            epoch_time = time.time() - start_time
            logger.info('trained epoch {}: took {} ({} train / {} val)'.format(
                epoch, utils.get_formatted_duration(epoch_time, "ms"),
                utils.get_formatted_duration(train_time, "ms"),
                utils.get_formatted_duration(epoch_time - train_time, "ms")))
            batch = next(batch_gen['val_sampling'])
            results_dict = net.train_forward(batch, is_validation=True)
            logger.info('generating validation-sampling example plot.')
            utils.split_off_process(plot_batch_prediction,
                                    batch,
                                    results_dict,
                                    cf,
                                    outfile=os.path.join(
                                        cf.plot_dir,
                                        'pred_example_{}_val.png'.format(
                                            cf.fold)))

        # -------------- scheduling -----------------
        if cf.dynamic_lr_scheduling:
            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
        else:
            for param_group in optimizer.param_groups:
                param_group['lr'] = cf.learning_rate[epoch - 1]
    # Save whole experiment to MLFlow
    mlflow.log_artifacts(cf.exp_dir, "exp")
Ejemplo n.º 4
0
def train(logger):
    """
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs.
    """
    logger.info(
        'performing training in {}D over fold {} on experiment {} with model {}'
        .format(cf.dim, cf.fold, cf.exp_dir, cf.model))

    net = model.net(cf, logger).cuda()
    if hasattr(cf, "optimizer") and cf.optimizer.lower() == "adam":
        logger.info("Using Adam optimizer.")
        optimizer = torch.optim.Adam(utils.parse_params_for_optim(
            net,
            weight_decay=cf.weight_decay,
            exclude_from_wd=cf.exclude_from_wd),
                                     lr=cf.learning_rate[0])
    else:
        logger.info("Using AdamW optimizer.")
        optimizer = torch.optim.AdamW(utils.parse_params_for_optim(
            net,
            weight_decay=cf.weight_decay,
            exclude_from_wd=cf.exclude_from_wd),
                                      lr=cf.learning_rate[0])

    if cf.dynamic_lr_scheduling:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=cf.scheduling_mode,
            factor=cf.lr_decay_factor,
            patience=cf.scheduling_patience)

    model_selector = utils.ModelSelector(cf, logger)
    train_evaluator = Evaluator(cf, logger, mode='train')
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)

    starting_epoch = 1

    # prepare monitoring
    monitor_metrics = utils.prepare_monitoring(cf)

    if cf.resume:
        checkpoint_path = os.path.join(cf.fold_dir, "last_checkpoint")
        starting_epoch, net, optimizer, monitor_metrics = \
            utils.load_checkpoint(checkpoint_path, net, optimizer)
        logger.info('resumed from checkpoint {} to epoch {}'.format(
            checkpoint_path, starting_epoch))

    logger.info('loading dataset and initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, logger)

    for epoch in range(starting_epoch, cf.num_epochs + 1):

        logger.info('starting training epoch {}'.format(epoch))
        start_time = time.time()

        net.train()
        train_results_list = []
        for bix in range(cf.num_train_batches):
            batch = next(batch_gen['train'])
            tic_fw = time.time()
            results_dict = net.train_forward(batch)
            tic_bw = time.time()
            optimizer.zero_grad()
            results_dict['torch_loss'].backward()
            optimizer.step()
            print(
                '\rtr. batch {0}/{1} (ep. {2}) fw {3:.2f}s / bw {4:.2f} s / total {5:.2f} s || '
                .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw,
                        time.time() - tic_bw,
                        time.time() - tic_fw) + results_dict['logger_string'],
                flush=True,
                end="")
            train_results_list.append(
                ({k: v
                  for k, v in results_dict.items()
                  if k != "seg_preds"}, batch["pid"]))
        print()

        _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(
            train_results_list, monitor_metrics['train'])

        logger.info('generating training example plot.')
        utils.split_off_process(
            plot_batch_prediction,
            batch,
            results_dict,
            cf,
            outfile=os.path.join(cf.plot_dir,
                                 'pred_example_{}_train.png'.format(cf.fold)))

        train_time = time.time() - start_time

        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        with torch.no_grad():
            net.eval()
            if cf.do_validation:
                val_results_list = []
                val_predictor = Predictor(cf, net, logger, mode='val')
                for _ in range(batch_gen['n_val']):
                    batch = next(batch_gen[cf.val_mode])
                    if cf.val_mode == 'val_patient':
                        results_dict = val_predictor.predict_patient(batch)
                    elif cf.val_mode == 'val_sampling':
                        results_dict = net.train_forward(batch,
                                                         is_validation=True)
                    #val_results_list.append([results_dict['boxes'], batch['pid']])
                    val_results_list.append(({
                        k: v
                        for k, v in results_dict.items() if k != "seg_preds"
                    }, batch["pid"]))

                _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(
                    val_results_list, monitor_metrics['val'])
                model_selector.run_model_selection(net, optimizer,
                                                   monitor_metrics, epoch)

            # update monitoring and prediction plots
            monitor_metrics.update({
                "lr": {
                    str(g): group['lr']
                    for (g, group) in enumerate(optimizer.param_groups)
                }
            })
            logger.metrics2tboard(monitor_metrics, global_step=epoch)

            epoch_time = time.time() - start_time
            logger.info('trained epoch {}: took {} ({} train / {} val)'.format(
                epoch, utils.get_formatted_duration(epoch_time, "ms"),
                utils.get_formatted_duration(train_time, "ms"),
                utils.get_formatted_duration(epoch_time - train_time, "ms")))
            batch = next(batch_gen['val_sampling'])
            results_dict = net.train_forward(batch, is_validation=True)
            logger.info('generating validation-sampling example plot.')
            utils.split_off_process(plot_batch_prediction,
                                    batch,
                                    results_dict,
                                    cf,
                                    outfile=os.path.join(
                                        cf.plot_dir,
                                        'pred_example_{}_val.png'.format(
                                            cf.fold)))

        # -------------- scheduling -----------------
        if cf.dynamic_lr_scheduling:
            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
        else:
            for param_group in optimizer.param_groups:
                param_group['lr'] = cf.learning_rate[epoch - 1]