Example #1
0
def model_initializer(hparams, continue_training, base_path, logger=None):
    logger = logger or ScreenLogger()

    # Init model
    model = init_model(hparams["build"], logger)

    if continue_training:
        from MultiPlanarUNet.utils import get_last_model, get_lr_at_epoch, \
                                        clear_csv_after_epoch
        model_path, epoch = get_last_model(os.path.join(base_path, "model"))
        model.load_weights(model_path, by_name=True)
        hparams["fit"]["init_epoch"] = epoch + 1

        # Get the LR at the continued epoch
        lr, name = get_lr_at_epoch(epoch, os.path.join(base_path, "logs"))
        hparams["fit"]["optimizer_kwargs"][name] = lr

        # Remove entries in training.csv file that occurred after the
        # continued epoch
        clear_csv_after_epoch(epoch,
                              os.path.join(base_path, "logs", "training.csv"))

        logger("[NOTICE] Training continues from:\n"
               "Model: %s\n"
               "Epoch: %i\n"
               "LR:    %s" % (os.path.split(model_path)[-1], epoch, lr))
    else:
        hparams["fit"]["init_epoch"] = 0

    return model
Example #2
0
def plot_all_training_curves(glob_path,
                             out_path,
                             raise_error=False,
                             logger=None,
                             **kwargs):
    logger = logger or ScreenLogger()
    try:
        from glob import glob
        paths = glob(glob_path)
        if not paths:
            raise OSError("File pattern {} gave none or too many matches " \
                          "({})".format(glob_path, paths))
        out_folder = os.path.split(out_path)[0]
        for p in paths:
            if len(paths) > 1:
                # Set unique names
                uniq = os.path.splitext(os.path.split(p)[-1])[0]
                f_name = uniq + "_" + os.path.split(out_path)[-1]
                save_path = os.path.join(out_folder, f_name)
            else:
                save_path = out_path
            plot_training_curves(p, save_path, **kwargs)
    except Exception as e:
        s = "Could not plot training curves. ({})".format(e)
        if raise_error:
            raise RuntimeError(s) from e
        else:
            logger.warn(s)
Example #3
0
def init_model(build_hparams, logger=None, clear_previous=True):
    """
    From a set of hyperparameters 'build_hparams' (dict) initializes the
    model specified under build_hparams['model_class_name'].

    Typically, this function is not called directly, but used by the
    higher-level 'initialize_model' function.

    Args:
        build_hparams:  A dictionary of model build hyperparameters
        logger:         A Logger instance
        clear_previous: Clear previous tf sessions

    Returns:
        A tf.keras Model instance
    """
    from utime import models
    logger = logger or ScreenLogger()
    if clear_previous:
        import tensorflow as tf
        tf.keras.backend.clear_session()
    # Build new model of the specified type
    cls_name = build_hparams["model_class_name"]
    logger("Creating new model of type '%s'" % cls_name)
    return models.__dict__[cls_name](logger=logger, **build_hparams)
Example #4
0
    def __init__(self,
                 val_sequence,
                 steps,
                 logger=None,
                 verbose=True,
                 ignore_class_zero=True):
        """
        Args:
            val_sequence: A MultiPlanarUNet.sequence object from which validation
                          batches can be sampled via its __getitem__ method.
            steps:        Numer of batches to sample from val_sequences in each
                          validation epoch
            logger:       An instance of a MultiPlanar Logger that prints to screen
                          and/or file
            verbose:      Print progress to screen - OBS does not use Logger
        """
        super().__init__()
        self.logger = logger or ScreenLogger()
        self.data = val_sequence
        self.steps = steps
        self.verbose = verbose
        self.ignore_bg = ignore_class_zero

        self.n_classes = self.data.n_classes
        if isinstance(self.n_classes, int):
            self.task_names = [""]
            self.n_classes = [self.n_classes]
        else:
            self.task_names = self.data.task_names
Example #5
0
    def __init__(self,
                 batch_shape,
                 n_classes,
                 padding="valid",
                 activation="relu",
                 use_dropout=True,
                 use_bn=True,
                 classify=True,
                 flatten=True,
                 l2_reg=0.0,
                 logger=None,
                 log=True,
                 build=True,
                 **unused):
        super(DeepFeatureNet, self).__init__()
        self.logger = logger or ScreenLogger()
        self.batch_shape = standardize_batch_shape(batch_shape)
        self.n_classes = n_classes
        self.use_dropout = use_dropout
        self.padding = padding
        self.activation = activation
        self.use_bn = use_bn
        self.classify = classify
        self.flatten = flatten
        self.l2_reg = l2_reg
        self.reg = None
        self.model_name = "DeepFeatureNet"

        # Build model and init base keras Model class
        if build:
            with tf.name_scope(self.model_name):
                super(DeepFeatureNet, self).__init__(*self.init_model())
            if log:
                self.log()
Example #6
0
    def __init__(self,
                 layer,
                 every=10,
                 first=10,
                 per_epoch=False,
                 logger=None):
        """
        Args:
            layer:      A tf.keras layer
            every:      Print the weights every 'every' batch or epoch if
                        per_epoch=True
            first:      Print the first 'first' elements of each weight matrix
            per_epoch:  Print after 'every' epoch instead of batch
            logger:     An instance of a MultiPlanar Logger that prints to screen
                        and/or file
        """
        super().__init__()
        if isinstance(layer, int):
            self.layer = self.model.layers[layer]
        else:
            self.layer = layer
        self.first = first
        self.every = every
        self.logger = logger or ScreenLogger()

        self.per_epoch = per_epoch
        if per_epoch:
            # Apply on every epoch instead of per batches
            self.on_epoch_begin = self.on_batch_begin
            self.on_batch_begin = lambda x, y: None

        self.log()
    def __init__(self, image, labels, affine,
                 bg_value=0, bg_class=0, logger=None):

        # Ensure 4D
        if not image.ndim == 4:
            raise ValueError("Input img of dim %i must be dim 4."
                             "If image has only 1 channel, use "
                             "np.expand_dims(img, -1)." % image.ndim)

        # Set logger
        self.logger = logger if logger is not None else ScreenLogger()

        # Number of channels in the input image
        self.im_shape = image.shape
        self.n_channels = self.im_shape[-1]
        self.im_dtype = image.dtype

        # Store potential transformation to regular grid
        self.rot_mat = None

        # Define interpolators
        self.im_intrps, self.lab_intrp = self._init_interpolators(image,
                                                                  labels,
                                                                  bg_value,
                                                                  bg_class,
                                                                  affine)
Example #8
0
    def __init__(self, logger=None, verbose=1):
        super().__init__()
        self.logger = logger or ScreenLogger()
        self.verbose = bool(verbose)

        # Timing attributes
        self.train_begin_time = None
        self.prev_epoch_time = None
Example #9
0
    def __init__(self, n_classes, logger=None, **kwargs):
        self.logger = logger if logger is not None else ScreenLogger()
        self.__name__ = "BatchWeightedCrossEntropyWithLogits"

        # if not class_weights
        self.n_classes = n_classes

        self._log()
Example #10
0
 def __init__(self, logger=None):
     import MultiPlanarUNet
     from MultiPlanarUNet.logging import ScreenLogger
     code_path = MultiPlanarUNet.__path__
     assert len(code_path) == 1
     self.logger = logger or ScreenLogger()
     self.git_path = os.path.split(os.path.abspath(code_path[0]))[0]
     self._mem_path = None
Example #11
0
 def __init__(self, sequences, no_log=False, logger=None):
     _assert_comparable_sequencers(sequences)
     self.sequences = sequences
     self.IDs = [s.identifier.split("/")[0] for s in self.sequences]
     self.n_classes = self.sequences[0].n_classes
     self.logger = logger or ScreenLogger()
     if not no_log:
         self.log()
Example #12
0
 def __init__(self, logger=None):
     """
     Args:
         logger: An instance of a MultiPlanar Logger that prints to screen
                 and/or file
     """
     super().__init__()
     self.logger = logger or ScreenLogger()
Example #13
0
    def __init__(self, model, logger=None):
        self.model = model
        self.logger = logger if logger is not None else ScreenLogger()
        self.target_tensor = None

        # Extra reference to original (non multiple-GPU) model
        # Is set from train.py as needed
        self.org_model = None
    def __init__(self, sequencers, task_names, logger=None):
        super().__init__()
        self.logger = logger or ScreenLogger()
        self.task_names = task_names
        self.sequencers = sequencers
        self.log()

        # Redirect setattrs to the sub-sequences
        self.redirect = True
    def __init__(self,
                 image_pair_loader,
                 dim,
                 batch_size,
                 n_classes,
                 real_space_span=None,
                 noise_sd=0.,
                 force_all_fg="auto",
                 fg_batch_fraction=0.50,
                 label_crop=None,
                 logger=None,
                 is_validation=False,
                 list_of_augmenters=None,
                 sparse=True,
                 **kwargs):
        super().__init__()

        # Validation or training batch generator?
        self.is_validation = is_validation

        # Set logger or default print
        self.logger = logger or ScreenLogger()

        # Set views and attributes for plane sample generation
        self.sample_dim = dim
        self.real_space_span = real_space_span
        self.noise_sd = noise_sd if not self.is_validation else 0.

        # Set data
        self.image_pair_loader = image_pair_loader
        self.images = image_pair_loader.images

        # Augmenter, applied to batch at creation time
        # Do not augment validation data
        self.list_of_augmenters = list_of_augmenters if not self.is_validation else None

        # Batch creation options
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.sparse = sparse

        # Minimum fraction of slices in each batch with FG
        self.force_all_fg_switch = force_all_fg
        self.fg_batch_fraction = fg_batch_fraction

        # Store labels?
        self.store_y = False
        self.stored_y = []

        # Foreground label settings
        self.fg_classes = np.arange(1, self.n_classes)
        if self.fg_classes.shape[0] == 0:
            self.fg_classes = [1]

        # Set potential label label_crop
        self.label_crop = np.array([[0, 0], [0, 0]
                                    ]) if label_crop is None else label_crop
Example #16
0
def init_model(build_hparams, logger=None):
    from MultiPlanarUNet import models
    logger = logger or ScreenLogger()

    # Build new model of the specified type
    cls_name = build_hparams["model_class_name"]
    logger("Creating new model of type '%s'" % cls_name)

    return models.__dict__[cls_name](logger=logger, **build_hparams)
Example #17
0
def prepare_for_continued_training(hparams, project_dir, logger=None):
    """
    Prepares the hyperparameter set and project directory for continued
    training.

    Will find the latest model (highest epoch number) of parameter files in
    the 'model' subdir of 'project_dir' and base the continued training on this
    file. If no file is found, training will start from scratch as
    normally (note: no error is raised, but None is returned instead of a path
    to a parameter file).

    The hparams['fit']['init_epoch'] parameter will be set to match the found
    parameter file or to 0 if no file was found. Note that if init_epoch is set
    to 0 all rows in the training.csv file will be deleted.

    The hparams['fit']['optimizer_kwargs']['learning_rate'] parameter will
    be set according to the value stored in the project_dir/logs/training.csv
    file at the corresponding epoch (left default if no init_epoch was found)

    Args:
        hparams:      (YAMLHParams) The hyperparameters to use for training
        project_dir:  (string)      The path to the current project directory
        logger:       (Logger)      An optional Logger instance

    Returns:
        A path to the model weight files to use for continued training.
        Will be None if no model files were found
    """
    from MultiPlanarUNet.utils import (get_last_model, get_lr_at_epoch,
                                       get_last_epoch, clear_csv_after_epoch)
    model_path, epoch = get_last_model(os.path.join(project_dir, "model"))
    if model_path:
        model_name = os.path.split(model_path)[-1]
    else:
        model_name = None
    csv_path = os.path.join(project_dir, "logs", "training.csv")
    if epoch == 0:
        epoch = get_last_epoch(csv_path)
    else:
        if epoch is None:
            epoch = 0
        clear_csv_after_epoch(epoch, csv_path)
    hparams["fit"]["init_epoch"] = epoch + 1
    # Get the LR at the continued epoch
    lr, name = get_lr_at_epoch(epoch, os.path.join(project_dir, "logs"))
    if lr:
        hparams["fit"]["optimizer_kwargs"][name] = lr
    logger = logger or ScreenLogger()
    logger("[NOTICE] Training continues from:\n"
           "Model: {}\n"
           "Epoch: {}\n"
           "LR:    {}".format(
               model_name or "<No model found - "
               "Starting for scratch!>", epoch, lr))
    return model_path
Example #18
0
 def __init__(self, callback, start_from=0, logger=None):
     """
     Args:
         callback:   A tf.keras callback
         start_from: Delay the activity of 'callback' until this epoch
                     'start_from'
         logger:     An instance of a MultiPlanar Logger that prints to screen
                     and/or file
     """
     self.logger = logger or ScreenLogger()
     self.callback = callback
     self.start_from = start_from
Example #19
0
 def __init__(self,
              n_classes,
              type_weight="Square",
              sparse=True,
              logger=None,
              **kwargs):
     self.type_weight = type_weight
     self.n_classes = n_classes
     self.sparse = sparse
     self.logger = logger or ScreenLogger()
     self.__name__ = "GeneralizedDiceLoss"
     self.log()
Example #20
0
def load_from_file(model, file_path, logger=None, by_name=True):
    """
    Load parameters from file 'file_path' into model 'model'.

    Args:
        model:      A tf.keras Model instance
        file_path:  A path to a parameter file (h5 format typically)
        logger:     An optional Logger instance
        by_name:    Load parameters by layer names instead of order (default).
    """
    model.load_weights(file_path, by_name=by_name)
    logger = logger or ScreenLogger()
    logger("Loading parameters from:\n{}".format(file_path))
Example #21
0
    def __init__(self, logger=None):
        self.logger = logger or ScreenLogger()

        # Prepare signal
        self.stop_signal = Event()
        self.run_signal = Event()
        self.set_signal = Event()

        # Stores list of available GPUs
        self._free_GPUs = Queue()

        super(GPUMonitor, self).__init__(target=self._monitor)
        self.start()
Example #22
0
 def __init__(self, sequencers, batch_size, no_log=False, logger=None):
     # Make sure we can use the 0th sequencer as a reference that respects
     # all the sequences (same batch-size, margins etc.)
     _assert_comparable_sequencers(sequencers)
     super().__init__()
     self.logger = logger or ScreenLogger()
     self.sequences = sequencers
     self.batch_size = batch_size
     self.margin = sequencers[0].margin
     self.n_classes = sequencers[0].n_classes
     for s in self.sequences:
         s.batch_size = 1
     if not no_log:
         self.log()
Example #23
0
 def __init__(self, train_data, val_data=None, logger=None):
     """
     Args:
         train_data: A MultiPlanarUNet.sequence object representing the
                     training data
         val_data:   A MultiPlanarUNet.sequence object representing the
                     validation data
         logger:     An instance of a MultiPlanar Logger that prints to screen
                     and/or file
     """
     super().__init__()
     self.data = (("train", train_data), ("val", val_data))
     self.logger = logger or ScreenLogger()
     self.active = True
Example #24
0
 def __init__(self,
              log_dir="logs",
              out_dir="logs",
              fname="curve.png",
              csv_regex="*training.csv",
              logger=None):
     """
     """
     super().__init__()
     out_dir = os.path.abspath(out_dir)
     if not os.path.exists(out_dir):
         os.makedirs(out_dir)
     self.csv_regex = os.path.join(os.path.abspath(log_dir), csv_regex)
     self.save_path = os.path.join(out_dir, fname)
     self.logger = logger or ScreenLogger()
Example #25
0
def save_images(train, val, out_dir, logger):
    logger = logger or ScreenLogger()
    # Write a few images to disk
    im_path = out_dir
    if not os.path.exists(im_path):
        os.mkdir(im_path)

    training = train[0]
    if val is not None and len(val) != 0:
        validation = val[0]
        v_len = len(validation[0])
    else:
        validation = None
        v_len = 0

    logger("Saving %i sample images in '<project_dir>/images' folder" %
           ((len(training[0]) + v_len) * 2))
    for rr in range(2):
        for k, temp in enumerate((training, validation)):
            if temp is None:
                # No validation data
                continue
            X, Y, W = temp
            for i, (xx, yy, ww) in enumerate(zip(X, Y, W)):
                # Make figure
                fig = plt.figure(figsize=(10, 4))
                ax1 = fig.add_subplot(121)
                ax2 = fig.add_subplot(122)

                # Plot image and overlayed labels
                chnl, view, _ = imshow_with_label_overlay(ax1, xx, yy)

                # Plot histogram
                ax2.hist(xx.flatten(), bins=200)

                # Set labels
                ax1.set_title("Channel %i - Axis %i - "
                              "Weight %.3f" % (chnl, view, ww),
                              size=18)

                # Get path
                out_path = im_path + "/%s%i.png" % ("train" if k == 0 else
                                                    "val", len(X) * rr + i)

                with np.testing.suppress_warnings() as sup:
                    sup.filter(UserWarning)
                    fig.savefig(out_path)
                plt.close(fig)
Example #26
0
def model_initializer(hparams,
                      continue_training,
                      project_dir,
                      initialize_from=None,
                      logger=None):
    logger = logger or ScreenLogger()

    # Init model
    model = init_model(hparams["build"], logger)

    if continue_training:
        if initialize_from:
            raise ValueError("Failed to initialize model with both "
                             "continue_training and initialize_from set.")
        from MultiPlanarUNet.utils import get_last_model, get_lr_at_epoch, \
                                          clear_csv_after_epoch, get_last_epoch
        model_path, epoch = get_last_model(os.path.join(project_dir, "model"))
        if model_path:
            model.load_weights(model_path, by_name=True)
            model_name = os.path.split(model_path)[-1]
        else:
            model_name = "<No model found>"
        csv_path = os.path.join(project_dir, "logs", "training.csv")
        if epoch == 0:
            epoch = get_last_epoch(csv_path)
        else:
            if epoch is None:
                epoch = 0
            clear_csv_after_epoch(epoch, csv_path)
        hparams["fit"]["init_epoch"] = epoch + 1

        # Get the LR at the continued epoch
        lr, name = get_lr_at_epoch(epoch, os.path.join(project_dir, "logs"))
        if lr:
            hparams["fit"]["optimizer_kwargs"][name] = lr

        logger("[NOTICE] Training continues from:\n"
               "Model: %s\n"
               "Epoch: %i\n"
               "LR:    %s" % (model_name, epoch, lr))
    else:
        hparams["fit"]["init_epoch"] = 0
        if initialize_from:
            model.load_weights(initialize_from, by_name=True)
            logger("[NOTICE] Initializing parameters from:\n"
                   "{}".format(initialize_from))
    return model
Example #27
0
    def __init__(self,
                 yaml_path,
                 logger=None,
                 no_log=False,
                 no_version_control=False,
                 **kwargs):
        dict.__init__(self, **kwargs)

        # Set logger or default print
        self.logger = logger or ScreenLogger()

        # Set YAML path
        self.yaml_path = os.path.abspath(yaml_path)
        self.string_rep = ""
        self.project_path = os.path.split(self.yaml_path)[0]
        if not os.path.exists(self.yaml_path):
            raise OSError("YAML path '%s' does not exist" % self.yaml_path)
        else:
            with open(self.yaml_path, "r") as yaml_file:
                for line in yaml_file:
                    self.string_rep += line
            hparams = YAML(typ="safe").load(self.string_rep)

        # Set dict elements
        self.update({k: hparams[k] for k in hparams if k[:4] != "__CB"})

        if self.get('fit') and self["fit"].get("callbacks"):
            # Convert potential callback paths to absolute paths
            cb = _cb_paths_to_abs_paths(callbacks=self["fit"]["callbacks"],
                                        patterns=("log.?dir", "file.?name",
                                                  "file.?path"),
                                        project_dir=self.project_path)
            self["fit"]["callbacks"] = cb

        # Log basic information here...
        self.no_log = no_log
        if not self.no_log:
            self.logger("YAML path:    %s" % self.yaml_path)

        # Version controlling
        _check_deprecated_params(self, self.logger)
        if not no_version_control:
            package = kwargs.get('package') or "MultiPlanarUNet"
            has_git = _check_version(self, self.logger, package)
            if has_git:
                _set_version(self, self.logger if not no_log else None,
                             package)
Example #28
0
    def __init__(self, class_weights, logger=None, *args, **kwargs):
        """
        weights: A Nx1 matrix of class weights for N classes.
        """

        if class_weights is None or class_weights is False:
            raise ValueError("No class weights passed.")
        self.logger = logger if logger is not None else ScreenLogger()
        self.__name__ = "WeightedSemanticCCE"

        self.weights = _to_tensor(class_weights, dtype=tf.float32)
        self.n_classes = self.weights.get_shape()[0]

        with print_options_context(precision=3, suppress=True):
            logger("Class weights:\n%s" % class_weights)

        # Log to console/file
        self._log()
Example #29
0
 def __init__(self, validation_data, n_classes, batch_size=8, logger=None):
     """
     Args:
         validation_data: A tuple (X, y) of two ndarrays of validation data
                          and corresponding labels.
                          Any shape accepted by the model.
                          Labels must be integer targets (not one-hot)
         n_classes:       Number of classes, including background
         batch_size:      Batch size used for prediction
         logger:          An instance of a MultiPlanar Logger that prints to screen
                          and/or file
     """
     super().__init__()
     self.logger = logger or ScreenLogger()
     self.X_val, self.y_val = validation_data
     self.n_classes = n_classes
     self.batch_size = batch_size
     self.scores = []
Example #30
0
    def __init__(self,
                 n_classes,
                 gamma_dice=0.3,
                 gamma_cross=0.3,
                 weight_dice=1,
                 weight_cross=1,
                 logger=None,
                 int_targets=True,
                 **kwargs):
        self.gamma_dice = gamma_dice
        self.gamma_cross = gamma_cross
        self.weight_dice = weight_dice
        self.weight_cross = weight_cross

        self.n_classes = n_classes
        self.int_targets = int_targets
        self.logger = logger or ScreenLogger()
        self.__name__ = "ExponentialLogarithmicLoss"
        self.log()