def __init__(self,
                 log_dir='./logs/tmp/',
                 image_freq=10,
                 feed_inputs_4_display=None,
                 flow=False,
                 dpi=200,
                 f_size=(5, 5),
                 interpol='bilinear'):
        """
        This callback gets a dict with key: x,y entries
        When the on_epoch_end callback is invoked this callback predicts the current output for all xs
        Afterwards it writes the image, gt and prediction into a summary file to make the learning visually in the Tensorboard
        :param log_dir: String, path - folder for the tensorboard summary file Imagewriter will create a subdir "images" for the imagesummary file
        :param image_freq: int - run this callback every n epoch to save disk space and increase speed
        :param feed_inputs_4_display: dict {'train':(x_tensor,y_tensor), 'val' : (x_tensor. y_tensor)}
        x and ys to predict and visualise + key for summary description
        x_tensor and y_tensor have the shape n, x, y, 1 or classes for y, they are grouped by a key, eg. 'train', 'val'
        """

        super(CustomImageWritertf2, self).__init__()
        self.freq = image_freq
        self.flow = flow
        self.f_size = f_size
        self.dpi = dpi
        self.interpol = interpol
        self.e = 0
        self.n_start_epochs = 20
        self.feed_inputs_4_display = feed_inputs_4_display
        log_dir = os.path.join(
            log_dir,
            'images')  # create a subdir for the imagewriter summary file
        ensure_dir(log_dir)
        self.writer = tensorflow.summary.create_file_writer(log_dir)
    def __init__(self, log_dir='./logs', **kwargs):
        """
        This Callback is neccesary to plot train and validation score in one subdirectory 
        :param log_dir: tensorboard summary folder
        :param kwargs: tensorboard config arguments
        """

        # Make the original `TensorBoard` log to a subdirectory 'training'
        training_log_dir = os.path.join(log_dir, 'training')
        ensure_dir(training_log_dir)
        super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)

        # Log the validation metrics to a separate subdirectory
        self.val_log_dir = os.path.join(log_dir, 'validation')
        ensure_dir(self.val_log_dir)
 def __init__(self, model_path, model_freq):
     self.model_path = model_path
     self.N = model_freq
     self.epoch_w = 0
     ensure_dir(model_path)
def get_callbacks(config=None,
                  batch_generator=None,
                  metrics=None,
                  validation_generator=None):
    """
    :param config:
    :param validation_generator:
    :param batch_generator:
    :return: list of callbacks for keras fit_generator
    """

    if config is None:
        config = {}
    callbacks = []
    ensure_dir(config['MODEL_PATH'])

    if batch_generator:
        callbacks.append(
            Ax2SaxWriter(log_dir=config['TENSORBOARD_LOG_DIR'],
                         image_freq=1,
                         val_gen=batch_generator,
                         flow=False,
                         dpi=200,
                         f_size=[12, 4]))
    """callbacks.append(
        WeightsSaver(config.get('MODEL_PATH', 'temp/models'),
                     model_freq=2))"""
    """callbacks.append(
        ModelCheckpoint(os.path.join(config['MODEL_PATH'], ''), # could also be 'model.h5 to save only the weights
                        verbose=1,
                        save_best_only=True,
                        save_weights_only=False,
                        monitor=config.get('SAVE_MODEL_FUNCTION', 'loss'),
                        mode=config.get('SAVE_MODEL_MODE', 'min'),
                        save_freq='epoch'))"""
    callbacks.append(
        ModelCheckpoint(
            os.path.join(config['MODEL_PATH'], 'model.h5'),
            # could also be 'model.h5 to save only the weights
            verbose=1,
            save_best_only=True,
            save_weights_only=True,
            monitor=config.get('SAVE_MODEL_FUNCTION', 'loss'),
            mode=config.get('SAVE_MODEL_MODE', 'min'),
            save_freq='epoch'))
    '''callbacks.append(
        tensorflow.keras.callbacks.TensorBoard(log_dir=config.get('TENSORBOARD_LOG_DIR', 'temp/tf_log'),
                                               histogram_freq=0,
                                               write_graph=False,
                                               write_images=False,
                                               update_freq='epoch',
                                               profile_batch=0,
                                               embeddings_freq=0))'''

    callbacks.append(
        tensorflow.keras.callbacks.ReduceLROnPlateau(
            monitor=config.get('MONITOR_FUNCTION', 'loss'),
            factor=config.get('DECAY_FACTOR', 0.1),
            patience=5,
            verbose=1,
            cooldown=2,
            mode=config.get('MONITOR_MODE', 'auto'),
            min_lr=config.get('MIN_LR', 1e-10)))

    # cyclic learning rate
    # does not work better than simple lr decay
    """callbacks.append(SGDRScheduler(min_lr=1e-5,
                             max_lr=1e-2,
                             lr_decay=0.9,
                             cycle_length=5,
                             mult_factor=1.5))"""

    callbacks.append(
        LRTensorBoard(log_dir=config.get('TENSORBOARD_LOG_DIR', 'temp/tf_log'),
                      histogram_freq=0,
                      write_graph=False,
                      write_images=False,
                      update_freq='epoch',
                      profile_batch=0,
                      embeddings_freq=0))

    if metrics:  # optimizer will be changed to SGD, if adam does not improve any more
        # This callback will call this method without metrics to avoid recursive learning
        logging.info(
            'optimizer will be changed to SGD after adam does not improve any more'
        )
        # idea based on: https://arxiv.org/pdf/1712.07628.pdf
        callbacks.append(
            OptimizerChanger(on_train_end=finetune_with_SGD,
                             train_generator=batch_generator,
                             val_generator=validation_generator,
                             config=config,
                             metrics=metrics,
                             patience=15,
                             verbose=1,
                             monitor=config.get('MONITOR_FUNCTION', 'loss'),
                             mode=config.get('MONITOR_MODE', 'min')))
    else:  # no metrics given, use early stopping callback to stop the training after 20 epochs
        callbacks.append(
            EarlyStopping(patience=config.get('MODEL_PATIENCE', 10),
                          verbose=1,
                          monitor=config.get('MONITOR_FUNCTION', 'loss'),
                          mode=config.get('MONITOR_MODE', 'min')))

    # add own learning lr sheduler, stepdecay and linear/polinomial decay is implemented
    learning_rate_shedule = PolynomialDecay(
        max_epochs=config.get('EPOCHS', 100),
        init_alpha=config.get('LEARNING_RATE', 1e-1),
        power=1)
    # callbacks.append(LearningRateScheduler(learning_rate_shedule))

    return callbacks
def plot_value_histogram(nda, f_name='histogram.jpg', image=True, reports_path='reports/figures/4D_description'):
    '''
    plot 4 histograms for a numpy array of any shape
    1st plot with all values (100 buckets)
    2nd plot with .999 quantile (20 buckets)
    3rd plot with .75 quantile (20 buckets)
    4th plot with .5 quantile (bucketsize = Counter.most_common())
    y axis is percentual scaled
    x axis linear spaced - logspaced buckets are possible
    Parameters
    ----------
    nda :
    f_name :
    image :
    reports_path :

    Returns
    -------

    '''

    ensure_dir(reports_path)
    nda_img_flat = nda.flatten()
    plt.close('all')

    if not image:
        fig = plt.figure(figsize=[6, 6])
        ax1 = fig.add_subplot(111)
        nda_img_flat_filter = nda_img_flat[nda_img_flat > 0]
        c = Counter(nda_img_flat_filter)
        ax1.hist(nda_img_flat_filter, weights=np.ones(len(nda_img_flat_filter)) / len(nda_img_flat_filter), bins=3)
        ax1.set_title("Mask with  = {0:.2f} values".format(len(c)))
        ax1.yaxis.set_major_formatter(PercentFormatter(1))
    else:
        fig = plt.figure(figsize=[20, 6])
        ax1 = fig.add_subplot(141)
        ax1.hist(nda_img_flat, weights=np.ones(len(nda_img_flat)) / len(nda_img_flat), bins=100)
        ax1.set_title("1. quantile = {0:.2f}".format(nda_img_flat.max()))
        ax1.yaxis.set_major_formatter(PercentFormatter(1))

        ax2 = fig.add_subplot(142)
        ninenine_q = np.quantile(nda_img_flat, .999)
        nda_img_flat_nine = nda_img_flat[nda_img_flat <= ninenine_q]
        ax2.hist(nda_img_flat_nine, weights=np.ones(len(nda_img_flat_nine)) / len(nda_img_flat_nine), bins=20)
        ax2.set_title("0.999 quantile = {0:.2f}".format(ninenine_q))
        ax2.yaxis.set_major_formatter(PercentFormatter(1))

        ax3 = fig.add_subplot(143)
        seven_q = np.quantile(nda_img_flat, .75)
        nda_img_flat_seven = nda_img_flat[nda_img_flat <= seven_q]
        ax3.hist(nda_img_flat_seven, weights=np.ones(len(nda_img_flat_seven)) / len(nda_img_flat_seven), bins=20)
        ax3.set_title("0.75 quantile = {0:.2f}".format(seven_q))
        ax3.yaxis.set_major_formatter(PercentFormatter(1))

        ax4 = fig.add_subplot(144)
        mean_q = np.quantile(nda_img_flat, .5)
        nda_img_flat_mean = nda_img_flat[nda_img_flat <= mean_q]
        c = Counter(nda_img_flat_mean)
        ax4.hist(nda_img_flat_mean, weights=np.ones(len(nda_img_flat_mean)) / len(nda_img_flat_mean),
                 bins=len(c.most_common()))
        ax4.set_title("0.5 quantile = {0:.2f}".format(mean_q))
        ax4.set_xticks([key for key, _ in c.most_common()])
        ax4.yaxis.set_major_formatter(PercentFormatter(1))

    fig.suptitle(f_name, y=1.08)
    fig.tight_layout()
    plt.savefig(os.path.join(reports_path, f_name))
    plt.show()
def plot_4d_vol(img_4d, timesteps=[0], save=False, path='temp/', mask_4d=None, f_name='4d_volume.png'):
    '''
    Creates a grid with # timesteps * z-slices
    #saves all slices as fig
    expects nda with t, z, x, y
    Parameters
    ----------
    img_4d :
    timesteps : list of int defining the timesteps which should be print
    save : bool, save the plot or not
    path : path, where this fig should be saved to
    mask_4d :
    f_name :

    Returns
    -------

    '''

    if isinstance(img_4d, sitk.Image):
        img_4d = sitk.GetArrayFromImage(img_4d)

    if len(timesteps) <= 1:  # add first volume if no timesteps found
        logging.info('No timesteps given for: {}, use img.shape[0]'.format(path))
        timesteps = list(range(0, img_4d.shape[0]))
    assert (len(timesteps) == img_4d.shape[0]), 'timeteps does not match'

    if img_4d.shape[-1] == 4:
        img_4d = img_4d[..., 1:]  # ignore background if 4 channels are given

    elif img_4d.shape[-1] == 1:
        img_4d = img_4d[..., 0]  # ignore single channels at the end, matpotlib cant plot this shape

    if mask_4d is not None:  # if images and masks are provided
        if mask_4d.shape[-1] in [3, 4]:
            mask_4d = mask_4d[..., -3:]  # ignore background for masks if 4 channels are given

    # define the number of subplots
    # timesteps * z-slices
    z_size = min(int(2 * img_4d.shape[1]), 30)
    t_size = min(int(2 * len(timesteps)), 20)
    logging.info('figure: {} x {}'.format(z_size, t_size))

    # long axis volumes have only one z slice squeeze=False is necessary to avoid squeezing the axes
    fig, ax = plt.subplots(len(timesteps), img_4d.shape[1], figsize=[z_size, t_size], squeeze=False)
    for t_, img_3d in enumerate(img_4d):  # traverse trough time

        for z, slice in enumerate(img_3d):  # traverse through the z-axis
            # show slice and delete ticks
            if mask_4d is not None:
                ax[t_][z] = show_slice_transparent(slice, mask_4d[t_, z, ...], show=True, ax=ax[t_][z])
            else:
                ax[t_][z] = show_slice_transparent(slice, show=True, ax=ax[t_][z])
            ax[t_][z].set_xticks([])
            ax[t_][z].set_yticks([])
            # ax[t_][z].set_aspect('equal')
            if t_ == 0:  # set title before first row
                ax[t_][z].set_title('z-axis: {}'.format(z), color='r')
            if z == 0:  # set y-label before first column
                ax[t_][z].set_ylabel('t-axis: {}'.format(timesteps[t_]), color='r')

    plt.subplots_adjust(wspace=0.0, hspace=0.0)
    plt.tight_layout()

    if save:
        ensure_dir(path)
        save_plot(fig, path, f_name, override=True, tight=False)
    else:
        return fig
Beispiel #7
0
def copy_meta_and_save(new_image,
                       reference_sitk_img,
                       full_filename=None,
                       override_spacing=None,
                       copy_direction=True):
    """
    Copy metadata, UID and structural information from one image to another
    Works also for different dimensions, returns new_image with copied structural info
    :param new_image: sitk.Image
    :param reference_sitk_img: sitk.Image
    :param path: full file path as str
    :return:
    """

    t1 = time()
    try:
        # make sure this method works with nda and sitk images
        if isinstance(new_image, np.ndarray):
            if len(new_image.shape) == 4:
                # 4D needs to be built from a series
                new_image = [sitk.GetImageFromArray(img) for img in new_image]
                new_image = sitk.JoinSeries(new_image)
            else:
                new_image = sitk.GetImageFromArray(new_image)

        ensure_dir(os.path.dirname(os.path.abspath(full_filename)))

        if reference_sitk_img is not None:
            assert (isinstance(reference_sitk_img,
                               sitk.Image)), 'no reference image given'
            assert (isinstance(new_image, sitk.Image)
                    ), 'only np.ndarrays and sitk images could be stored'

            # copy metadata
            for key in reference_sitk_img.GetMetaDataKeys():
                new_image.SetMetaData(
                    key, get_metadata_maybe(reference_sitk_img, key))
            logging.debug('Metadata_copied: {:0.3f}s'.format(time() - t1))

            # copy structural informations to image with same dimension and size
            if (reference_sitk_img.GetDimension() == new_image.GetDimension()
                ) and (reference_sitk_img.GetSize() == new_image.GetSize()):
                new_image.CopyInformation(reference_sitk_img)

            # same dimension (e.g. 4) but different size per dimension
            elif (reference_sitk_img.GetDimension() ==
                  new_image.GetDimension()):

                # copy spacing, origin and rotation but keep size as it is
                if copy_direction:
                    new_image.SetDirection(reference_sitk_img.GetDirection())
                new_image.SetOrigin(reference_sitk_img.GetOrigin())
                new_image.SetSpacing(reference_sitk_img.GetSpacing())

            # copy structural information to smaller images e.g. 4D to 3D
            elif reference_sitk_img.GetDimension() > new_image.GetDimension():
                shape_ = len(new_image.GetSize())

                reference_shape = len(reference_sitk_img.GetSize())

                # copy direction to smaller images
                # 1. extract the direction, 2. create a matrix, 3. slice by the new shape, 4. flatten
                if copy_direction:
                    direction = np.array(reference_sitk_img.GetDirection())
                    dir_ = direction.reshape(reference_shape, reference_shape)
                    direction = dir_[:shape_, :shape_].flatten()
                    new_image.SetDirection(direction)

                new_image.SetOrigin(reference_sitk_img.GetOrigin()[:shape_])
                new_image.SetSpacing(reference_sitk_img.GetSpacing()[:shape_])

            # copy structural information to bigger images e.g. 3D to 4D, fill with 1.0 spacing
            else:
                ones = [1.0] * (new_image.GetDimension() -
                                reference_sitk_img.GetDimension())
                new_image.SetOrigin((*reference_sitk_img.GetOrigin(), *ones))
                new_image.SetSpacing((*reference_sitk_img.GetSpacing(), *ones))
                # we cant copy the direction from smaller images to bigger ones

            logging.debug('spatial data_copied: {:0.3f}s'.format(time() - t1))

            if override_spacing:
                new_image.SetSpacing(override_spacing)

        if full_filename != None:
            # copy uid
            writer = sitk.ImageFileWriter()
            # writer.KeepOriginalImageUIDOn()
            writer.SetFileName(full_filename)
            writer.Execute(new_image)
            logging.debug('image saved: {:0.3f}s'.format(time() - t1))
            return True
    except Exception as e:
        logging.error('Error with saving file: {} - {}'.format(
            full_filename, str(e)))
        return False

    else:
        return new_image