Пример #1
0
    def _print_val_results(precisions, recalls, dices, epoch, name, classes,
                           ignore_bg, logger):
        # Log the results
        # We add them to a pd dataframe just for the pretty print output
        index = ["cls %i" % i for i in classes]
        val_results = pd.DataFrame(
            {
                "precision":
                [np.nan] + list(precisions) if ignore_bg else precisions,
                "recall": [np.nan] + list(recalls) if ignore_bg else recalls,
                "dice": [np.nan] + list(dices) if ignore_bg else dices,
            },
            index=index)
        # Transpose the results to have metrics in rows
        val_results = val_results.T
        # Add mean and set in first row
        means = [precisions.mean(), recalls.mean(), dices.mean()]
        val_results["mean"] = means
        cols = list(val_results.columns)
        cols.insert(0, cols.pop(cols.index('mean')))
        val_results = val_results.ix[:, cols]

        # Print the df to screen
        logger(
            highlighted("\n" + ("%s Validation Results for "
                                "Epoch %i" % (name, epoch)).lstrip(" ")))
        logger(val_results.round(4))
        logger("")
Пример #2
0
    def _print_val_results(self, precisions, recalls, dices, metrics, epoch,
                           name, classes):
        # Log the results
        # We add them to a pd dataframe just for the pretty print output
        index = ["cls %i" % i for i in classes]
        metric_keys, metric_vals = map(list, list(zip(*metrics.items())))
        col_order = metric_keys + ["precision", "recall", "dice"]
        nan_arr = np.empty(shape=len(precisions))
        nan_arr[:] = np.nan
        value_dict = {"precision": precisions,
                      "recall": recalls,
                      "dice": dices}
        value_dict.update({key: nan_arr for key in metrics})
        val_results = pd.DataFrame(value_dict,
                                   index=index).loc[:, col_order]  # ensure order
        # Transpose the results to have metrics in rows
        val_results = val_results.T
        # Add mean and set in first row
        means = metric_vals + [precisions.mean(), recalls.mean(), dices.mean()]
        val_results["mean"] = means
        cols = list(val_results.columns)
        cols.insert(0, cols.pop(cols.index('mean')))
        val_results = val_results.loc[:, cols]

        # Print the df to screen
        self.logger(highlighted(("[%s] Validation Results for "
                            "Epoch %i" % (name, epoch)).lstrip(" ")))
        print_string = val_results.round(self.print_round).to_string()
        self.logger(print_string.replace("NaN", "---") + "\n")
Пример #3
0
    def on_epoch_end(self, epoch, logs={}):
        scores = self.eval()
        mean_dice = scores.mean()
        s = "Mean dice for epoch %d: %.4f\nPr. class: %s" % (epoch, mean_dice,
                                                             scores)
        self.logger(highlighted(s))
        self.scores.append(mean_dice)

        # Add to log
        logs["val_dice"] = mean_dice
Пример #4
0
    def on_epoch_end(self, epoch, logs={}):

        # Predict and get CM
        for name, tp, rel, sel in zip(self.task_names, *self.predict()):
            # Compute precisions, recalls and dices
            precisions = tp / sel
            recalls = tp / rel
            dices = (2 * precisions * recalls) / (precisions + recalls)

            # Ignore BG
            precisions = precisions[1:]
            recalls = recalls[1:]
            dices = dices[1:]

            # Set NaN --> 0.
            precisions[np.isnan(precisions)] = 0.
            recalls[np.isnan(recalls)] = 0.
            dices[np.isnan(dices)] = 0.

            sp = "Mean precision for epoch %d: %.4f - Pr. class: %s" % (
                epoch, precisions.mean(), np.round(precisions, 4))
            sr = "Mean recall for epoch %d:    %.4f - Pr. class: %s" % (
                epoch, recalls.mean(), np.round(recalls, 4))
            sf = "Mean dice for epoch %d:      %.4f - Pr. class: %s" % (
                epoch, dices.mean(), np.round(dices, 4))

            self.logger(
                highlighted("\n" +
                            ("%s Validation Results" % name).lstrip(" ")))
            self.logger(sp + "\n" + sr + "\n" + sf)

            # Add to log
            if name:
                name += "_"
            logs["%sval_dice" % name] = dices.mean()
            logs["%sval_precision" % name] = precisions.mean()
            logs["%sval_recall" % name] = recalls.mean()

        if len(self.task_names) > 1:
            self.logger("\nMean across tasks")
            # If multi-task, compute mean over tasks and add to logs
            fetch = ("val_dice", "val_precision", "val_recall")
            for f in fetch:
                res = np.mean(
                    [logs["%s_%s" % (name, f)] for name in self.task_names])
                logs[f] = res
                self.logger(("Mean %s for epoch %d:" %
                             (f.split("_")[1], epoch)).ljust(30) +
                            "%.4f" % res)
            self.logger("")
Пример #5
0
 def log(self):
     self.logger(highlighted("\nAudit for %i images" % len(self.nii_paths)))
     self.logger("Total memory GiB:  %.3f" % self.total_memory_gib)
     if self.n_classes is not None:
         self.logger("Number of classes: %i" % self.n_classes)
     self.logger("\n2D:\n"
                 "Real space span:   %.3f\n"
                 "Sample dim:        %.3f" %
                 (self.real_space_span_2D, self.sample_dim_2D))
     self.logger(
         "\n3D:\n"
         "Sample dim:        %i\n"
         "Real space span:   %.3f\n"
         "Box span:          %.3f" %
         (self.sample_dim_3D, self.real_space_span_3D, self.real_box_span))
Пример #6
0
    def on_epoch_end(self, epoch, logs={}):
        self.logger("\n")
        # Predict and get CM
        TPs, relevant, selected, metrics = self.predict()
        for name in self.IDs:
            tp, rel, sel = TPs[name], relevant[name], selected[name]
            precisions, recalls, dices = self._compute_dice(tp=tp,
                                                            sel=sel,
                                                            rel=rel)
            classes = np.arange(len(dices))

            # Add to log
            n = (name + "_") if len(self.IDs) > 1 else ""
            logs[f"{n}val_dice"] = dices.mean().round(self.log_round)
            logs[f"{n}val_precision"] = precisions.mean().round(self.log_round)
            logs[f"{n}val_recall"] = recalls.mean().round(self.log_round)
            for m_name, value in metrics[name].items():
                logs[f"{n}val_{m_name}"] = value.round(self.log_round)

            if self.verbose:
                self._print_val_results(precisions=precisions,
                                        recalls=recalls,
                                        dices=dices,
                                        epoch=epoch,
                                        name=name,
                                        classes=classes,
                                        logger=self.logger)

        if len(self.IDs) > 1:
            # Print cross-dataset mean values
            if self.verbose:
                self.logger(
                    highlighted(f"[ALL DATASETS] Means Across Classes"
                                f" for Epoch {epoch}"))
            fetch = ("val_dice", "val_precision", "val_recall")
            m_fetch = tuple(["val_" + s for s in self.model.metrics_names])
            to_print = {}
            for f in fetch + m_fetch:
                scores = [logs["%s_%s" % (name, f)] for name in self.IDs]
                res = np.mean(scores)
                logs[f] = res.round(self.log_round)  # Add to log file
                to_print[f.split("_")[-1]] = list(scores) + [res]
            if self.verbose:
                df = pd.DataFrame(to_print)
                df.index = self.IDs + ["mean"]
                print(df.round(self.print_round))
            self.logger("")
Пример #7
0
def _run_fusion_training(sets, logger, hparams, min_val_images, is_validation,
                         views, n_classes, unet, fusion_model_org,
                         fusion_model, early_stopping, fm_batch_size, epochs,
                         eval_prob, fusion_weights_path):

    for _round, _set in enumerate(sets):
        s = "Set %i/%i:\n%s" % (_round + 1, len(sets), _set)
        logger("\n%s" % highlighted(s))

        # Reload data
        images = ImagePairLoader(**hparams["val_data"])
        if len(images) < min_val_images:
            images.add_images(ImagePairLoader(**hparams["train_data"]))

        # Get list of ImagePair objects to run on
        image_set_dict = {m.id: m for m in images if m.id in _set}

        # Fetch points from the set images
        points_collection = []
        targets_collection = []
        N_im = len(image_set_dict)
        for num_im, image_id in enumerate(list(image_set_dict.keys())):
            logger("")
            logger(
                highlighted("(%i/%i) Running on %s (%s)" %
                            (num_im + 1, N_im, image_id,
                             "val" if is_validation[image_id] else "train")))

            # Set the current ImagePair
            image = image_set_dict[image_id]
            images.images = [image]

            # Load views
            kwargs = hparams["fit"]
            kwargs.update(hparams["build"])
            seq = images.get_sequencer(views=views, **kwargs)

            # Get voxel grid in real space
            voxel_grid_real_space = get_voxel_grid_real_space(image)

            # Get array to store predictions across all views
            targets = image.labels.reshape(-1, 1)
            points = np.empty(shape=(len(targets), len(views), n_classes),
                              dtype=np.float32)
            points.fill(np.nan)

            # Predict on all views
            for k, v in enumerate(views):
                print("\n%s" % "View: %s" % v)
                points[:, k, :] = predict_and_map(
                    model=unet,
                    seq=seq,
                    image=image,
                    view=v,
                    voxel_grid_real_space=voxel_grid_real_space,
                    n_planes='same+20',
                    targets=targets,
                    eval_prob=eval_prob).reshape(-1, n_classes)

            # Clean up a bit
            del image_set_dict[image_id]
            del image  # Should be GC at this point anyway

            # add to collections
            points_collection.append(points)
            targets_collection.append(targets)

        # Stack points into one matrix
        logger("Stacking points...")
        X, y = stack_collections(points_collection, targets_collection)

        # Shuffle train
        print("Shuffling points...")
        X, y = shuffle(X, y)

        print("Getting validation set...")
        val_ind = int(0.20 * X.shape[0])
        X_val, y_val = X[:val_ind], y[:val_ind]
        X, y = X[val_ind:], y[val_ind:]

        # Prepare dice score callback for validation data
        val_cb = ValDiceScores((X_val, y_val), n_classes, 50000, logger)

        # Callbacks
        cbs = [
            val_cb,
            CSVLogger(filename="logs/fusion_training.csv",
                      separator=",",
                      append=True),
            PrintLayerWeights(fusion_model_org.layers[-1],
                              every=1,
                              first=1000,
                              per_epoch=True,
                              logger=logger)
        ]

        es = EarlyStopping(monitor='val_dice',
                           min_delta=0.0,
                           patience=early_stopping,
                           verbose=1,
                           mode='max')
        cbs.append(es)

        # Start training
        try:
            fusion_model.fit(X,
                             y,
                             batch_size=fm_batch_size,
                             epochs=epochs,
                             callbacks=cbs,
                             verbose=1)
        except KeyboardInterrupt:
            pass
        fusion_model_org.save_weights(fusion_weights_path)