def test_operations(self): reader = ImageReader(['image']) reader.initialise_reader(SINGLE_MOD_DATA, SINGLE_MOD_TASK) idx, data, interp_order = reader() self.assertEqual(SINGLE_MOD_DATA['lesion'].interp_order, interp_order['image'][0]) self.assertAllClose(data['image'].shape, (256, 168, 256, 1, 1))
def test_properties(self): reader = ImageReader(['image']) reader.initialise_reader(SINGLE_MOD_DATA, SINGLE_MOD_TASK) self.assertEquals(len(reader.output_list), 4) self.assertDictEqual(reader.shapes, {'image': (256, 168, 256, 1, 1)}) self.assertDictEqual(reader.tf_dtypes, {'image': tf.float32}) self.assertEqual(reader.names, ['image']) self.assertDictEqual(reader.input_sources, {'image': ('lesion', )}) self.assertEqual(reader.get_subject_id(1)[:4], 'Fin_')
def get_label_reader(): reader = ImageReader(['label']) reader.initialise_reader(MOD_LABEL_DATA, MOD_LABEl_TASK) label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(SINGLE_25D_TASK).get('label'), model_filename=os.path.join('testing_data', 'agg_test.txt')) reader.add_preprocessing_layers(label_normaliser) pad_layer = PadLayer(image_name=('label', ), border=(5, 6, 7)) reader.add_preprocessing_layers([pad_layer]) return reader
def test_preprocessing_zero_padding(self): reader = ImageReader(['image']) reader.initialise_reader(SINGLE_MOD_DATA, SINGLE_MOD_TASK) idx, data, interp_order = reader() self.assertEqual(SINGLE_MOD_DATA['lesion'].interp_order, interp_order['image'][0]) self.assertAllClose(data['image'].shape, (256, 168, 256, 1, 1)) reader.add_preprocessing_layers( [PadLayer(image_name=['image'], border=(0, 0, 0))]) idx, data, interp_order = reader(idx=2) self.assertEqual(idx, 2) self.assertAllClose(data['image'].shape, (256, 168, 256, 1, 1))
def test_trainable_preprocessing(self): label_file = os.path.join('testing_data', 'label_reader.txt') if os.path.exists(label_file): os.remove(label_file) label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(LABEL_TASK).get('label'), model_filename=os.path.join('testing_data', 'label_reader.txt')) reader = ImageReader(['label']) with self.assertRaisesRegexp(AssertionError, ''): reader.add_preprocessing_layers(label_normaliser) reader.initialise_reader(LABEL_DATA, LABEL_TASK) reader.add_preprocessing_layers(label_normaliser) reader.add_preprocessing_layers( [PadLayer(image_name=['label'], border=(10, 5, 5))]) idx, data, interp_order = reader(idx=0) unique_data = np.unique(data['label']) expected = np.array(range(156), dtype=np.float32) self.assertAllClose(unique_data, expected) self.assertAllClose(data['label'].shape, (83, 73, 73, 1, 1))
def test_initialisation(self): with self.assertRaisesRegexp(ValueError, ''): reader = ImageReader(['test']) reader.initialise_reader(MULTI_MOD_DATA, MULTI_MOD_TASK) with self.assertRaisesRegexp(AssertionError, ''): reader = ImageReader(None) reader.initialise_reader(MULTI_MOD_DATA, MULTI_MOD_TASK) reader = ImageReader(['image']) reader.initialise_reader(MULTI_MOD_DATA, MULTI_MOD_TASK) self.assertEquals(len(reader.output_list), 4) reader = ImageReader(['image']) reader.initialise_reader(SINGLE_MOD_DATA, SINGLE_MOD_TASK) self.assertEquals(len(reader.output_list), 4)
def test_errors(self): with self.assertRaisesRegexp(AttributeError, ''): reader = ImageReader(['image']) reader.initialise_reader(BAD_DATA, SINGLE_MOD_TASK) with self.assertRaisesRegexp(ValueError, ''): reader = ImageReader(['image']) reader.initialise_reader(SINGLE_MOD_DATA, BAD_TASK) reader = ImageReader(['image']) reader.initialise_reader(SINGLE_MOD_DATA, SINGLE_MOD_TASK) idx, data, interp_order = reader(idx=100) self.assertEqual(idx, -1) self.assertEqual(data, None) idx, data, interp_order = reader(shuffle=True) self.assertEqual(data['image'].shape, (256, 168, 256, 1, 1))
def get_dynamic_window_reader(): reader = ImageReader(['image']) reader.initialise_reader(DYNAMIC_MOD_DATA, DYNAMIC_MOD_TASK) return reader
def get_2d_reader(): reader = ImageReader(['image']) reader.initialise_reader(MOD_2D_DATA, MOD_2D_TASK) return reader
def get_3d_reader(): reader = ImageReader(['image']) reader.initialise_reader(MULTI_MOD_DATA, MULTI_MOD_TASK) return reader
def get_25d_reader(): reader = ImageReader(['image']) reader.initialise_reader(SINGLE_25D_DATA, SINGLE_25D_TASK) return reader
class GANApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "GAN" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting GAN application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.gan_param = None def initialise_dataset_loader(self, data_param=None, task_param=None): self.data_param = data_param self.gan_param = task_param # read each line of csv files into an instance of Subject if self.is_training: self.reader = ImageReader(['image', 'conditioning']) else: # in the inference process use image input only self.reader = ImageReader(['conditioning']) if self.reader: self.reader.initialise_reader(data_param, task_param) if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) else: foreground_masking_layer = None mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') else: histogram_normaliser = None normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append(RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append(RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1])) if self.action_param.rotation_angle: augmentation_layers.append(RandomRotationLayer( min_angle=self.action_param.rotation_angle[0], max_angle=self.action_param.rotation_angle[1])) if self.reader: self.reader.add_preprocessing_layers( normalisation_layers + augmentation_layers) def initialise_sampler(self): self.sampler = [] if self.is_training: self.sampler.append(ResizeSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle_buffer=True, queue_length=self.net_param.queue_length)) else: self.sampler.append(RandomVectorSampler( names=('vector',), vector_size=(self.gan_param.noise_size,), batch_size=self.net_param.batch_size, n_interpolations=self.gan_param.n_interpolations, repeat=None, queue_length=self.net_param.queue_length)) # repeat each resized image n times, so that each # image matches one random vector, # (n = self.gan_param.n_interpolations) self.sampler.append(ResizeSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.gan_param.n_interpolations, shuffle_buffer=False, queue_length=self.net_param.queue_length)) def initialise_network(self): self.net = ApplicationNetFactory.create(self.net_param.name)() def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): if self.is_training: with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) # a new pop_batch_op for each gpu tower data_dict = self.get_sampler()[0].pop_batch_op() images = tf.cast(data_dict['image'], tf.float32) noise_shape = [self.net_param.batch_size, self.gan_param.noise_size] noise = tf.Variable(tf.random_normal(shape=noise_shape, mean=0.0, stddev=1.0, dtype=tf.float32)) tf.stop_gradient(noise) conditioning = data_dict['conditioning'] net_output = self.net(noise, images, conditioning, self.is_training) loss_func = LossFunction( loss_type=self.action_param.loss_type) real_logits = net_output[1] fake_logits = net_output[2] lossG, lossD = loss_func(real_logits, fake_logits) if self.net_param.decay > 0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(l_reg) for l_reg in reg_losses]) lossD = lossD + reg_loss lossG = lossG + reg_loss # variables to display in STDOUT outputs_collector.add_to_collection( var=lossD, name='lossD', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection( var=lossG, name='lossG', average_over_devices=False, collection=CONSOLE) # variables to display in tensorboard outputs_collector.add_to_collection( var=lossG, name='lossG', average_over_devices=False, collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=lossG, name='lossD', average_over_devices=True, collection=TF_SUMMARIES) with tf.name_scope('ComputeGradients'): # gradients of generator generator_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') generator_grads = self.optimiser.compute_gradients( lossG, var_list=generator_variables) # gradients of discriminator discriminator_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') discriminator_grads = self.optimiser.compute_gradients( lossD, var_list=discriminator_variables) grads = [generator_grads, discriminator_grads] # add the grads back to application_driver's training_grads gradients_collector.add_to_collection(grads) else: data_dict = self.get_sampler()[0].pop_batch_op() conditioning_dict = self.get_sampler()[1].pop_batch_op() conditioning = conditioning_dict['conditioning'] image_size = conditioning.shape.as_list()[:-1] dummy_image = tf.zeros(image_size + [1]) net_output = self.net(data_dict['vector'], dummy_image, conditioning, self.is_training) outputs_collector.add_to_collection( var=net_output[0], name='image', average_over_devices=False, collection=NETORK_OUTPUT) outputs_collector.add_to_collection( var=conditioning_dict['conditioning_location'], name='location', average_over_devices=False, collection=NETORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir) def interpret_output(self, batch_output): if self.is_training: return True return self.output_decoder.decode_batch( batch_output['image'], batch_output['location'])
class AutoencoderApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "AUTOENCODER" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting autoencoder application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.autoencoder_param = None def initialise_dataset_loader(self, data_param=None, task_param=None): self.data_param = data_param self.autoencoder_param = task_param if not self.is_training: self._infer_type = look_up_operations( self.autoencoder_param.inference_type, SUPPORTED_INFERENCE) else: self._infer_type = None # read each line of csv files into an instance of Subject if self.is_training: self.reader = ImageReader(['image']) if self._infer_type in ('encode', 'encode-decode'): self.reader = ImageReader(['image']) elif self._infer_type == 'sample': self.reader = () elif self._infer_type == 'linear_interpolation': self.reader = ImageReader(['feature']) if self.reader: self.reader.initialise_reader(data_param, task_param) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=self.action_param. scaling_percentage[0], max_percentage=self.action_param. scaling_percentage[1])) if self.action_param.rotation_angle: augmentation_layers.append( RandomRotationLayer( min_angle=self.action_param.rotation_angle[0], max_angle=self.action_param.rotation_angle[1])) self.reader.add_preprocessing_layers(augmentation_layers) def initialise_sampler(self): self.sampler = [] if self.is_training: self.sampler.append( ResizeSampler(reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle_buffer=True, queue_length=self.net_param.queue_length)) return if self._infer_type in ('encode', 'encode-decode'): self.sampler.append( ResizeSampler(reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle_buffer=False, queue_length=self.net_param.queue_length)) return if self._infer_type == 'linear_interpolation': self.sampler.append( LinearInterpolateSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, n_interpolations=self.autoencoder_param.n_interpolations, queue_length=self.net_param.queue_length)) return def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( w_regularizer=w_regularizer, b_regularizer=b_regularizer) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): if self.is_training: with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) data_dict = self.get_sampler()[0].pop_batch_op() image = tf.cast(data_dict['image'], dtype=tf.float32) net_output = self.net(image, is_training=True) loss_func = LossFunction(loss_type=self.action_param.loss_type) data_loss = loss_func(net_output) loss = data_loss if self.net_param.decay > 0.0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = loss + reg_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) outputs_collector.add_to_collection(var=data_loss, name='variational_lower_bound', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='variational_lower_bound', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=net_output[4], name='Originals', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=net_output[2], name='Means', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=net_output[5], name='Variances', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) else: if self._infer_type in ('encode', 'encode-decode'): data_dict = self.get_sampler()[0].pop_batch_op() image = tf.cast(data_dict['image'], dtype=tf.float32) net_output = self.net(image, is_training=False) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=True, collection=NETORK_OUTPUT) if self._infer_type == 'encode-decode': outputs_collector.add_to_collection( var=net_output[2], name='generated_image', average_over_devices=True, collection=NETORK_OUTPUT) if self._infer_type == 'encode': outputs_collector.add_to_collection( var=net_output[7], name='embedded', average_over_devices=True, collection=NETORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir) return elif self._infer_type == 'sample': image_size = (self.net_param.batch_size,) + \ self.action_param.spatial_window_size + (1,) dummy_image = tf.zeros(image_size) net_output = self.net(dummy_image, is_training=False) noise_shape = net_output[-1].get_shape().as_list() stddev = self.autoencoder_param.noise_stddev noise = tf.random_normal(shape=noise_shape, mean=0.0, stddev=stddev, dtype=tf.float32) partially_decoded_sample = self.net.shared_decoder( noise, is_training=False) decoder_output = self.net.decoder_means( partially_decoded_sample, is_training=False) outputs_collector.add_to_collection(var=decoder_output, name='generated_image', average_over_devices=True, collection=NETORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=None, output_path=self.action_param.save_seg_dir) return elif self._infer_type == 'linear_interpolation': # construct the entire network image_size = (self.net_param.batch_size,) + \ self.action_param.spatial_window_size + (1,) dummy_image = tf.zeros(image_size) net_output = self.net(dummy_image, is_training=False) data_dict = self.get_sampler()[0].pop_batch_op() real_code = data_dict['feature'] real_code = tf.reshape(real_code, net_output[-1].get_shape()) partially_decoded_sample = self.net.shared_decoder( real_code, is_training=False) decoder_output = self.net.decoder_means( partially_decoded_sample, is_training=False) outputs_collector.add_to_collection(var=decoder_output, name='generated_image', average_over_devices=True, collection=NETORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['feature_location'], name='location', average_over_devices=True, collection=NETORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir) else: raise NotImplementedError def interpret_output(self, batch_output): if self.is_training: return True else: infer_type = look_up_operations( self.autoencoder_param.inference_type, SUPPORTED_INFERENCE) if infer_type == 'encode': return self.output_decoder.decode_batch( batch_output['embedded'], batch_output['location'][:, 0:1]) if infer_type == 'encode-decode': return self.output_decoder.decode_batch( batch_output['generated_image'], batch_output['location'][:, 0:1]) if infer_type == 'sample': return self.output_decoder.decode_batch( batch_output['generated_image'], None) if infer_type == 'linear_interpolation': return self.output_decoder.decode_batch( batch_output['generated_image'], batch_output['location'][:, :2])
def test_volume_loader(self): expected_T1 = np.array( [0.0, 8.24277910972, 21.4917343731, 27.0551695202, 32.6186046672, 43.5081573038, 53.3535675285, 61.9058849776, 70.0929786194, 73.9944243858, 77.7437509974, 88.5331971492, 100.0]) expected_FLAIR = np.array( [0.0, 5.36540863446, 15.5386130103, 20.7431912042, 26.1536608309, 36.669150376, 44.7821246138, 50.7930589961, 56.1703089214, 59.2393548654, 63.1565641037, 78.7271261392, 100.0]) reader = ImageReader(['image']) reader.initialise_reader(DATA_PARAM, TASK_PARAM) self.assertAllClose(len(reader._file_list), 4) foreground_masking_layer = BinaryMaskingLayer( type_str='otsu_plus', multimod_fusion='or') hist_norm = HistogramNormalisationLayer( image_name='image', modalities=vars(TASK_PARAM).get('image'), model_filename=MODEL_FILE, binary_masking_func=foreground_masking_layer, cutoff=(0.05, 0.95), name='hist_norm_layer') if os.path.exists(MODEL_FILE): os.remove(MODEL_FILE) hist_norm.train(reader.output_list) out_map = hist_norm.mapping self.assertAllClose(out_map['T1'], expected_T1) self.assertAllClose(out_map['FLAIR'], expected_FLAIR) # normalise a uniformly sampled random image test_shape = (20, 20, 20, 3, 2) rand_image = np.random.uniform(low=-10.0, high=10.0, size=test_shape) norm_image = np.copy(rand_image) norm_image_dict, mask_dict = hist_norm({'image': norm_image}) norm_image, mask = hist_norm(norm_image, mask_dict) self.assertAllClose(norm_image_dict['image'], norm_image) self.assertAllClose(mask_dict['image'], mask) # apply mean std normalisation mv_norm = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) norm_image, _ = mv_norm(norm_image, mask) self.assertAllClose(norm_image.shape, mask.shape) mv_norm = MeanVarNormalisationLayer( image_name='image', binary_masking_func=None) norm_image, _ = mv_norm(norm_image) # mapping should keep at least the order of the images rand_image = rand_image[:, :, :, 1, 1].flatten() norm_image = norm_image[:, :, :, 1, 1].flatten() order_before = rand_image[1:] > rand_image[:-1] order_after = norm_image[1:] > norm_image[:-1] self.assertAllClose(np.mean(norm_image), 0.0) self.assertAllClose(np.std(norm_image), 1.0) self.assertAllClose(order_before, order_after) if os.path.exists(MODEL_FILE): os.remove(MODEL_FILE)
class SegmentationApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting segmentation application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None): self.data_param = data_param self.segmentation_param = task_param # read each line of csv files into an instance of Subject if self.is_training: self.reader = ImageReader(SUPPORTED_INPUT) else: # in the inference process use image input only self.reader = ImageReader(['image']) self.reader.initialise_reader(data_param, task_param) if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) else: foreground_masking_layer = None mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') else: histogram_normaliser = None if self.net_param.histogram_ref_file: label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) else: label_normaliser = None normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation: normalisation_layers.append(label_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.rotation_angle: rotation_layer = RandomRotationLayer() if self.action_param.rotation_angle: rotation_layer.init_uniform_angle( self.action_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( self.action_param.rotation_angle_x, self.action_param.rotation_angle_y, self.action_param.rotation_angle_z) augmentation_layers.append(rotation_layer) # ========================== Disable scaling and rotation ===================== # if self.action_param.scaling_percentage: # augmentation_layers.append(RandomSpatialScalingLayer( # min_percentage=self.action_param.scaling_percentage[0], # max_percentage=self.action_param.scaling_percentage[1])) # ============================================================================= # ============================================================================= # if self.action_param.rotation_angle: # augmentation_layers.append(RandomRotationLayer( # min_angle=self.action_param.rotation_angle[0], # max_angle=self.action_param.rotation_angle[1])) # ============================================================================= volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append( PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) self.reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) def initialise_uniform_sampler(self): self.sampler = [ UniformSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) ] def initialise_resize_sampler(self): self.sampler = [ ResizeSampler(reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) ] def initialise_grid_sampler(self): self.sampler = [ GridSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) ] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() else: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_network(self): num_classes = self.segmentation_param.num_classes w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=num_classes, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): data_dict = self.get_sampler()[0].pop_batch_op() image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, self.is_training) if self.is_training: label = data_dict.get('label', None) # Changed label on 11/29/2017: This will generate a 2D label # from the 3D label provided in the input. Only suitable for STNeuroNet k = label.get_shape().as_list() label = tf.nn.max_pool3d(label, [1, 1, 1, k[3], 1], [1, 1, 1, 1, 1], 'VALID', data_format='NDHWC') print('label shape is{}'.format(label.get_shape())) print('Image shape is{}'.format(image.get_shape())) print('Out shape is{}'.format(net_out.get_shape())) #### with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) data_loss = loss_func(prediction=net_out, ground_truth=label, weight_map=data_dict.get('weight', None)) if self.net_param.decay > 0.0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='dice_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='dice_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # ADDED on 10/30 by Soltanian-Zadeh for tensorboard visualization seg_summary = tf.to_float( tf.expand_dims(tf.argmax(net_out, -1), -1)) * ( 255. / self.segmentation_param.num_classes - 1) label_summary = tf.to_float(tf.expand_dims( label, -1)) * (255. / self.segmentation_param.num_classes - 1) m, v = tf.nn.moments(image, axes=[1, 2, 3], keep_dims=True) img_summary = tf.minimum( 255., tf.maximum(0., (tf.to_float(image - m) / (tf.sqrt(v) * 2.) + 1.) * 127.)) image3_axial('img', img_summary, 50, [tf.GraphKeys.SUMMARIES]) image3_axial('seg', seg_summary, 5, [tf.GraphKeys.SUMMARIES]) image3_axial('label', label_summary, 5, [tf.GraphKeys.SUMMARIES]) else: # converting logits into final output for # classification probabilities or argmax classification labels output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) print('output shape is{}'.format(net_out.get_shape())) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator() def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch(batch_output['window'], batch_output['location']) return True
class RegressionApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "REGRESSION" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting regression application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.regression_param = None self.data_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None): self.data_param = data_param self.regression_param = task_param # read each line of csv files into an instance of Subject if self.is_training: self.reader = ImageReader(SUPPORTED_INPUT) else: # in the inference process use image input only self.reader = ImageReader(['image']) self.reader.initialise_reader(data_param, task_param) mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') else: histogram_normaliser = None normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1]) ) if self.action_param.rotation_angle: augmentation_layers.append(RandomRotationLayer()) augmentation_layers[-1].init_uniform_angle( self.action_param.rotation_angle) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append( PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) self.reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) def initialise_uniform_sampler(self): self.sampler = [ UniformSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) ] def initialise_weighted_sampler(self): self.sampler = [ WeightedSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) ] def initialise_resize_sampler(self): self.sampler = [ ResizeSampler(reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) ] def initialise_grid_sampler(self): self.sampler = [ GridSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) ] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() else: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=1, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): data_dict = self.get_sampler()[0].pop_batch_op() image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, self.is_training) if self.is_training: crop_layer = CropLayer(border=self.regression_param.loss_border, name='crop-88') with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) prediction = crop_layer(net_out) ground_truth = crop_layer(data_dict.get('output', None)) weight_map = None if data_dict.get('weight', None) is None \ else crop_layer(data_dict.get('weight', None)) data_loss = loss_func(prediction=prediction, ground_truth=ground_truth, weight_map=weight_map) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='Loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='Loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) else: crop_layer = CropLayer(border=0, name='crop-88') post_process_layer = PostProcessingLayer('IDENTITY') net_out = post_process_layer(crop_layer(net_out)) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator() def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch(batch_output['window'], batch_output['location']) else: return True