Ejemplo n.º 1
0
    def test_set_augmentation(self):

        img_path = os.path.join(base_path, '../test_data/tiffconnector_1/im/')
        label_path = os.path.join(base_path,
                                  '../test_data/tiffconnector_1/labels/')
        c = TiffConnector(img_path, label_path)
        d = Dataset(c)

        size = (1, 3, 4)
        pad = (1, 2, 2)

        m = TrainingBatch(d,
                          size,
                          padding_zxy=pad,
                          batch_size=len(d.label_values()))

        self.assertEqual(m.augmentation, {'flip'})

        m.augment_by_rotation(True)
        self.assertEqual(m.augmentation, {'flip', 'rotate'})
        self.assertEqual(m.rotation_range, (-45, 45))

        m.augment_by_shear(True)
        self.assertEqual(m.augmentation, {'flip', 'rotate', 'shear'})
        self.assertEqual(m.shear_range, (-5, 5))

        m.augment_by_flipping(False)
        self.assertEqual(m.augmentation, {'rotate', 'shear'})
Ejemplo n.º 2
0
    def test_get_pixels_dimension_order(self):

        img_path = os.path.join(base_path, '../test_data/tiffconnector_1/im/')
        label_path = os.path.join(base_path,
                                  '../test_data/tiffconnector_1/labels/')
        c = TiffConnector(img_path, label_path)
        d = Dataset(c)

        size = (2, 5, 4)
        pad = (0, 0, 0)

        m = TrainingBatch(d,
                          size,
                          padding_zxy=pad,
                          batch_size=len(d.label_values()))

        next(m)

        p = m.pixels()
        w = m.weights()
        self.assertEqual(p.shape, (3, 3, 2, 5, 4))
        self.assertEqual(w.shape, (3, 3, 2, 5, 4))

        m.set_pixel_dimension_order('bczxy')
        p = m.pixels()
        w = m.weights()
        self.assertEqual(p.shape, (3, 3, 2, 5, 4))
        self.assertEqual(w.shape, (3, 3, 2, 5, 4))

        m.set_pixel_dimension_order('bzxyc')
        p = m.pixels()
        w = m.weights()
        self.assertEqual(p.shape, (3, 2, 5, 4, 3))
        self.assertEqual(w.shape, (3, 2, 5, 4, 3))
Ejemplo n.º 3
0
    def test_normalize_zscore(self):

        img_path = os.path.join(base_path, '../test_data/tiffconnector_1/im/')
        label_path = os.path.join(base_path,
                                  '../test_data/tiffconnector_1/labels/')
        c = TiffConnector(img_path, label_path)
        d = Dataset(c)

        size = (1, 2, 2)
        pad = (0, 0, 0)

        batchsize = 2
        nr_channels = 3
        nz = 1
        nx = 4
        ny = 5

        m = TrainingBatch(d,
                          size,
                          padding_zxy=pad,
                          batch_size=len(d.label_values()))

        m.set_normalize_mode('local_z_score')

        pixels = np.zeros((batchsize, nr_channels, nz, nx, ny))
        pixels[:, 0, :, :, :] = 1
        pixels[:, 1, :, :, :] = 2
        pixels[:, 2, :, :, :] = 3

        p_norm = m._normalize(pixels)
        self.assertTrue((p_norm == 0).all())

        # add variation
        pixels[:, 0, 0, 0, 0] = 2
        pixels[:, 0, 0, 0, 1] = 0

        pixels[:, 1, 0, 0, 0] = 3
        pixels[:, 1, 0, 0, 1] = 1

        pixels[:, 2, 0, 0, 0] = 4
        pixels[:, 2, 0, 0, 1] = 2

        p_norm = m._normalize(pixels)

        assert_array_equal(p_norm[:, 0, :, :, :], p_norm[:, 1, :, :, :])
        assert_array_equal(p_norm[:, 0, :, :, :], p_norm[:, 2, :, :, :])

        val = np.array([[[3.16227766, -3.16227766, 0., 0., 0.],
                         [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.],
                         [0., 0., 0., 0., 0.]]])

        assert_array_almost_equal(val, p_norm[0, 0, :, :, :])
Ejemplo n.º 4
0
    def test_normalize_global(self):

        img_path = os.path.join(base_path, '../test_data/tiffconnector_1/im/')
        label_path = os.path.join(base_path,
                                  '../test_data/tiffconnector_1/labels/')
        c = TiffConnector(img_path, label_path)
        d = Dataset(c)

        size = (1, 2, 2)
        pad = (0, 0, 0)

        batchsize = 2
        nr_channels = 3
        nz = 1
        nx = 4
        ny = 5

        val = np.array(
            [[[0.33333333, 0.33333333, 0.33333333, 0.33333333, 0.33333333],
              [0.33333333, 0.33333333, 0.33333333, 0.33333333, 0.33333333],
              [0.33333333, 0.33333333, 0.33333333, 0.33333333, 0.33333333],
              [0.33333333, 0.33333333, 0.33333333, 0.33333333, 0.33333333]],
             [[0.66666667, 0.66666667, 0.66666667, 0.66666667, 0.66666667],
              [0.66666667, 0.66666667, 0.66666667, 0.66666667, 0.66666667],
              [0.66666667, 0.66666667, 0.66666667, 0.66666667, 0.66666667],
              [0.66666667, 0.66666667, 0.66666667, 0.66666667, 0.66666667]],
             [[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.],
              [1., 1., 1., 1., 1.]]])

        m = TrainingBatch(d,
                          size,
                          padding_zxy=pad,
                          batch_size=len(d.label_values()))
        m.set_normalize_mode('global', minmax=[0, 3])

        pixels = np.zeros((batchsize, nr_channels, nz, nx, ny))
        pixels[:, 0, :, :, :] = 1
        pixels[:, 1, :, :, :] = 2
        pixels[:, 2, :, :, :] = 3

        p_norm = m._normalize(pixels)

        pprint(p_norm)
        print(pixels.shape)
        print(p_norm.shape)

        assert_array_almost_equal(val, p_norm[0, :, 0, :, :])
Ejemplo n.º 5
0
    def test_random_tile(self):

        img_path = os.path.join(base_path, '../test_data/tiffconnector_1/im/')
        label_path = os.path.join(base_path,
                                  '../test_data/tiffconnector_1/labels/')
        c = TiffConnector(img_path, label_path)
        d = Dataset(c)

        size = (1, 3, 4)
        pad = (1, 2, 2)

        m = TrainingBatch(d,
                          size,
                          padding_zxy=pad,
                          batch_size=len(d.label_values()))

        m._random_tile(for_label=1)
Ejemplo n.º 6
0
    def test_normalize_global_auto(self):

        img_path = os.path.join(base_path, '../test_data/tiffconnector_1/im/')
        label_path = os.path.join(base_path,
                                  '../test_data/tiffconnector_1/labels/')
        c = TiffConnector(img_path, label_path)
        d = Dataset(c)

        size = (1, 5, 4)
        pad = (0, 0, 0)

        m = TrainingBatch(d,
                          size,
                          padding_zxy=pad,
                          batch_size=len(d.label_values()))
        assert m.global_norm_minmax is None
        m.set_normalize_mode('global')
        assert len(m.global_norm_minmax) == 3
Ejemplo n.º 7
0
    def test_set_pixel_dimension_order(self):

        img_path = os.path.join(base_path, '../test_data/tiffconnector_1/im/')
        label_path = os.path.join(base_path,
                                  '../test_data/tiffconnector_1/labels/')
        c = TiffConnector(img_path, label_path)
        d = Dataset(c)

        size = (1, 3, 4)
        pad = (1, 2, 2)

        m = TrainingBatch(d,
                          size,
                          padding_zxy=pad,
                          batch_size=len(d.label_values()))

        m.set_pixel_dimension_order('bczxy')
        self.assertEqual([0, 1, 2, 3, 4], m.pixel_dimension_order)

        m.set_pixel_dimension_order('bzxyc')
        self.assertEqual([0, 4, 1, 2, 3], m.pixel_dimension_order)
Ejemplo n.º 8
0
    def test_getitem_multichannel_labels(self):

        # define data loacations
        pixel_image_dir = os.path.join(base_path,
                                       '../test_data/tiffconnector_1/im/')
        label_image_dir = os.path.join(
            base_path, '../test_data/tiffconnector_1/labels_multichannel/')

        tile_size = (1, 5, 4)  # size of network output layer in zxy
        # padding of network input layer in zxy, in respect to output layer
        padding = (0, 2, 2)

        # make training_batch mb and prediction interface p with
        # TiffConnector binding
        c = TiffConnector(pixel_image_dir,
                          label_image_dir,
                          savepath=self.tmpdir)
        d = Dataset(c)
        m = TrainingBatch(d,
                          tile_size,
                          padding_zxy=padding,
                          batch_size=len(d.label_values()))

        for counter, mini in enumerate(m):
            # shape is (6, 6, 1, 5, 4):
            # batchsize 6 , 6 label-classes, 1 z, 5 x, 4 y
            weights = mini.weights()

            # shape is (6, 3, 1, 9, 8):
            # batchsize 6, 6 channels, 1 z, 9 x, 4 y (more xy due to padding)
            pixels = mini.pixels()
            self.assertEqual(weights.shape, (6, 6, 1, 5, 4))
            self.assertEqual(pixels.shape, (6, 3, 1, 9, 8))

            # apply training on mini.pixels and mini.weights goes here

            if counter > 10:  # m is infinite
                break
Ejemplo n.º 9
0
class Session(object):
    '''
    A session is used for training a model with a connected dataset (pixels
    and labels) or for predicting connected data (pixels) with an already
    trained model.




    Parameters
    ----------
    data : yapic_io.TrainingBatch
        Connector object for binding pixel and label data
    '''
    def __init__(self):

        self.dataset = None
        self.model = None
        self.data = None
        self.data_val = None
        self.history = None
        self.data_predict = None
        self.log_filename = None

        self.output_tile_size_zxy = None
        self.padding_zxy = None

    def load_training_data(self, image_path, label_path):
        '''
        Connect to a training dataset.

        Parameters
        ----------
        image_path : string
            Path to folder with tiff images.
        label_path : string
            Path to folder with label tiff images or path to ilastik project
            file (.ilp file).
        '''

        print(image_path)
        print(label_path)

        self.dataset = Dataset(io_connector(image_path, label_path))

    def load_prediction_data(self, image_path, save_path):
        '''
        Connect to a prediction dataset.

        Parameters
        ----------
        image_path : string
            Path to folder with tiff images to predict.
        save_path : string
            Path to folder for saving prediction images.
        '''

        self.dataset = Dataset(
            io_connector(image_path,
                         '/tmp/this_should_not_exist',
                         savepath=save_path))

    def make_model(self, model_name, input_tile_size_zxy):
        '''
        Initialzes a neural network and sets tile sizes of the data connector
        accordingly to model input/output shapes.

        Parameters
        ----------
        model_name : string
            Either 'unet_2d' or 'unet_2p5d'
        input_tile_size_zxy: (nr_zslices, nr_x, nr_y)
            Input shape of the model. Large input shapes require large memory
            for used GPU hardware. For 'unet_2d', nr_zslices has to be 1.
        '''
        print('tile size zxy: {}'.format(input_tile_size_zxy))
        assert len(input_tile_size_zxy) == 3
        nr_channels = self.dataset.image_dimensions(0)[0]
        print('nr_channels: {}'.format(nr_channels))
        input_size_czxy = [nr_channels] + list(input_tile_size_zxy)
        n_classes = len(self.dataset.label_values())

        self.model = make_model(model_name, n_classes, input_size_czxy)

        output_tile_size_zxy = self.model.output_shape[-4:-1]

        self._configure_minibatch_data(input_tile_size_zxy,
                                       output_tile_size_zxy)

    def load_model(self, model_filepath):
        '''
        Import a Keras model in hfd5 format.

        Parameters
        ----------
        model_filepath : string
            Path to .h5 model file
        '''

        model = load_keras_model(model_filepath)

        n_classes_model = model.output_shape[-1]
        output_tile_size_zxy = model.output_shape[-4:-1]
        n_channels_model = model.input_shape[-1]
        input_tile_size_zxy = model.input_shape[-4:-1]

        n_classes_data = len(self.dataset.label_values())
        n_channels_data = self.dataset.image_dimensions(0)[0]

        msg = ('nr of model classes ({}) and data classes ({}) '
               'is not equal').format(n_classes_model, n_classes_data)
        if n_classes_data > 0:
            assert n_classes_data == n_classes_model, msg

        msg = ('nr of model channels ({}) and iamge channels ({}) '
               'is not equal').format(n_channels_model, n_channels_data)
        assert n_channels_data == n_channels_model, msg

        self._configure_minibatch_data(input_tile_size_zxy,
                                       output_tile_size_zxy)
        self.model = model

    def _configure_minibatch_data(self, input_tile_size_zxy,
                                  output_tile_size_zxy):
        padding_zxy = tuple(
            ((np.array(input_tile_size_zxy) - np.array(output_tile_size_zxy)) /
             2).astype(np.int))

        self.data_val = None
        self.data = TrainingBatch(self.dataset,
                                  output_tile_size_zxy,
                                  padding_zxy=padding_zxy)
        self.data.set_normalize_mode('local')
        self.data.set_pixel_dimension_order('bzxyc')
        next(self.data)
        self.output_tile_size_zxy = output_tile_size_zxy
        self.padding_zxy = padding_zxy

    def define_validation_data(self, valfraction):
        '''
        Splits the dataset into a training fraction and a validation fraction.

        Parameters
        ----------
        valfraction : float
            Approximate fraction of validation data. Has to be between 0 and 1.
        '''
        if self.data_val is not None:
            logger.warning('skipping define_validation_data: already defined')
            return None

        if self.data is None:
            logger.warning('skipping, data not defined yet')
            return None
        self.data.remove_unlabeled_tiles()
        self.data_val = self.data.split(valfraction)

    def train(self,
              max_epochs=3000,
              steps_per_epoch=24,
              log_filename=None,
              model_filename='model.h5'):
        '''
        Starts a training run.

        Parameters
        ----------
        max_epochs : int
            Number of epochs.
        steps_per_epoch : int
            Number of training steps per epoch.
        log_filename : string
           Path to the csv file for logging loss and accuracy.
        model_filename : string
           Path to h5 keras model file


        Notes
        -----
        Validation is executed once each epoch.
        Logging to csv file is executed once each epoch.
        '''

        callbacks = []

        if self.data_val:
            save_model_callback = keras.callbacks.ModelCheckpoint(
                model_filename,
                monitor='val_loss',
                verbose=0,
                save_best_only=True)
        else:
            save_model_callback = keras.callbacks.ModelCheckpoint(
                model_filename, monitor='loss', verbose=0, save_best_only=True)
        callbacks.append(save_model_callback)

        if log_filename:
            callbacks.append(
                keras.callbacks.CSVLogger(log_filename,
                                          separator=',',
                                          append=False))

        training_data = ((mb.pixels(), mb.weights()) for mb in self.data)

        if self.data_val:
            validation_data = ((mb.pixels(), mb.weights())
                               for mb in self.data_val)
        else:
            validation_data = None

        self.history = self.model.fit_generator(
            training_data,
            validation_data=validation_data,
            epochs=max_epochs,
            validation_steps=steps_per_epoch,
            steps_per_epoch=steps_per_epoch,
            callbacks=callbacks,
            workers=0)

        return self.history

    def predict(self):
        data_predict = PredictionBatch(self.dataset, 2,
                                       self.output_tile_size_zxy,
                                       self.padding_zxy)
        data_predict.set_normalize_mode('local')
        data_predict.set_pixel_dimension_order('bzxyc')

        for item in data_predict:
            result = self.model.predict(item.pixels())
            item.put_probmap_data(result)

    def set_augmentation(self, augment_string):
        '''
        Define data augmentation settings for model training.

        Parameters
        ----------
        augment_string : string
            Choose 'flip' and/or 'rotate' and/or 'shear'.
            Use '+' to specify multiple augmentations (e.g. flip+rotate).
        '''

        if self.data is None:
            logger.warning(
                'could not set augmentation to {}. Run make_model() first')
            return

        ut.handle_augmentation_setting(self.data, augment_string)

    def set_normalization(self, norm_string):
        '''
        Set pixel normalization scope.

        Parameters
        ----------
        norm_string : string
            For minibatch-wise normalization choose 'local_z_score' or 'local'.
            For global normalization use global_<min>+<max>
            (e.g. 'global_0+255' for 8-bit images and 'global_0+65535' for
            16-bit images).
            Choose 'off' to deactivate.
        '''

        if self.data is None:
            logger.warning(
                'could not set normalizarion to {}. Run make_model() first')
            return

        ut.handle_normalization_setting(self.data, norm_string)