def test_remove_unlabeled_tiles(self): img_path = os.path.join( base_path, '../test_data/ilastik/pixels_ilastik-multiim-1.2') label_path = os.path.join( base_path, '../test_data/ilastik/ilastik-multiim-1.2.ilp') c = IlastikConnector(img_path, label_path) d = Dataset(c) size = (2, 6, 4) pad = (0, 0, 0) m = TrainingBatch(d, size, padding_zxy=pad) n_pos_lbl_1 = len(m.tile_pos_for_label[1]) n_pos_lbl_2 = len(m.tile_pos_for_label[2]) m.remove_unlabeled_tiles() n_pos_lbl_1_after = len(m.tile_pos_for_label[1]) n_pos_lbl_2_after = len(m.tile_pos_for_label[2]) self.assertTrue(n_pos_lbl_1 > n_pos_lbl_1_after) self.assertTrue(n_pos_lbl_2 > n_pos_lbl_2_after)
class Session(object): ''' A session is used for training a model with a connected dataset (pixels and labels) or for predicting connected data (pixels) with an already trained model. Parameters ---------- data : yapic_io.TrainingBatch Connector object for binding pixel and label data ''' def __init__(self): self.dataset = None self.model = None self.data = None self.data_val = None self.history = None self.data_predict = None self.log_filename = None self.output_tile_size_zxy = None self.padding_zxy = None def load_training_data(self, image_path, label_path): ''' Connect to a training dataset. Parameters ---------- image_path : string Path to folder with tiff images. label_path : string Path to folder with label tiff images or path to ilastik project file (.ilp file). ''' print(image_path) print(label_path) self.dataset = Dataset(io_connector(image_path, label_path)) def load_prediction_data(self, image_path, save_path): ''' Connect to a prediction dataset. Parameters ---------- image_path : string Path to folder with tiff images to predict. save_path : string Path to folder for saving prediction images. ''' self.dataset = Dataset( io_connector(image_path, '/tmp/this_should_not_exist', savepath=save_path)) def make_model(self, model_name, input_tile_size_zxy): ''' Initialzes a neural network and sets tile sizes of the data connector accordingly to model input/output shapes. Parameters ---------- model_name : string Either 'unet_2d' or 'unet_2p5d' input_tile_size_zxy: (nr_zslices, nr_x, nr_y) Input shape of the model. Large input shapes require large memory for used GPU hardware. For 'unet_2d', nr_zslices has to be 1. ''' print('tile size zxy: {}'.format(input_tile_size_zxy)) assert len(input_tile_size_zxy) == 3 nr_channels = self.dataset.image_dimensions(0)[0] print('nr_channels: {}'.format(nr_channels)) input_size_czxy = [nr_channels] + list(input_tile_size_zxy) n_classes = len(self.dataset.label_values()) self.model = make_model(model_name, n_classes, input_size_czxy) output_tile_size_zxy = self.model.output_shape[-4:-1] self._configure_minibatch_data(input_tile_size_zxy, output_tile_size_zxy) def load_model(self, model_filepath): ''' Import a Keras model in hfd5 format. Parameters ---------- model_filepath : string Path to .h5 model file ''' model = load_keras_model(model_filepath) n_classes_model = model.output_shape[-1] output_tile_size_zxy = model.output_shape[-4:-1] n_channels_model = model.input_shape[-1] input_tile_size_zxy = model.input_shape[-4:-1] n_classes_data = len(self.dataset.label_values()) n_channels_data = self.dataset.image_dimensions(0)[0] msg = ('nr of model classes ({}) and data classes ({}) ' 'is not equal').format(n_classes_model, n_classes_data) if n_classes_data > 0: assert n_classes_data == n_classes_model, msg msg = ('nr of model channels ({}) and iamge channels ({}) ' 'is not equal').format(n_channels_model, n_channels_data) assert n_channels_data == n_channels_model, msg self._configure_minibatch_data(input_tile_size_zxy, output_tile_size_zxy) self.model = model def _configure_minibatch_data(self, input_tile_size_zxy, output_tile_size_zxy): padding_zxy = tuple( ((np.array(input_tile_size_zxy) - np.array(output_tile_size_zxy)) / 2).astype(np.int)) self.data_val = None self.data = TrainingBatch(self.dataset, output_tile_size_zxy, padding_zxy=padding_zxy) self.data.set_normalize_mode('local') self.data.set_pixel_dimension_order('bzxyc') next(self.data) self.output_tile_size_zxy = output_tile_size_zxy self.padding_zxy = padding_zxy def define_validation_data(self, valfraction): ''' Splits the dataset into a training fraction and a validation fraction. Parameters ---------- valfraction : float Approximate fraction of validation data. Has to be between 0 and 1. ''' if self.data_val is not None: logger.warning('skipping define_validation_data: already defined') return None if self.data is None: logger.warning('skipping, data not defined yet') return None self.data.remove_unlabeled_tiles() self.data_val = self.data.split(valfraction) def train(self, max_epochs=3000, steps_per_epoch=24, log_filename=None, model_filename='model.h5'): ''' Starts a training run. Parameters ---------- max_epochs : int Number of epochs. steps_per_epoch : int Number of training steps per epoch. log_filename : string Path to the csv file for logging loss and accuracy. model_filename : string Path to h5 keras model file Notes ----- Validation is executed once each epoch. Logging to csv file is executed once each epoch. ''' callbacks = [] if self.data_val: save_model_callback = keras.callbacks.ModelCheckpoint( model_filename, monitor='val_loss', verbose=0, save_best_only=True) else: save_model_callback = keras.callbacks.ModelCheckpoint( model_filename, monitor='loss', verbose=0, save_best_only=True) callbacks.append(save_model_callback) if log_filename: callbacks.append( keras.callbacks.CSVLogger(log_filename, separator=',', append=False)) training_data = ((mb.pixels(), mb.weights()) for mb in self.data) if self.data_val: validation_data = ((mb.pixels(), mb.weights()) for mb in self.data_val) else: validation_data = None self.history = self.model.fit_generator( training_data, validation_data=validation_data, epochs=max_epochs, validation_steps=steps_per_epoch, steps_per_epoch=steps_per_epoch, callbacks=callbacks, workers=0) return self.history def predict(self): data_predict = PredictionBatch(self.dataset, 2, self.output_tile_size_zxy, self.padding_zxy) data_predict.set_normalize_mode('local') data_predict.set_pixel_dimension_order('bzxyc') for item in data_predict: result = self.model.predict(item.pixels()) item.put_probmap_data(result) def set_augmentation(self, augment_string): ''' Define data augmentation settings for model training. Parameters ---------- augment_string : string Choose 'flip' and/or 'rotate' and/or 'shear'. Use '+' to specify multiple augmentations (e.g. flip+rotate). ''' if self.data is None: logger.warning( 'could not set augmentation to {}. Run make_model() first') return ut.handle_augmentation_setting(self.data, augment_string) def set_normalization(self, norm_string): ''' Set pixel normalization scope. Parameters ---------- norm_string : string For minibatch-wise normalization choose 'local_z_score' or 'local'. For global normalization use global_<min>+<max> (e.g. 'global_0+255' for 8-bit images and 'global_0+65535' for 16-bit images). Choose 'off' to deactivate. ''' if self.data is None: logger.warning( 'could not set normalizarion to {}. Run make_model() first') return ut.handle_normalization_setting(self.data, norm_string)