Example #1
0
def _init(string_list, tf_funcs, custom_funcs, logger=None, **kwargs):
    """
    Helper for 'init_losses' or 'init_metrics'.
    Please refer to their docstrings.

    Args:
        string_list:  (list)   List of strings, each giving a name of a metric
                               or loss to use for training. The name should
                               refer to a function or class in either tf_funcs
                               or custom_funcs modules.
        tf_funcs:     (module) A Tensorflow.keras module of losses or metrics,
                               or a list of various modules to look through.
        custom_funcs: (module) A custom module or losses or metrics
        logger:       (Logger) A Logger object
        **kwargs:     (dict)   Parameters passed to all losses or metrics which
                               are represented by a class (i.e. not a function)

    Returns:
        A list of len(string_list) of initialized classes of losses or metrics
        or references to loss or metric functions.
    """
    initialized = []
    tf_funcs = ensure_list_or_tuple(tf_funcs)
    for func_or_class in ensure_list_or_tuple(string_list):
        modules_found = list(filter(None, [getattr(m, func_or_class, None)
                                           for m in tf_funcs]))
        if modules_found:
            initialized.append(modules_found[0])  # return the first found
        else:
            # Fall back to look in custom module
            initialized.append(getattr(custom_funcs, func_or_class))
    return initialized
Example #2
0
    def compile_model(self, optimizer, optimizer_kwargs, loss, metrics,
                      **kwargs):
        """
        Compile the stored tf.keras Model instance stored in self.model
        Sets the loss function, optimizer and metrics

        Args:
            optimizer:        (string) The name of a tf.keras.optimizers Optimizer
            optimizer_kwargs: (dict)   Key-word arguments passed to the Optimizer
            loss:             (string) The name of a tf.keras.losses or
                                       MultiPlanarUnet loss function
            metrics:          (list)   List of tf.keras.metrics or
                                       MultiPlanarUNet metrics.
            **kwargs:         (dict)   Key-word arguments passed to losses
                                       and/or metrics that accept such.
        """
        # Make sure sparse metrics and loss are specified as sparse
        metrics = ensure_list_or_tuple(metrics)
        losses = ensure_list_or_tuple(loss)
        ensure_sparse(metrics + losses)

        # Initialize optimizer
        optimizer = optimizers.__dict__[optimizer]
        optimizer = optimizer(**optimizer_kwargs)

        # Initialize loss(es) and metrics from tf.keras or MultiPlanarUNet
        losses = init_losses(losses, self.logger, **kwargs)
        metrics = init_metrics(metrics, self.logger, **kwargs)

        # Compile the model
        self.model.compile(optimizer=optimizer, loss=losses, metrics=metrics)
        self.logger("Optimizer:   %s" % optimizer)
        self.logger("Loss funcs:  %s" % losses)
        self.logger("Metrics:     %s" % init_metrics)
        return self
Example #3
0
    def compile_model(self,
                      optimizer,
                      loss,
                      metrics,
                      reduction,
                      check_sparse=False,
                      optimizer_kwargs={},
                      loss_kwargs={},
                      **kwargs):
        """
        Compile the stored tf.keras Model instance stored in self.model
        Sets the loss function, optimizer and metrics

        Args:
            optimizer:        (string) The name of a tf.keras.optimizers Optimizer
            optimizer_kwargs: (dict)   Key-word arguments passed to the Optimizer
            loss:             (string) The name of a tf.keras.losses or
                                       MultiPlanarUnet loss function
            metrics:          (list)   List of tf.keras.metrics or
                                       mpunet metrics.
            reduction:        TODO
            check_sparse:     TODO
            **kwargs:         (dict)   Key-word arguments passed to losses
                                       and/or metrics that accept such.
        """
        # Make sure sparse metrics and loss are specified as sparse
        metrics = ensure_list_or_tuple(metrics)
        losses = ensure_list_or_tuple(loss)
        if check_sparse:
            ensure_sparse(metrics + losses)

        # Initialize optimizer, loss(es) and metric(s) from tf.keras or
        # mpunet
        optimizer = init_optimizer(optimizer, self.logger, **optimizer_kwargs)
        losses = init_losses(losses, self.logger, **kwargs)
        for i, loss in enumerate(losses):
            try:
                losses[i] = loss(reduction=reduction, **loss_kwargs)
            except (ValueError, TypeError):
                raise TypeError("All loss functions must currently be "
                                "callable and accept the 'reduction' "
                                "parameter specifying a "
                                "tf.keras.losses.Reduction type. If you "
                                "specified a keras loss function such as "
                                "'sparse_categorical_crossentropy', change "
                                "this to its corresponding loss class "
                                "'SparseCategoricalCrossentropy'. If "
                                "you implemented a custom loss function, "
                                "please raise an issue on GitHub.")
        metrics = init_metrics(metrics, self.logger, **kwargs)

        # Compile the model
        self.model.compile(optimizer=optimizer, loss=losses, metrics=metrics)
        self.logger("Optimizer:   %s" % optimizer)
        self.logger("Loss funcs:  %s" % losses)
        self.logger("Metrics:     %s" % init_metrics)
        return self
Example #4
0
 def log_image(self, print_calling_method=False):
     """
     Log basic stats for this ImagePair.
     """
     self.logger(
         "%s\n"
         "--- loaded:     %s\n"
         "--- shape:      %s\n"
         "--- bg class    %i\n"
         "--- bg value    %s\n"
         '--- scaler      %s\n'
         "--- real shape: %s\n"
         "--- pixdim:     %s" %
         (self.identifier, self.is_loaded, self.shape, self._bg_class,
          self._bg_value, ensure_list_or_tuple(
              self._scaler)[0], np.round(get_real_image_size(self), 3),
          np.round(get_pix_dim(self), 3)),
         print_calling_method=print_calling_method)
Example #5
0
    def evalaute(self):
        """
        TODO

        :return:
        """
        # Get tensors to run and their names
        if hasattr(self.model, "loss_functions"):
            metrics = self.model.loss_functions + self.model.metrics
        else:
            metrics = self.model.metrics
        metrics_names = self.model.metrics_names
        self.model.reset_metrics()
        assert len(metrics_names) == len(metrics)

        # Prepare dictionaries for storing pr. task metric results
        TPs, relevant, selected, batch_wise_metrics = {}, {}, {}, {}
        for task_name, n_classes in zip(self.task_names, self.n_classes):
            TPs[task_name] = np.zeros(shape=(n_classes, ), dtype=np.uint64)
            relevant[task_name] = np.zeros(shape=(n_classes, ),
                                           dtype=np.uint64)
            selected[task_name] = np.zeros(shape=(n_classes, ),
                                           dtype=np.uint64)
            batch_wise_metrics[task_name] = defaultdict(list)

        # Prepare queue and thread for computing counts
        from queue import Queue
        from threading import Thread
        count_queue = Queue(maxsize=self.steps)
        count_thread = Thread(target=self._count_cm_elements_from_queue,
                              daemon=True,
                              args=[
                                  count_queue, self.steps, TPs, relevant,
                                  selected, self.task_names, self.n_classes,
                                  Lock()
                              ])
        count_thread.start()

        # Fetch validation batches from the generator(s)
        pool = ThreadPoolExecutor(max_workers=3)
        batches = pool.map(self.data.__getitem__, np.arange(self.steps))

        # Predict on all
        self.logger("")
        for i, (X, y, _) in enumerate(batches):
            if self.verbose:
                print("   Validation: %i/%i" % (i + 1, self.steps),
                      end="\r",
                      flush=True)
            X = ensure_list_or_tuple(X)
            y = ensure_list_or_tuple(y)

            # Predict and put values in the queue for counting
            pred = self.model.predict_on_batch(X)
            pred = ensure_list_or_tuple(pred)
            count_queue.put([pred, y])

            for p_task, y_task, task in zip(pred, y, self.task_names):
                # Run all metrics
                for metric, name in zip(metrics, metrics_names):
                    m = tf.reduce_mean(metric(y_task, p_task))
                    batch_wise_metrics[task][name].append(m.numpy())
        pool.shutdown(wait=True)

        # Compute the mean over batch-wise metrics
        mean_batch_wise_metrics = {}
        for task in self.task_names:
            mean_batch_wise_metrics[task] = {}
            for metric in metrics_names:
                ms = batch_wise_metrics[task][metric]
                mean_batch_wise_metrics[task][metric] = np.mean(ms)
        self.model.reset_metrics()
        self.logger("")

        # Terminate count thread
        print("Waiting for counting queue to terminate...\n")
        count_thread.join()
        count_queue.join()

        # Compute per-class metrics (dice+precision+recall)
        class_wise_metrics = {}
        for task in self.task_names:
            precisions, recalls, dices = self._compute_dice(tp=TPs[task],
                                                            sel=relevant[task],
                                                            rel=selected[task])
            if self.ignore_bg:
                precisions[0] = np.nan
                recalls[0] = np.nan
                dices[0] = np.nan
            class_wise_metrics[task] = {
                "dice": dices,
                "recall": recalls,
                "precision": precisions
            }
        return class_wise_metrics, mean_batch_wise_metrics