def test_25d_init(self): reader = get_25d_reader() sampler = ResizeSampler(reader=reader, window_sizes=SINGLE_25D_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = ResizeSamplesAggregator( image_reader=reader, name='image', output_path=os.path.join('testing_data', 'aggregated'), interp_order=3) more_batch = True with self.test_session() as sess: coordinator = tf.train.Coordinator() sampler.run_threads(sess, coordinator, num_threads=2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) except tf.errors.OutOfRangeError: break more_batch = aggregator.decode_batch( out['image'], out['image_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose( nib.load(output_file).shape, [255, 168, 256, 1, 1], rtol=1e-03, atol=1e-03) sampler.close_all()
def test_inverse_mapping(self): reader = get_label_reader() sampler = ResizeSampler(reader=reader, data_param=MOD_LABEL_DATA, batch_size=1, shuffle_buffer=False, queue_length=50) aggregator = ResizeSamplesAggregator( image_reader=reader, name='label', output_path=os.path.join('testing_data', 'aggregated'), interp_order=0) more_batch = True with self.test_session() as sess: coordinator = tf.train.Coordinator() sampler.run_threads(sess, coordinator, num_threads=2) while more_batch: out = sess.run(sampler.pop_batch_op()) more_batch = aggregator.decode_batch( out['label'], out['label_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join( 'testing_data', 'aggregated', output_filename) self.assertAllClose( nib.load(output_file).shape, [256, 168, 256, 1, 1]) sampler.close_all()
def test_inverse_mapping(self): reader = get_label_reader() sampler = ResizeSampler(reader=reader, data_param=MOD_LABEL_DATA, batch_size=1, shuffle_buffer=False, queue_length=50) aggregator = ResizeSamplesAggregator(image_reader=reader, name='label', output_path=os.path.join( 'testing_data', 'aggregated'), interp_order=0) more_batch = True with self.test_session() as sess: coordinator = tf.train.Coordinator() sampler.run_threads(sess, coordinator, num_threads=2) while more_batch: out = sess.run(sampler.pop_batch_op()) more_batch = aggregator.decode_batch(out['label'], out['label_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, [256, 168, 256, 1, 1]) sampler.close_all()
def test_25d_init(self): reader = get_25d_reader() sampler = ResizeSampler(reader=reader, window_sizes=SINGLE_25D_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = ResizeSamplesAggregator( image_reader=reader, name='image', output_path=os.path.join('testing_data', 'aggregated'), interp_order=3) more_batch = True with self.test_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) except tf.errors.OutOfRangeError: break more_batch = aggregator.decode_batch( out['image'], out['image_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose( nib.load(output_file).shape, [255, 168, 256, 1, 1], rtol=1e-03, atol=1e-03) sampler.close_all()
def test_inverse_mapping(self): reader = get_label_reader() sampler = ResizeSampler(reader=reader, window_sizes=MOD_LABEL_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = ResizeSamplesAggregator(image_reader=reader, name='label', output_path=os.path.join( 'testing_data', 'aggregated'), interp_order=0) more_batch = True with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) except tf.errors.OutOfRangeError: break more_batch = aggregator.decode_batch( {'window_label': out['label']}, out['label_location']) output_filename = 'window_label_{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, [256, 168, 256]) sampler.close_all()
def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix)
def test_init_2d_mo_bidimcsv(self): reader = get_2d_reader() sampler = ResizeSampler(reader=reader, window_sizes=MOD_2D_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = ResizeSamplesAggregator(image_reader=reader, name='image', output_path=os.path.join( 'testing_data', 'aggregated'), interp_order=3) more_batch = True with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) except tf.errors.OutOfRangeError: break min_val = np.sum((np.asarray(out['image']).flatten())) stats_val = [ np.min(out['image']), np.max(out['image']), np.sum(out['image']) ] stats_val = np.expand_dims(stats_val, 0) stats_val = np.concatenate([stats_val, stats_val], axis=0) more_batch = aggregator.decode_batch( { 'window_image': out['image'], 'csv_sum': min_val, 'csv_stats_2d': stats_val }, out['image_location']) output_filename = 'window_image_{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) sum_filename = os.path.join( 'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) stats_filename = os.path.join( 'testing_data', 'aggregated', 'csv_stats_2d_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, (128, 128)) min_pd = pd.read_csv(sum_filename) self.assertAllClose(min_pd.shape, [1, 2]) stats_pd = pd.read_csv(stats_filename) self.assertAllClose(stats_pd.shape, [1, 7]) sampler.close_all()
def initialise_resize_aggregator(self): ''' Define the resize aggregator used for decoding using the configuration parameters :return: ''' self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix)
def test_3d_init_mo_3out(self): reader = get_3d_reader() sampler = ResizeSampler(reader=reader, window_sizes=MULTI_MOD_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = ResizeSamplesAggregator(image_reader=reader, name='image', output_path=os.path.join( 'testing_data', 'aggregated'), interp_order=3) more_batch = True with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) except tf.errors.OutOfRangeError: break sum_val = np.sum(out['image']) stats_val = [ np.sum(out['image']), np.min(out['image']), np.max(out['image']) ] more_batch = aggregator.decode_batch( { 'window_image': out['image'], 'csv_sum': sum_val, 'csv_stats': stats_val }, out['image_location']) output_filename = 'window_image_{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) sum_filename = os.path.join( 'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) stats_filename = os.path.join( 'testing_data', 'aggregated', 'csv_stats_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, (256, 168, 256, 1, 2)) sum_pd = pd.read_csv(sum_filename) self.assertAllClose(sum_pd.shape, [1, 2]) stats_pd = pd.read_csv(stats_filename) self.assertAllClose(stats_pd.shape, [1, 4]) sampler.close_all()
def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix, fill_constant=self.action_param.fill_constant)
def initialise_grid_aggregator(self): ''' Define the grid aggregator used for decoding using configuration parameters :return: ''' self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix, fill_constant=self.action_param.fill_constant)
class SegmentationApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, action): super(SegmentationApplication, self).__init__() tf.logging.info('starting segmentation application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), 'balanced': (self.initialise_balanced_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), } def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.segmentation_param = task_param file_lists = self.get_file_lists(data_partitioner) # read each line of csv files into an instance of Subject if self.is_training: self.readers = [] for file_list in file_lists: reader = ImageReader({'image', 'label', 'weight', 'sampler'}) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) elif self.is_inference: # in the inference process use image input only inference_reader = ImageReader({'image'}) file_list = data_partitioner.inference_files inference_reader.initialise(data_param, task_param, file_list) self.readers = [inference_reader] elif self.is_evaluation: file_list = data_partitioner.inference_files reader = ImageReader({'image', 'label', 'inferred'}) reader.initialise(data_param, task_param, file_list) self.readers = [reader] else: raise ValueError('Action `{}` not supported. Expected one of {}' .format(self.action, self.SUPPORTED_ACTIONS)) foreground_masking_layer = None if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') label_normalisers = None if self.net_param.histogram_ref_file and \ task_param.label_normalisation: label_normalisers = [DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file)] if self.is_evaluation: label_normalisers.append( DiscreteLabelNormalisationLayer( image_name='inferred', modalities=vars(task_param).get('inferred'), model_filename=self.net_param.histogram_ref_file)) label_normalisers[-1].key = label_normalisers[0].key normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation and \ (self.is_training or not task_param.output_prob): normalisation_layers.extend(label_normalisers) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append(RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append(RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1])) if self.action_param.rotation_angle or \ self.action_param.rotation_angle_x or \ self.action_param.rotation_angle_y or \ self.action_param.rotation_angle_z: rotation_layer = RandomRotationLayer() if self.action_param.rotation_angle: rotation_layer.init_uniform_angle( self.action_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( self.action_param.rotation_angle_x, self.action_param.rotation_angle_y, self.action_param.rotation_angle_z) augmentation_layers.append(rotation_layer) # add deformation layer if self.action_param.do_elastic_deformation: spatial_rank = list(self.readers[0].spatial_ranks.values())[0] augmentation_layers.append(RandomElasticDeformationLayer( spatial_rank=spatial_rank, num_controlpoints=self.action_param.num_ctrl_points, std_deformation_sigma=self.action_param.deformation_sigma, proportion_to_augment=self.action_param.proportion_to_deform)) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append(PadLayer( image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) # only add augmentation to first reader (not validation reader) self.readers[0].add_preprocessing_layers( volume_padding_layer + normalisation_layers + augmentation_layers) for reader in self.readers[1:]: reader.add_preprocessing_layers( volume_padding_layer + normalisation_layers) def initialise_uniform_sampler(self): self.sampler = [[UniformSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_weighted_sampler(self): self.sampler = [[WeightedSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_resize_sampler(self): self.sampler = [[ResizeSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_sampler(self): self.sampler = [[GridSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_balanced_sampler(self): self.sampler = [[BalancedSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() elif self.is_inference: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_aggregator(self): self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=self.segmentation_param.num_classes, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): # def data_net(for_training): # with tf.name_scope('train' if for_training else 'validation'): # sampler = self.get_sampler()[0][0 if for_training else -1] # data_dict = sampler.pop_batch_op() # image = tf.cast(data_dict['image'], tf.float32) # return data_dict, self.net(image, is_training=for_training) def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: # if self.action_param.validation_every_n > 0: # data_dict, net_out = tf.cond(tf.logical_not(self.is_validation), # lambda: data_net(True), # lambda: data_net(False)) # else: # data_dict, net_out = data_net(True) if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type, softmax=self.segmentation_param.softmax) data_loss = loss_func( prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image*180.0, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=tf.reduce_mean(image), name='mean_image', # average_over_devices=False, summary_type='scalar', # collection=CONSOLE) elif self.is_inference: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): if self.is_inference: return self.output_decoder.decode_batch( batch_output['window'], batch_output['location']) return True def initialise_evaluator(self, eval_param): self.eval_param = eval_param self.evaluator = SegmentationEvaluator(self.readers[0], self.segmentation_param, eval_param) def add_inferred_output(self, data_param, task_param): return self.add_inferred_output_like(data_param, task_param, 'label')
class MultiClassifSegApplication(BaseApplication): """This class defines an application for image-level classification problems mapping from images to scalar labels. This is the application class to be instantiated by the driver and referred to in configuration files. Although structurally similar to segmentation, this application supports different samplers/aggregators (because patch-based processing is not appropriate), and monitoring metrics.""" REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, action): super(MultiClassifSegApplication, self).__init__() tf.logging.info('starting classification application') self.action = action self.net_param = net_param self.eval_param = None self.evaluator = None self.action_param = action_param self.net_multi = None self.data_param = None self.segmentation_param = None self.csv_readers = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), 'balanced': (self.initialise_balanced_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): ''' Initialise the data loader both csv readers and image readers and specify preprocessing layers :param data_param: :param task_param: :param data_partitioner: :return: ''' self.data_param = data_param self.segmentation_param = task_param if self.is_training: image_reader_names = ('image', 'sampler', 'label') csv_reader_names = ('value', ) elif self.is_inference: image_reader_names = ('image', ) csv_reader_names = () elif self.is_evaluation: image_reader_names = ('image', 'inferred', 'label') csv_reader_names = ('value', ) else: tf.logging.fatal('Action `%s` not supported. Expected one of %s', self.action, self.SUPPORTED_PHASES) raise ValueError try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by(phase=reader_phase, action=self.action) self.readers = [ ImageReader(image_reader_names).initialise(data_param, task_param, file_list) for file_list in file_lists ] if self.is_inference: self.action_param.sample_per_volume = 1 if csv_reader_names is not None and list(csv_reader_names): self.csv_readers = [ CSVReader(csv_reader_names).initialise( data_param, task_param, file_list, sample_per_volume=self.action_param.sample_per_volume) for file_list in file_lists ] else: self.csv_readers = [None for file_list in file_lists] foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) \ if self.net_param.normalise_foreground_only else None mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) \ if self.net_param.whitening else None histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') \ if (self.net_param.histogram_ref_file and self.net_param.normalisation) else None label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) \ if (self.net_param.histogram_ref_file and task_param.label_normalisation) else None normalisation_layers = [] if histogram_normaliser is not None: normalisation_layers.append(histogram_normaliser) if mean_var_normaliser is not None: normalisation_layers.append(mean_var_normaliser) if label_normaliser is not None: normalisation_layers.append(label_normaliser) augmentation_layers = [] if self.is_training: train_param = self.action_param if train_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=train_param.random_flipping_axes)) if train_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=train_param.scaling_percentage[0], max_percentage=train_param.scaling_percentage[1])) if train_param.rotation_angle or \ self.action_param.rotation_angle_x or \ self.action_param.rotation_angle_y or \ self.action_param.rotation_angle_z: rotation_layer = RandomRotationLayer() if train_param.rotation_angle: rotation_layer.init_uniform_angle( train_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( self.action_param.rotation_angle_x, self.action_param.rotation_angle_y, self.action_param.rotation_angle_z) augmentation_layers.append(rotation_layer) # only add augmentation to first reader (not validation reader) self.readers[0].add_preprocessing_layers(normalisation_layers + augmentation_layers) for reader in self.readers[1:]: reader.add_preprocessing_layers(normalisation_layers) def initialise_uniform_sampler(self): ''' Create the uniform sampler using information from readers :return: ''' self.sampler = [[ UniformSampler( reader=reader, csv_reader=csv_reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader, csv_reader in zip(self.readers, self.csv_readers) ]] def initialise_weighted_sampler(self): ''' Create the weighted sampler using the info from the csv_readers and image_readers and the configuration parameters :return: ''' self.sampler = [[ WeightedSampler( reader=reader, csv_reader=csv_reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader, csv_reader in zip(self.readers, self.csv_readers) ]] def initialise_resize_sampler(self): ''' Define the resize sampler using the information from the configuration parameters, csv_readers and image_readers :return: ''' self.sampler = [[ ResizeSampler(reader=reader, csv_reader=csv_reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, shuffle=self.is_training, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader, csv_reader in zip(self.readers, self.csv_readers) ]] def initialise_grid_sampler(self): ''' Define the grid sampler based on the information from configuration and the csv_readers and image_readers specifications :return: ''' self.sampler = [[ GridSampler( reader=reader, csv_reader=csv_reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader, csv_reader in zip(self.readers, self.csv_readers) ]] def initialise_balanced_sampler(self): ''' Define the balanced sampler based on the information from configuration and the csv_readers and image_readers specifications :return: ''' self.sampler = [[ BalancedSampler( reader=reader, csv_reader=csv_reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader, csv_reader in zip(self.readers, self.csv_readers) ]] def initialise_grid_aggregator(self): ''' Define the grid aggregator used for decoding using configuration parameters :return: ''' self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix, fill_constant=self.action_param.fill_constant) def initialise_resize_aggregator(self): ''' Define the resize aggregator used for decoding using the configuration parameters :return: ''' self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_sampler(self): ''' Specifies the sampler used among those previously defined based on the sampling choice :return: ''' if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() elif self.is_inference: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_aggregator(self): ''' Specifies the aggregator used based on the sampling choice :return: ''' self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() def initialise_network(self): ''' Initialise the network and specifies the ordering of elements :return: ''' w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create( 'niftynet.contrib.csv_reader.toynet_features.ToyNetFeat')( num_classes=self.segmentation_param.num_classes, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) self.net_multi = ApplicationNetFactory.create( 'niftynet.contrib.csv_reader.class_seg_finnet.ClassSegFinnet')( num_classes=self.segmentation_param.num_classes, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def add_confusion_matrix_summaries_(self, outputs_collector, net_out, data_dict): """ This method defines several monitoring metrics that are derived from the confusion matrix """ labels = tf.reshape(tf.cast(data_dict['label'], tf.int64), [-1]) prediction = tf.reshape(tf.argmax(net_out, -1), [-1]) num_classes = 2 conf_mat = tf.confusion_matrix(labels, prediction, num_classes) conf_mat = tf.to_float(conf_mat) if self.segmentation_param.num_classes == 2: outputs_collector.add_to_collection(var=conf_mat[1][1], name='true_positives', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=conf_mat[1][0], name='false_negatives', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=conf_mat[0][1], name='false_positives', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=conf_mat[0][0], name='true_negatives', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) else: outputs_collector.add_to_collection(var=conf_mat[tf.newaxis, :, :, tf.newaxis], name='confusion_matrix', average_over_devices=True, summary_type='image', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=tf.trace(conf_mat), name='accuracy', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) net_out_seg, net_out_class = self.net_multi( net_out, self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func_class = LossFunctionClassification( n_class=2, loss_type='CrossEntropy') loss_func_seg = LossFunctionSegmentation( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) data_loss_seg = loss_func_seg(prediction=net_out_seg, ground_truth=data_dict.get( 'label', None)) data_loss_class = loss_func_class(prediction=net_out_class, ground_truth=data_dict.get( 'value', None)) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss_seg + data_loss_class + reg_loss else: loss = data_loss_seg + data_loss_class self.total_loss = loss self.total_loss = tf.Print( tf.cast(self.total_loss, tf.float32), [loss, tf.shape(net_out_seg), tf.shape(net_out_class)], message='test') grads = self.optimiser.compute_gradients( loss, colocate_gradients_with_ops=True) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss_class, name='data_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss_seg, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # self.add_confusion_matrix_summaries_(outputs_collector, # net_out_class, # data_dict) else: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) net_out_seg, net_out_class = self.net_multi( net_out, self.is_training) tf.logging.info('net_out.shape may need to be resized: %s', net_out.shape) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer_class = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) post_process_layer_seg = PostProcessingLayer('SOFTMAX', num_classes=2) elif not output_prob and num_classes > 1: post_process_layer_class = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) post_process_layer_seg = PostProcessingLayer('ARGMAX', num_classes=2) else: post_process_layer_class = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) post_process_layer_seg = PostProcessingLayer('IDENTITY', num_classes=2) net_out_class = post_process_layer_class(net_out_class) net_out_seg = post_process_layer_seg(net_out_seg) outputs_collector.add_to_collection(var=net_out_seg, name='seg', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=net_out_class, name='value', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): ''' Specifies how the output should be decoded :param batch_output: :return: ''' if not self.is_training: return self.output_decoder.decode_batch( { 'window_seg': batch_output['seg'], 'csv_class': batch_output['value'] }, batch_output['location']) return True def initialise_evaluator(self, eval_param): ''' Define the evaluator :param eval_param: :return: ''' self.eval_param = eval_param self.evaluator = ClassificationEvaluator(self.readers[0], self.segmentation_param, eval_param) def add_inferred_output(self, data_param, task_param): ''' Define how to treat added inferred output :param data_param: :param task_param: :return: ''' return self.add_inferred_output_like(data_param, task_param, 'label')
class RegApp(BaseApplication): REQUIRED_CONFIG_SECTION = "REGISTRATION" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting label-driven registration') self.action = action self.net_param = net_param self.action_param = action_param self.registration_param = None self.data_param = None def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.registration_param = task_param if self.is_evaluation: NotImplementedError('Evaluation is not yet ' 'supported in this application.') try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by(phase=reader_phase, action=self.action) self.readers = [] for file_list in file_lists: fixed_reader = ImageReader({'fixed_image', 'fixed_label'}) fixed_reader.initialise(data_param, task_param, file_list) self.readers.append(fixed_reader) moving_reader = ImageReader({'moving_image', 'moving_label'}) moving_reader.initialise(data_param, task_param, file_list) self.readers.append(moving_reader) # pad the fixed target only # moving image will be resampled to match the targets #volume_padding_layer = [] #if self.net_param.volume_padding_size: # volume_padding_layer.append(PadLayer( # image_name=('fixed_image', 'fixed_label'), # border=self.net_param.volume_padding_size)) #for reader in self.readers: # reader.add_preprocessing_layers(volume_padding_layer) def initialise_sampler(self): if self.is_training: self.sampler = [] assert len(self.readers) >= 2, 'at least two readers are required' training_sampler = PairwiseUniformSampler( reader_0=self.readers[0], reader_1=self.readers[1], data_param=self.data_param, batch_size=self.net_param.batch_size) self.sampler.append(training_sampler) # adding validation readers if possible if len(self.readers) >= 4: validation_sampler = PairwiseUniformSampler( reader_0=self.readers[2], reader_1=self.readers[3], data_param=self.data_param, batch_size=self.net_param.batch_size) self.sampler.append(validation_sampler) else: self.sampler = PairwiseResizeSampler( reader_0=self.readers[0], reader_1=self.readers[1], data_param=self.data_param, batch_size=self.net_param.batch_size) def initialise_network(self): decay = self.net_param.decay self.net = ApplicationNetFactory.create(self.net_param.name)(decay) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_samplers(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0 if for_training else -1] return sampler() # returns image only if self.is_training: self.patience = self.action_param.patience self.mode = self.action_param.early_stopping_mode if self.action_param.validation_every_n > 0: sampler_window = \ tf.cond(tf.logical_not(self.is_validation), lambda: switch_samplers(True), lambda: switch_samplers(False)) else: sampler_window = switch_samplers(True) image_windows, _ = sampler_window # image_windows, locations = sampler_window # decode channels for moving and fixed images image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1) ] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list # estimate ddf dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer(interpolation='linear', boundary='replicate') resampled_moving_label = resampler(moving_label, dense_field) # compute label loss (foreground only) loss_func = LossFunction(n_class=1, loss_type=self.action_param.loss_type, softmax=False) label_loss = loss_func(prediction=resampled_moving_label, ground_truth=fixed_label) dice_fg = 1.0 - label_loss # appending regularisation loss total_loss = label_loss reg_loss = tf.get_collection('bending_energy') if reg_loss: total_loss = total_loss + \ self.net_param.decay * tf.reduce_mean(reg_loss) self.total_loss = total_loss # compute training gradients with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) grads = self.optimiser.compute_gradients( total_loss, colocate_gradients_with_ops=True) gradients_collector.add_to_collection(grads) metrics_dice = loss_func( prediction=tf.to_float(resampled_moving_label >= 0.5), ground_truth=tf.to_float(fixed_label >= 0.5)) metrics_dice = 1.0 - metrics_dice # command line output outputs_collector.add_to_collection(var=dice_fg, name='one_minus_data_loss', collection=CONSOLE) outputs_collector.add_to_collection(var=tf.reduce_mean(reg_loss), name='bending_energy', collection=CONSOLE) outputs_collector.add_to_collection(var=total_loss, name='total_loss', collection=CONSOLE) outputs_collector.add_to_collection(var=metrics_dice, name='ave_fg_dice', collection=CONSOLE) # for tensorboard outputs_collector.add_to_collection(var=dice_fg, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=total_loss, name='total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=metrics_dice, name='averaged_foreground_Dice', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # for visualisation debugging # resampled_moving_image = resampler(moving_image, dense_field) # outputs_collector.add_to_collection( # var=fixed_image, name='fixed_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=fixed_label, name='fixed_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_image, name='moving_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_label, name='moving_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_image, name='resampled_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_label, name='resampled_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=dense_field, name='ddf', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=locations, name='locations', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=shift[0], name='a', collection=CONSOLE) # outputs_collector.add_to_collection( # var=shift[1], name='b', collection=CONSOLE) else: image_windows, locations = self.sampler() image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1) ] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer(interpolation='linear', boundary='replicate') resampled_moving_image = resampler(moving_image, dense_field) resampled_moving_label = resampler(moving_label, dense_field) outputs_collector.add_to_collection(var=fixed_image, name='fixed_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=moving_image, name='moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=resampled_moving_image, name='resampled_moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=resampled_moving_label, name='resampled_moving_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=fixed_label, name='fixed_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=moving_label, name='moving_label', collection=NETWORK_OUTPUT) #outputs_collector.add_to_collection( # var=dense_field, name='field', # collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=locations, name='locations', collection=NETWORK_OUTPUT) self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], # fixed image reader name='fixed_image', output_path=self.action_param.save_seg_dir, interp_order=self.action_param.output_interp_order) def interpret_output(self, batch_output): if self.is_training: return True return self.output_decoder.decode_batch( {'window_resampled': batch_output['resampled_moving_image']}, batch_output['locations'])
class RegressionApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "REGRESSION" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting regression application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.regression_param = None self.data_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None): self.data_param = data_param self.regression_param = task_param # read each line of csv files into an instance of Subject if self.is_training: self.reader = ImageReader(SUPPORTED_INPUT) else: # in the inference process use image input only self.reader = ImageReader(['image']) self.reader.initialise_reader(data_param, task_param) mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') else: histogram_normaliser = None normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1]) ) if self.action_param.rotation_angle: augmentation_layers.append(RandomRotationLayer()) augmentation_layers[-1].init_uniform_angle( self.action_param.rotation_angle) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append( PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) self.reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) def initialise_uniform_sampler(self): self.sampler = [ UniformSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) ] def initialise_weighted_sampler(self): self.sampler = [ WeightedSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) ] def initialise_resize_sampler(self): self.sampler = [ ResizeSampler(reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) ] def initialise_grid_sampler(self): self.sampler = [ GridSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) ] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() else: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=1, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): data_dict = self.get_sampler()[0].pop_batch_op() image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, self.is_training) if self.is_training: crop_layer = CropLayer(border=self.regression_param.loss_border, name='crop-88') with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) prediction = crop_layer(net_out) ground_truth = crop_layer(data_dict.get('output', None)) weight_map = None if data_dict.get('weight', None) is None \ else crop_layer(data_dict.get('weight', None)) data_loss = loss_func(prediction=prediction, ground_truth=ground_truth, weight_map=weight_map) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='Loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='Loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) else: crop_layer = CropLayer(border=0, name='crop-88') post_process_layer = PostProcessingLayer('IDENTITY') net_out = post_process_layer(crop_layer(net_out)) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator() def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch(batch_output['window'], batch_output['location']) else: return True
def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order)
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_samplers(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0 if for_training else -1] return sampler() # returns image only if self.is_training: self.patience = self.action_param.patience self.mode = self.action_param.early_stopping_mode if self.action_param.validation_every_n > 0: sampler_window = \ tf.cond(tf.logical_not(self.is_validation), lambda: switch_samplers(True), lambda: switch_samplers(False)) else: sampler_window = switch_samplers(True) image_windows, _ = sampler_window # image_windows, locations = sampler_window # decode channels for moving and fixed images image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1) ] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list # estimate ddf dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer(interpolation='linear', boundary='replicate') resampled_moving_label = resampler(moving_label, dense_field) # compute label loss (foreground only) loss_func = LossFunction(n_class=1, loss_type=self.action_param.loss_type, softmax=False) label_loss = loss_func(prediction=resampled_moving_label, ground_truth=fixed_label) dice_fg = 1.0 - label_loss # appending regularisation loss total_loss = label_loss reg_loss = tf.get_collection('bending_energy') if reg_loss: total_loss = total_loss + \ self.net_param.decay * tf.reduce_mean(reg_loss) self.total_loss = total_loss # compute training gradients with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) grads = self.optimiser.compute_gradients( total_loss, colocate_gradients_with_ops=True) gradients_collector.add_to_collection(grads) metrics_dice = loss_func( prediction=tf.to_float(resampled_moving_label >= 0.5), ground_truth=tf.to_float(fixed_label >= 0.5)) metrics_dice = 1.0 - metrics_dice # command line output outputs_collector.add_to_collection(var=dice_fg, name='one_minus_data_loss', collection=CONSOLE) outputs_collector.add_to_collection(var=tf.reduce_mean(reg_loss), name='bending_energy', collection=CONSOLE) outputs_collector.add_to_collection(var=total_loss, name='total_loss', collection=CONSOLE) outputs_collector.add_to_collection(var=metrics_dice, name='ave_fg_dice', collection=CONSOLE) # for tensorboard outputs_collector.add_to_collection(var=dice_fg, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=total_loss, name='total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=metrics_dice, name='averaged_foreground_Dice', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # for visualisation debugging # resampled_moving_image = resampler(moving_image, dense_field) # outputs_collector.add_to_collection( # var=fixed_image, name='fixed_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=fixed_label, name='fixed_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_image, name='moving_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_label, name='moving_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_image, name='resampled_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_label, name='resampled_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=dense_field, name='ddf', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=locations, name='locations', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=shift[0], name='a', collection=CONSOLE) # outputs_collector.add_to_collection( # var=shift[1], name='b', collection=CONSOLE) else: image_windows, locations = self.sampler() image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1) ] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer(interpolation='linear', boundary='replicate') resampled_moving_image = resampler(moving_image, dense_field) resampled_moving_label = resampler(moving_label, dense_field) outputs_collector.add_to_collection(var=fixed_image, name='fixed_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=moving_image, name='moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=resampled_moving_image, name='resampled_moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=resampled_moving_label, name='resampled_moving_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=fixed_label, name='fixed_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=moving_label, name='moving_label', collection=NETWORK_OUTPUT) #outputs_collector.add_to_collection( # var=dense_field, name='field', # collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=locations, name='locations', collection=NETWORK_OUTPUT) self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], # fixed image reader name='fixed_image', output_path=self.action_param.save_seg_dir, interp_order=self.action_param.output_interp_order)
class RegApp(BaseApplication): REQUIRED_CONFIG_SECTION = "REGISTRATION" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting label-driven registration') self.action = action self.net_param = net_param self.action_param = action_param self.registration_param = None self.data_param = None def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.registration_param = task_param file_lists = self.get_file_lists(data_partitioner) if self.is_evaluation: NotImplementedError('Evaluation is not yet ' 'supported in this application.') self.readers = [] for file_list in file_lists: fixed_reader = ImageReader({'fixed_image', 'fixed_label'}) fixed_reader.initialise(data_param, task_param, file_list) self.readers.append(fixed_reader) moving_reader = ImageReader({'moving_image', 'moving_label'}) moving_reader.initialise(data_param, task_param, file_list) self.readers.append(moving_reader) # pad the fixed target only # moving image will be resampled to match the targets #volume_padding_layer = [] #if self.net_param.volume_padding_size: # volume_padding_layer.append(PadLayer( # image_name=('fixed_image', 'fixed_label'), # border=self.net_param.volume_padding_size)) #for reader in self.readers: # reader.add_preprocessing_layers(volume_padding_layer) def initialise_sampler(self): if self.is_training: self.sampler = [] assert len(self.readers) >= 2, 'at least two readers are required' training_sampler = PairwiseUniformSampler( reader_0=self.readers[0], reader_1=self.readers[1], data_param=self.data_param, batch_size=self.net_param.batch_size) self.sampler.append(training_sampler) # adding validation readers if possible if len(self.readers) >= 4: validation_sampler = PairwiseUniformSampler( reader_0=self.readers[2], reader_1=self.readers[3], data_param=self.data_param, batch_size=self.net_param.batch_size) self.sampler.append(validation_sampler) else: self.sampler = PairwiseResizeSampler( reader_0=self.readers[0], reader_1=self.readers[1], data_param=self.data_param, batch_size=self.net_param.batch_size) def initialise_network(self): decay = self.net_param.decay self.net = ApplicationNetFactory.create(self.net_param.name)(decay) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_samplers(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0 if for_training else -1] return sampler() # returns image only if self.is_training: if self.action_param.validation_every_n > 0: sampler_window = \ tf.cond(tf.logical_not(self.is_validation), lambda: switch_samplers(True), lambda: switch_samplers(False)) else: sampler_window = switch_samplers(True) image_windows, _ = sampler_window # image_windows, locations = sampler_window # decode channels for moving and fixed images image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1)] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list # estimate ddf dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer( interpolation='linear', boundary='replicate') resampled_moving_label = resampler(moving_label, dense_field) # compute label loss (foreground only) loss_func = LossFunction( n_class=1, loss_type=self.action_param.loss_type, softmax=False) label_loss = loss_func(prediction=resampled_moving_label, ground_truth=fixed_label) dice_fg = 1.0 - label_loss # appending regularisation loss total_loss = label_loss reg_loss = tf.get_collection('bending_energy') if reg_loss: total_loss = total_loss + \ self.net_param.decay * tf.reduce_mean(reg_loss) # compute training gradients with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) grads = self.optimiser.compute_gradients(total_loss) gradients_collector.add_to_collection(grads) metrics_dice = loss_func( prediction=tf.to_float(resampled_moving_label >= 0.5), ground_truth=tf.to_float(fixed_label >= 0.5)) metrics_dice = 1.0 - metrics_dice # command line output outputs_collector.add_to_collection( var=dice_fg, name='one_minus_data_loss', collection=CONSOLE) outputs_collector.add_to_collection( var=tf.reduce_mean(reg_loss), name='bending_energy', collection=CONSOLE) outputs_collector.add_to_collection( var=total_loss, name='total_loss', collection=CONSOLE) outputs_collector.add_to_collection( var=metrics_dice, name='ave_fg_dice', collection=CONSOLE) # for tensorboard outputs_collector.add_to_collection( var=dice_fg, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=total_loss, name='averaged_total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=metrics_dice, name='averaged_foreground_Dice', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # for visualisation debugging # resampled_moving_image = resampler(moving_image, dense_field) # outputs_collector.add_to_collection( # var=fixed_image, name='fixed_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=fixed_label, name='fixed_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_image, name='moving_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_label, name='moving_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_image, name='resampled_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_label, name='resampled_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=dense_field, name='ddf', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=locations, name='locations', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=shift[0], name='a', collection=CONSOLE) # outputs_collector.add_to_collection( # var=shift[1], name='b', collection=CONSOLE) else: image_windows, locations = self.sampler() image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1)] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer( interpolation='linear', boundary='replicate') resampled_moving_image = resampler(moving_image, dense_field) resampled_moving_label = resampler(moving_label, dense_field) outputs_collector.add_to_collection( var=fixed_image, name='fixed_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=moving_image, name='moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=resampled_moving_image, name='resampled_moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=resampled_moving_label, name='resampled_moving_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=fixed_label, name='fixed_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=moving_label, name='moving_label', collection=NETWORK_OUTPUT) #outputs_collector.add_to_collection( # var=dense_field, name='field', # collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=locations, name='locations', collection=NETWORK_OUTPUT) self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], # fixed image reader name='fixed_image', output_path=self.action_param.save_seg_dir, interp_order=self.action_param.output_interp_order) def interpret_output(self, batch_output): if self.is_training: return True return self.output_decoder.decode_batch( batch_output['resampled_moving_image'], batch_output['locations'])
class SegmentationApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting segmentation application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None): self.data_param = data_param self.segmentation_param = task_param # read each line of csv files into an instance of Subject if self.is_training: self.reader = ImageReader(SUPPORTED_INPUT) else: # in the inference process use image input only self.reader = ImageReader(['image']) self.reader.initialise_reader(data_param, task_param) if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) else: foreground_masking_layer = None mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') else: histogram_normaliser = None if self.net_param.histogram_ref_file: label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) else: label_normaliser = None normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation: normalisation_layers.append(label_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.rotation_angle: rotation_layer = RandomRotationLayer() if self.action_param.rotation_angle: rotation_layer.init_uniform_angle( self.action_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( self.action_param.rotation_angle_x, self.action_param.rotation_angle_y, self.action_param.rotation_angle_z) augmentation_layers.append(rotation_layer) # ========================== Disable scaling and rotation ===================== # if self.action_param.scaling_percentage: # augmentation_layers.append(RandomSpatialScalingLayer( # min_percentage=self.action_param.scaling_percentage[0], # max_percentage=self.action_param.scaling_percentage[1])) # ============================================================================= # ============================================================================= # if self.action_param.rotation_angle: # augmentation_layers.append(RandomRotationLayer( # min_angle=self.action_param.rotation_angle[0], # max_angle=self.action_param.rotation_angle[1])) # ============================================================================= volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append( PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) self.reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) def initialise_uniform_sampler(self): self.sampler = [ UniformSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) ] def initialise_resize_sampler(self): self.sampler = [ ResizeSampler(reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) ] def initialise_grid_sampler(self): self.sampler = [ GridSampler( reader=self.reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) ] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.reader, output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() else: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_network(self): num_classes = self.segmentation_param.num_classes w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=num_classes, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): data_dict = self.get_sampler()[0].pop_batch_op() image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, self.is_training) if self.is_training: label = data_dict.get('label', None) # Changed label on 11/29/2017: This will generate a 2D label # from the 3D label provided in the input. Only suitable for STNeuroNet k = label.get_shape().as_list() label = tf.nn.max_pool3d(label, [1, 1, 1, k[3], 1], [1, 1, 1, 1, 1], 'VALID', data_format='NDHWC') print('label shape is{}'.format(label.get_shape())) print('Image shape is{}'.format(image.get_shape())) print('Out shape is{}'.format(net_out.get_shape())) #### with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) data_loss = loss_func(prediction=net_out, ground_truth=label, weight_map=data_dict.get('weight', None)) if self.net_param.decay > 0.0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='dice_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='dice_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # ADDED on 10/30 by Soltanian-Zadeh for tensorboard visualization seg_summary = tf.to_float( tf.expand_dims(tf.argmax(net_out, -1), -1)) * ( 255. / self.segmentation_param.num_classes - 1) label_summary = tf.to_float(tf.expand_dims( label, -1)) * (255. / self.segmentation_param.num_classes - 1) m, v = tf.nn.moments(image, axes=[1, 2, 3], keep_dims=True) img_summary = tf.minimum( 255., tf.maximum(0., (tf.to_float(image - m) / (tf.sqrt(v) * 2.) + 1.) * 127.)) image3_axial('img', img_summary, 50, [tf.GraphKeys.SUMMARIES]) image3_axial('seg', seg_summary, 5, [tf.GraphKeys.SUMMARIES]) image3_axial('label', label_summary, 5, [tf.GraphKeys.SUMMARIES]) else: # converting logits into final output for # classification probabilities or argmax classification labels output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) print('output shape is{}'.format(net_out.get_shape())) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator() def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch(batch_output['window'], batch_output['location']) return True
class ClassificationApplication(BaseApplication): """This class defines an application for image-level classification problems mapping from images to scalar labels. This is the application class to be instantiated by the driver and referred to in configuration files. Although structurally similar to segmentation, this application supports different samplers/aggregators (because patch-based processing is not appropriate), and monitoring metrics.""" REQUIRED_CONFIG_SECTION = "CLASSIFICATION" def __init__(self, net_param, action_param, action): super(ClassificationApplication, self).__init__() tf.logging.info('starting classification application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.classification_param = None self.SUPPORTED_SAMPLING = { 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler), } def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.classification_param = task_param if self.is_training: reader_names = ('image', 'label', 'sampler') elif self.is_inference: reader_names = ('image', ) elif self.is_evaluation: reader_names = ('image', 'label', 'inferred') else: tf.logging.fatal('Action `%s` not supported. Expected one of %s', self.action, self.SUPPORTED_PHASES) raise ValueError try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by(phase=reader_phase, action=self.action) self.readers = [ ImageReader(reader_names).initialise(data_param, task_param, file_list) for file_list in file_lists ] foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) \ if self.net_param.normalise_foreground_only else None mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) \ if self.net_param.whitening else None histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') \ if (self.net_param.histogram_ref_file and self.net_param.normalisation) else None label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) \ if (self.net_param.histogram_ref_file and task_param.label_normalisation) else None normalisation_layers = [] if histogram_normaliser is not None: normalisation_layers.append(histogram_normaliser) if mean_var_normaliser is not None: normalisation_layers.append(mean_var_normaliser) if label_normaliser is not None: normalisation_layers.append(label_normaliser) augmentation_layers = [] if self.is_training: train_param = self.action_param if train_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=train_param.random_flipping_axes)) if train_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=train_param.scaling_percentage[0], max_percentage=train_param.scaling_percentage[1], antialiasing=train_param.antialiasing, isotropic=train_param.isotropic_scaling)) if train_param.rotation_angle or \ self.action_param.rotation_angle_x or \ self.action_param.rotation_angle_y or \ self.action_param.rotation_angle_z: rotation_layer = RandomRotationLayer() if train_param.rotation_angle: rotation_layer.init_uniform_angle( train_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( self.action_param.rotation_angle_x, self.action_param.rotation_angle_y, self.action_param.rotation_angle_z) augmentation_layers.append(rotation_layer) # only add augmentation to first reader (not validation reader) self.readers[0].add_preprocessing_layers(normalisation_layers + augmentation_layers) for reader in self.readers[1:]: reader.add_preprocessing_layers(normalisation_layers) # Checking num_classes is set correctly if self.classification_param.num_classes <= 1: raise ValueError( "Number of classes must be at least 2 for classification") for preprocessor in self.readers[0].preprocessors: if preprocessor.name == 'label_norm': if len(preprocessor.label_map[preprocessor.key[0]] ) != self.classification_param.num_classes: raise ValueError( "Number of unique labels must be equal to " "number of classes (check histogram_ref file)") def initialise_resize_sampler(self): self.sampler = [[ ResizeSampler(reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, shuffle=self.is_training, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, postfix=self.action_param.output_postfix) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() else: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=self.classification_param.num_classes, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def add_confusion_matrix_summaries_(self, outputs_collector, net_out, data_dict): """ This method defines several monitoring metrics that are derived from the confusion matrix """ labels = tf.reshape(tf.cast(data_dict['label'], tf.int64), [-1]) prediction = tf.reshape(tf.argmax(net_out, -1), [-1]) num_classes = self.classification_param.num_classes conf_mat = tf.contrib.metrics.confusion_matrix(labels, prediction, num_classes) conf_mat = tf.to_float(conf_mat) if self.classification_param.num_classes == 2: outputs_collector.add_to_collection(var=conf_mat[1][1], name='true_positives', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=conf_mat[1][0], name='false_negatives', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=conf_mat[0][1], name='false_positives', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=conf_mat[0][0], name='true_negatives', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) else: outputs_collector.add_to_collection(var=conf_mat[tf.newaxis, :, :, tf.newaxis], name='confusion_matrix', average_over_devices=True, summary_type='image', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=tf.trace(conf_mat), name='accuracy', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: self.patience = self.action_param.patience self.mode = self.action_param.early_stopping_mode if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.classification_param.num_classes, loss_type=self.action_param.loss_type) data_loss = loss_func(prediction=net_out, ground_truth=data_dict.get('label', None)) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss self.total_loss = loss grads = self.optimiser.compute_gradients( loss, colocate_gradients_with_ops=True) outputs_collector.add_to_collection(var=self.total_loss, name='total_loss', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=self.total_loss, name='total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='data_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) self.add_confusion_matrix_summaries_(outputs_collector, net_out, data_dict) else: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) tf.logging.info('net_out.shape may need to be resized: %s', net_out.shape) output_prob = self.classification_param.output_prob num_classes = self.classification_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch( {'csv': batch_output['window']}, batch_output['location']) return True def initialise_evaluator(self, eval_param): self.eval_param = eval_param self.evaluator = ClassificationEvaluator(self.readers[0], self.classification_param, eval_param) def add_inferred_output(self, data_param, task_param): return self.add_inferred_output_like(data_param, task_param, 'label')
class RegressionApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "REGRESSION" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting regression application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.regression_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), 'balanced': (self.initialise_balanced_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.regression_param = task_param # initialise input image readers if self.is_training: reader_names = ('image', 'output', 'weight', 'sampler') elif self.is_inference: # in the inference process use `image` input only reader_names = ('image', ) elif self.is_evaluation: reader_names = ('image', 'output', 'inferred') else: tf.logging.fatal('Action `%s` not supported. Expected one of %s', self.action, self.SUPPORTED_PHASES) raise ValueError try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by(phase=reader_phase, action=self.action) self.readers = [ ImageReader(reader_names).initialise(data_param, task_param, file_list) for file_list in file_lists ] # initialise input preprocessing layers mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') \ if self.net_param.whitening else None histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') \ if (self.net_param.histogram_ref_file and self.net_param.normalisation) else None normalisation_layers = [] if histogram_normaliser is not None: normalisation_layers.append(histogram_normaliser) if mean_var_normaliser is not None: normalisation_layers.append(mean_var_normaliser) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append( PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size, mode=self.net_param.volume_padding_mode)) # initialise training data augmentation layers augmentation_layers = [] if self.is_training: train_param = self.action_param if train_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=train_param.random_flipping_axes)) if train_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=train_param.scaling_percentage[0], max_percentage=train_param.scaling_percentage[1])) if train_param.rotation_angle: rotation_layer = RandomRotationLayer() if train_param.rotation_angle: rotation_layer.init_uniform_angle( train_param.rotation_angle) augmentation_layers.append(rotation_layer) if train_param.do_elastic_deformation: spatial_rank = list(self.readers[0].spatial_ranks.values())[0] augmentation_layers.append( RandomElasticDeformationLayer( spatial_rank=spatial_rank, num_controlpoints=train_param.num_ctrl_points, std_deformation_sigma=train_param.deformation_sigma, proportion_to_augment=train_param.proportion_to_deform) ) # only add augmentation to first reader (not validation reader) self.readers[0].add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) for reader in self.readers[1:]: reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers) def initialise_uniform_sampler(self): self.sampler = [[ UniformSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_weighted_sampler(self): self.sampler = [[ WeightedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_resize_sampler(self): self.sampler = [[ ResizeSampler(reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, shuffle=self.is_training, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_sampler(self): self.sampler = [[ GridSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_balanced_sampler(self): self.sampler = [[ BalancedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() elif self.is_inference: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_aggregator(self): self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=1, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) crop_layer = CropLayer(border=self.regression_param.loss_border) weight_map = data_dict.get('weight', None) weight_map = None if weight_map is None else crop_layer(weight_map) data_loss = loss_func(prediction=crop_layer(net_out), ground_truth=crop_layer(data_dict['output']), weight_map=weight_map) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients( loss, colocate_gradients_with_ops=True) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) elif self.is_inference: data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) net_out = PostProcessingLayer('IDENTITY')(net_out) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): if self.is_inference: return self.output_decoder.decode_batch(batch_output['window'], batch_output['location']) return True def initialise_evaluator(self, eval_param): self.eval_param = eval_param self.evaluator = RegressionEvaluator(self.readers[0], self.regression_param, eval_param) def add_inferred_output(self, data_param, task_param): return self.add_inferred_output_like(data_param, task_param, 'output')
class SegmentationApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, action): super(SegmentationApplication, self).__init__() tf.logging.info('starting segmentation application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), 'balanced': (self.initialise_balanced_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.segmentation_param = task_param file_lists = self.get_file_lists(data_partitioner) # read each line of csv files into an instance of Subject if self.is_training: self.readers = [] for file_list in file_lists: reader = ImageReader({'image', 'label', 'weight', 'sampler'}) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) elif self.is_inference: # in the inference process use image input only inference_reader = ImageReader({'image'}) file_list = pd.concat([ data_partitioner.inference_files, data_partitioner.validation_files ], axis=0) file_list.index = range(file_list.shape[0]) inference_reader.initialise(data_param, task_param, file_list) self.readers = [inference_reader] elif self.is_evaluation: file_list = data_partitioner.inference_files reader = ImageReader({'image', 'label', 'inferred'}) reader.initialise(data_param, task_param, file_list) self.readers = [reader] else: raise ValueError( 'Action `{}` not supported. Expected one of {}'.format( self.action, self.SUPPORTED_ACTIONS)) foreground_masking_layer = None if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') label_normalisers = None if self.net_param.histogram_ref_file and \ task_param.label_normalisation: label_normalisers = [ DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) ] if self.is_evaluation: label_normalisers.append( DiscreteLabelNormalisationLayer( image_name='inferred', modalities=vars(task_param).get('inferred'), model_filename=self.net_param.histogram_ref_file)) label_normalisers[-1].key = label_normalisers[0].key normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation and \ (self.is_training or not task_param.output_prob): normalisation_layers.extend(label_normalisers) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1]) ) if self.action_param.rotation_angle or \ self.action_param.rotation_angle_x or \ self.action_param.rotation_angle_y or \ self.action_param.rotation_angle_z: rotation_layer = RandomRotationLayer() if self.action_param.rotation_angle: rotation_layer.init_uniform_angle( self.action_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( self.action_param.rotation_angle_x, self.action_param.rotation_angle_y, self.action_param.rotation_angle_z) augmentation_layers.append(rotation_layer) # add deformation layer if self.action_param.do_elastic_deformation: spatial_rank = list(self.readers[0].spatial_ranks.values())[0] augmentation_layers.append( RandomElasticDeformationLayer( spatial_rank=spatial_rank, num_controlpoints=self.action_param.num_ctrl_points, std_deformation_sigma=self.action_param. deformation_sigma, proportion_to_augment=self.action_param. proportion_to_deform)) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append( PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) # only add augmentation to first reader (not validation reader) self.readers[0].add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) for reader in self.readers[1:]: reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers) def initialise_uniform_sampler(self): self.sampler = [[ UniformSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_weighted_sampler(self): self.sampler = [[ WeightedSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_resize_sampler(self): self.sampler = [[ ResizeSampler(reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_sampler(self): self.sampler = [[ GridSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_balanced_sampler(self): self.sampler = [[ BalancedSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() elif self.is_inference: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_aggregator(self): self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=self.segmentation_param.num_classes, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): # def data_net(for_training): # with tf.name_scope('train' if for_training else 'validation'): # sampler = self.get_sampler()[0][0 if for_training else -1] # data_dict = sampler.pop_batch_op() # image = tf.cast(data_dict['image'], tf.float32) # return data_dict, self.net(image, is_training=for_training) def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: # if self.action_param.validation_every_n > 0: # data_dict, net_out = tf.cond(tf.logical_not(self.is_validation), # lambda: data_net(True), # lambda: data_net(False)) # else: # data_dict, net_out = data_net(True) if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) image = tf.unstack(image, axis=-1) net_out = self.net( { MODALITIES[k]: tf.expand_dims(image[k], -1) for k in range(2) }, is_training=True) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type, softmax=self.segmentation_param.softmax) data_loss = loss_func(prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=None) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image*180.0, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=tf.reduce_mean(image), name='mean_image', # average_over_devices=False, summary_type='scalar', # collection=CONSOLE) elif self.is_inference: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) image = tf.unstack(image, axis=-1) net_out = self.net( { MODALITIES[k]: tf.expand_dims(image[k], -1) for k in range(2) }, is_training=True) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): if self.is_inference: return self.output_decoder.decode_batch(batch_output['window'], batch_output['location']) return True def initialise_evaluator(self, eval_param): self.eval_param = eval_param self.evaluator = SegmentationEvaluator(self.readers[0], self.segmentation_param, eval_param) def add_inferred_output(self, data_param, task_param): return self.add_inferred_output_like(data_param, task_param, 'label')
class RegressionRecApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "REGRESSION" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting recursive regression application') self.action = action self.net_param = net_param self.net2_param = copy.deepcopy(net_param) self.action_param = action_param self.regression_param = None self.data_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.regression_param = task_param # read each line of csv files into an instance of Subject if self.is_training: file_lists = [] if self.action_param.validation_every_n > 0: file_lists.append(data_partitioner.train_files) file_lists.append(data_partitioner.validation_files) else: file_lists.append(data_partitioner.train_files) self.readers = [] for file_list in file_lists: reader = ImageReader(SUPPORTED_INPUT) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) else: inference_reader = ImageReader(['image']) file_list = data_partitioner.inference_files inference_reader.initialise(data_param, task_param, file_list) self.readers = [inference_reader] mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1]) ) if self.action_param.rotation_angle: augmentation_layers.append(RandomRotationLayer()) augmentation_layers[-1].init_uniform_angle( self.action_param.rotation_angle) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append( PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) for reader in self.readers: reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) def initialise_uniform_sampler(self): self.sampler = [[ UniformSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_weighted_sampler(self): self.sampler = [[ WeightedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_resize_sampler(self): self.sampler = [[ ResizeSampler(reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, shuffle=self.is_training, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_sampler(self): self.sampler = [[ GridSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_balanced_sampler(self): self.sampler = [[ BalancedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() else: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=1, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) self.net2 = ApplicationNetFactory.create(self.net2_param.name)( num_classes=1, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net2_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) pct1_out = self.net(image, self.is_training) res2_out = self.net2(tf.concat([image, pct1_out], 4), self.is_training) pct2_out = tf.add(pct1_out, res2_out) res3_out = self.net2(tf.concat([image, pct2_out], 4), self.is_training) pct3_out = tf.add(pct2_out, res3_out) #res4_out = self.net2(tf.concat([image, pct3_out],4), self.is_training) #pct4_out = tf.add(pct3_out,res4_out) #net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) crop_layer = CropLayer(border=self.regression_param.loss_border, name='crop-88') data_loss1 = loss_func( prediction=crop_layer(pct1_out), ground_truth=crop_layer(data_dict.get('output', None)), weight_map=None if data_dict.get('weight', None) is None else crop_layer(data_dict.get('weight', None))) data_loss2 = loss_func( prediction=crop_layer(pct2_out), ground_truth=crop_layer(data_dict.get('output', None)), weight_map=None if data_dict.get('weight', None) is None else crop_layer(data_dict.get('weight', None))) data_loss3 = loss_func( prediction=crop_layer(pct3_out), ground_truth=crop_layer(data_dict.get('output', None)), weight_map=None if data_dict.get('weight', None) is None else crop_layer(data_dict.get('weight', None))) #prediction = crop_layer(net_out) #ground_truth = crop_layer(data_dict.get('output', None)) #weight_map = None if data_dict.get('weight', None) is None \ #else crop_layer(data_dict.get('weight', None)) #data_loss = loss_func(prediction=prediction, #ground_truth=ground_truth, #weight_map=weight_map) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = reg_loss + data_loss1 + data_loss2 + data_loss3 else: loss = data_loss1 + data_loss2 + data_loss3 grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=loss, name='Loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss1, name='data_loss1', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss2, name='data_loss2', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss3, name='data_loss3', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss1, name='data_loss1', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=data_loss2, name='data_loss2', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=data_loss3, name='data_loss3', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=loss, name='LossSum', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=pct3_out, name="pct3_out", # average_over_devices=True, summary_type="image3_axial", # collection=TF_SUMMARIES) else: data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) #net_out = self.net(image, is_training=self.is_training) pct1_out = self.net(image, self.is_training) res2_out = self.net2(tf.concat([image, pct1_out], 4), self.is_training) pct2_out = tf.add(pct1_out, res2_out) res3_out = self.net2(tf.concat([image, pct2_out], 4), self.is_training) pct3_out = tf.add(pct2_out, res3_out) res4_out = self.net2(tf.concat([image, pct3_out], 4), self.is_training) pct4_out = tf.add(pct3_out, res4_out) crop_layer = CropLayer(border=0, name='crop-88') post_process_layer = PostProcessingLayer('IDENTITY') net_out = post_process_layer(crop_layer(pct4_out)) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator() def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch( {'window_image': batch_output['window']}, batch_output['location']) else: return True
class SegmentationApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, action): super(SegmentationApplication, self).__init__() tf.logging.info('starting segmentation application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), 'balanced': (self.initialise_balanced_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), } self.learning_rate = None self.current_lr = tf.constant(0) def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.segmentation_param = task_param # initialise input image readers if self.is_training: reader_names = ('image', 'label', 'weight_map', 'sampler') elif self.is_inference: # in the inference process use `image` input only reader_names = ('image',) elif self.is_evaluation: reader_names = ('image', 'label', 'inferred') else: tf.logging.fatal( 'Action `%s` not supported. Expected one of %s', self.action, self.SUPPORTED_PHASES) raise ValueError try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by( phase=reader_phase, action=self.action) self.readers = [ ImageReader(reader_names).initialise( data_param, task_param, file_list) for file_list in file_lists] # initialise input preprocessing layers foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) \ if self.net_param.normalise_foreground_only else None mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) \ if self.net_param.whitening else None percentile_normaliser = PercentileNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer, cutoff=self.net_param.cutoff) \ if self.net_param.percentile_normalisation else None histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') \ if (self.net_param.histogram_ref_file and self.net_param.normalisation) else None label_normalisers = None if self.net_param.histogram_ref_file and \ task_param.label_normalisation: label_normalisers = [DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file)] if self.is_evaluation: label_normalisers.append( DiscreteLabelNormalisationLayer( image_name='inferred', modalities=vars(task_param).get('inferred'), model_filename=self.net_param.histogram_ref_file)) label_normalisers[-1].key = label_normalisers[0].key normalisation_layers = [] if histogram_normaliser is not None: normalisation_layers.append(histogram_normaliser) if mean_var_normaliser is not None: normalisation_layers.append(mean_var_normaliser) if percentile_normaliser is not None: normalisation_layers.append(percentile_normaliser) if task_param.label_normalisation and \ (self.is_training or not task_param.output_prob): normalisation_layers.extend(label_normalisers) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append(PadLayer( image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size, mode=self.net_param.volume_padding_mode)) # initialise training data augmentation layers augmentation_layers = [] if self.is_training: train_param = self.action_param if train_param.random_flipping_axes != -1: augmentation_layers.append(RandomFlipLayer( flip_axes=train_param.random_flipping_axes)) if train_param.scaling_percentage: augmentation_layers.append(RandomSpatialScalingLayer( min_percentage=train_param.scaling_percentage[0], max_percentage=train_param.scaling_percentage[1], antialiasing=train_param.antialiasing)) if train_param.rotation_angle or \ train_param.rotation_angle_x or \ train_param.rotation_angle_y or \ train_param.rotation_angle_z: rotation_layer = RandomRotationLayer() if train_param.rotation_angle: rotation_layer.init_uniform_angle( train_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( train_param.rotation_angle_x, train_param.rotation_angle_y, train_param.rotation_angle_z) augmentation_layers.append(rotation_layer) if train_param.do_elastic_deformation: spatial_rank = list(self.readers[0].spatial_ranks.values())[0] augmentation_layers.append(RandomElasticDeformationLayer( spatial_rank=spatial_rank, num_controlpoints=train_param.num_ctrl_points, std_deformation_sigma=train_param.deformation_sigma, proportion_to_augment=train_param.proportion_to_deform)) # only add augmentation to first reader (not validation reader) self.readers[0].add_preprocessing_layers( volume_padding_layer + normalisation_layers + augmentation_layers) for reader in self.readers[1:]: reader.add_preprocessing_layers( volume_padding_layer + normalisation_layers) def initialise_uniform_sampler(self): self.sampler = [[UniformSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_weighted_sampler(self): self.sampler = [[WeightedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_resize_sampler(self): self.sampler = [[ResizeSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, shuffle=self.is_training, smaller_final_batch_mode=self.net_param.smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_sampler(self): self.sampler = [[GridSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, smaller_final_batch_mode=self.net_param.smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_balanced_sampler(self): self.sampler = [[BalancedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() elif self.is_inference: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_aggregator(self): self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=self.segmentation_param.num_classes, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = {'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob} net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): self.learning_rate = tf.placeholder(tf.float32, shape=[]) optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.learning_rate) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type, softmax=self.segmentation_param.softmax) data_loss = loss_func( prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight_map', None)) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss # Get all vars to_optimise = tf.trainable_variables() vars_to_freeze = \ self.action_param.vars_to_freeze or \ self.action_param.vars_to_restore if vars_to_freeze: import re var_regex = re.compile(vars_to_freeze) # Only optimise vars that are not frozen to_optimise = \ [v for v in to_optimise if not var_regex.search(v.name)] tf.logging.info( "Optimizing %d out of %d trainable variables, " "the other variables fixed (--vars_to_freeze %s)", len(to_optimise), len(tf.trainable_variables()), vars_to_freeze) grads = self.optimiser.compute_gradients( loss, var_list=to_optimise, colocate_gradients_with_ops=True) # clip gradients gradients, variables = zip(*grads) gradients, _ = tf.clip_by_global_norm(gradients, self.action_param.gradient_clipping_value) grads = list(zip(gradients, variables)) gnorm = tf.global_norm(list(gradients)) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=gnorm, name='gnorm', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=gnorm, name='gnorm', average_over_devices=False, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=self.learning_rate, name='lr', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) #outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) #outputs_collector.add_to_collection( # var=net_out, name='output', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image*180.0, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=tf.reduce_mean(image), name='mean_image', # average_over_devices=False, summary_type='scalar', # collection=CONSOLE) elif self.is_inference: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = {'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob} net_out = self.net(image, **net_args) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): if self.is_inference: return self.output_decoder.decode_batch( batch_output['window'], batch_output['location']) return True def initialise_evaluator(self, eval_param): self.eval_param = eval_param self.evaluator = SegmentationEvaluator(self.readers[0], self.segmentation_param, eval_param) def add_inferred_output(self, data_param, task_param): return self.add_inferred_output_like(data_param, task_param, 'label') def set_iteration_update(self, iteration_message): """ This function will be called by the application engine at each iteration. """ current_iter = iteration_message.current_iter if iteration_message.is_training: if current_iter < self.action_param.warmup: self.current_lr = self.action_param.lr/(1. + math.exp(10 * (-current_iter/self.action_param.warmup+0.5))) else: self.current_lr = self.action_param.lr * pow( self.action_param.lr_gamma, ((current_iter - self.action_param.warmup) // self.action_param.lr_step_size)) iteration_message.data_feed_dict[self.is_validation] = False iteration_message.data_feed_dict[self.learning_rate] = self.current_lr elif iteration_message.is_validation: iteration_message.data_feed_dict[self.is_validation] = True iteration_message.data_feed_dict[self.learning_rate] = self.current_lr
class SegmentationApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, action): super(SegmentationApplication, self).__init__() tf.logging.info('starting segmentation application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), 'balanced': (self.initialise_balanced_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.segmentation_param = task_param # initialise input image readers if self.is_training: reader_names = ('image', 'label', 'weight', 'sampler') elif self.is_inference: # in the inference process use `image` input only reader_names = ('image', ) elif self.is_evaluation: reader_names = ('image', 'label', 'inferred') else: tf.logging.fatal('Action `%s` not supported. Expected one of %s', self.action, self.SUPPORTED_PHASES) raise ValueError try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by(phase=reader_phase, action=self.action) self.readers = [ ImageReader(reader_names).initialise(data_param, task_param, file_list) for file_list in file_lists ] # initialise input preprocessing layers foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) \ if self.net_param.normalise_foreground_only else None mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) \ if self.net_param.whitening else None histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') \ if (self.net_param.histogram_ref_file and self.net_param.normalisation) else None rgb_normaliser = RGBHistogramEquilisationLayer( image_name='image', name='rbg_norm_layer' ) if self.net_param.rgb_normalisation else None label_normalisers = None if self.net_param.histogram_ref_file and \ task_param.label_normalisation: label_normalisers = [ DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) ] if self.is_evaluation: label_normalisers.append( DiscreteLabelNormalisationLayer( image_name='inferred', modalities=vars(task_param).get('inferred'), model_filename=self.net_param.histogram_ref_file)) label_normalisers[-1].key = label_normalisers[0].key normalisation_layers = [] if histogram_normaliser is not None: normalisation_layers.append(histogram_normaliser) if rgb_normaliser is not None: normalisation_layers.append(rgb_normaliser) if mean_var_normaliser is not None: normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation and \ (self.is_training or not task_param.output_prob): normalisation_layers.extend(label_normalisers) volume_padding_layer = [ PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size, mode=self.net_param.volume_padding_mode, pad_to=self.net_param.volume_padding_to_size) ] # initialise training data augmentation layers augmentation_layers = [] if self.is_training: train_param = self.action_param self.patience = train_param.patience self.mode = self.action_param.early_stopping_mode if train_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=train_param.random_flipping_axes)) if train_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=train_param.scaling_percentage[0], max_percentage=train_param.scaling_percentage[1], antialiasing=train_param.antialiasing, isotropic=train_param.isotropic_scaling)) if train_param.rotation_angle or \ train_param.rotation_angle_x or \ train_param.rotation_angle_y or \ train_param.rotation_angle_z: rotation_layer = RandomRotationLayer() if train_param.rotation_angle: rotation_layer.init_uniform_angle( train_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( train_param.rotation_angle_x, train_param.rotation_angle_y, train_param.rotation_angle_z) augmentation_layers.append(rotation_layer) if train_param.do_elastic_deformation: spatial_rank = list(self.readers[0].spatial_ranks.values())[0] augmentation_layers.append( RandomElasticDeformationLayer( spatial_rank=spatial_rank, num_controlpoints=train_param.num_ctrl_points, std_deformation_sigma=train_param.deformation_sigma, proportion_to_augment=train_param.proportion_to_deform) ) # only add augmentation to first reader (not validation reader) self.readers[0].add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) for reader in self.readers[1:]: reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers) # Checking num_classes is set correctly if self.segmentation_param.num_classes <= 1: raise ValueError( "Number of classes must be at least 2 for segmentation") for preprocessor in self.readers[0].preprocessors: if preprocessor.name == 'label_norm': if len(preprocessor.label_map[preprocessor.key[0]] ) != self.segmentation_param.num_classes: raise ValueError( "Number of unique labels must be equal to " "number of classes (check histogram_ref file)") def initialise_uniform_sampler(self): self.sampler = [[ UniformSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_weighted_sampler(self): self.sampler = [[ WeightedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_resize_sampler(self): self.sampler = [[ ResizeSampler(reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, shuffle=self.is_training, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_sampler(self): self.sampler = [[ GridSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_balanced_sampler(self): self.sampler = [[ BalancedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix, fill_constant=self.action_param.fill_constant) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() elif self.is_inference: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_aggregator(self): self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=self.segmentation_param.num_classes, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() def mixup_switch_sampler(for_training): # get first set of samples d_dict = switch_sampler(for_training=for_training) mix_fields = ('image', 'weight', 'label') if not for_training: with tf.name_scope('nomix'): # ensure label is appropriate for dense loss functions ground_truth = tf.cast(d_dict['label'], tf.int32) one_hot = tf.one_hot( tf.squeeze(ground_truth, axis=-1), depth=self.segmentation_param.num_classes) d_dict['label'] = one_hot else: with tf.name_scope('mixup'): # get the mixing parameter from the Beta distribution alpha = self.segmentation_param.mixup_alpha beta = tf.distributions.Beta(alpha, alpha) # 1, 1: uniform: rand_frac = beta.sample() # get another minibatch d_dict_to_mix = switch_sampler(for_training=True) # look at binarised labels: sort them if self.segmentation_param.mix_match: # sum up the positive labels to sort by their volumes inds1 = tf.argsort( tf.map_fn(tf.reduce_sum, tf.cast(d_dict['label'], tf.int64))) inds2 = tf.argsort( tf.map_fn( tf.reduce_sum, tf.cast(d_dict_to_mix['label'] > 0, tf.int64))) for field in [ field for field in mix_fields if field in d_dict ]: d_dict[field] = tf.gather(d_dict[field], indices=inds1) # note: sorted for opposite directions for d_dict_to_mix d_dict_to_mix[field] = tf.gather( d_dict_to_mix[field], indices=inds2[::-1]) # making the labels dense and one-hot for d in (d_dict, d_dict_to_mix): ground_truth = tf.cast(d['label'], tf.int32) one_hot = tf.one_hot( tf.squeeze(ground_truth, axis=-1), depth=self.segmentation_param.num_classes) d['label'] = one_hot # do the mixing for any fields that are relevant and present mixed_up = { field: d_dict[field] * rand_frac + d_dict_to_mix[field] * (1 - rand_frac) for field in mix_fields if field in d_dict } # reassign all relevant values in d_dict d_dict.update(mixed_up) return d_dict if self.is_training: if not self.segmentation_param.do_mixup: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: # mix up the samples if not in validation phase data_dict = tf.cond( tf.logical_not(self.is_validation), lambda: mixup_switch_sampler(for_training=True), lambda: mixup_switch_sampler(for_training=False )) # don't mix the validation image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type, softmax=self.segmentation_param.softmax) data_loss = loss_func(prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss # Get all vars to_optimise = tf.trainable_variables() vars_to_freeze = \ self.action_param.vars_to_freeze or \ self.action_param.vars_to_restore if vars_to_freeze: import re var_regex = re.compile(vars_to_freeze) # Only optimise vars that are not frozen to_optimise = \ [v for v in to_optimise if not var_regex.search(v.name)] tf.logging.info( "Optimizing %d out of %d trainable variables, " "the other variables fixed (--vars_to_freeze %s)", len(to_optimise), len(tf.trainable_variables()), vars_to_freeze) grads = self.optimiser.compute_gradients( loss, var_list=to_optimise, colocate_gradients_with_ops=True) self.total_loss = loss # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=self.total_loss, name='total_loss', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=self.total_loss, name='total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image*180.0, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=tf.reduce_mean(image), name='mean_image', # average_over_devices=False, summary_type='scalar', # collection=CONSOLE) elif self.is_inference: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): if self.is_inference: return self.output_decoder.decode_batch( {'window_seg': batch_output['window']}, batch_output['location']) return True def initialise_evaluator(self, eval_param): self.eval_param = eval_param self.evaluator = SegmentationEvaluator(self.readers[0], self.segmentation_param, eval_param) def add_inferred_output(self, data_param, task_param): return self.add_inferred_output_like(data_param, task_param, 'label')
def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order)
class SegmentationApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, is_training): super(SegmentationApplication, self).__init__() tf.logging.info('starting segmentation application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), } self.loss_variable = None self.first_slice = None self.netOut = None self.GROUNDTRUTH = None self.PREDICTION = None self.CONT = 0 self.SUMA = None self.GRADS = None self.CONV_KERNEL = None #self.IDS = None def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.segmentation_param = task_param # read each line of csv files into an instance of Subject if self.is_training: file_lists = [] if self.action_param.validation_every_n > 0: file_lists.append(data_partitioner.train_files) file_lists.append(data_partitioner.validation_files) else: file_lists.append(data_partitioner.all_files) self.readers = [] for file_list in file_lists: reader = ImageReader(SUPPORTED_INPUT) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) else: # in the inference process use image input only inference_reader = ImageReader(['image']) file_list = data_partitioner.inference_files inference_reader.initialise(data_param, task_param, file_list) self.readers = [inference_reader] foreground_masking_layer = None if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') label_normaliser = None if self.net_param.histogram_ref_file: label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation: normalisation_layers.append(label_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append(RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append(RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1])) if self.action_param.rotation_angle or \ self.action_param.rotation_angle_x or \ self.action_param.rotation_angle_y or \ self.action_param.rotation_angle_z: rotation_layer = RandomRotationLayer() if self.action_param.rotation_angle: rotation_layer.init_uniform_angle( self.action_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( self.action_param.rotation_angle_x, self.action_param.rotation_angle_y, self.action_param.rotation_angle_z) augmentation_layers.append(rotation_layer) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append(PadLayer( image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) for reader in self.readers: reader.add_preprocessing_layers( volume_padding_layer + normalisation_layers + augmentation_layers) def initialise_uniform_sampler(self): self.sampler = [[UniformSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_weighted_sampler(self): self.sampler = [[WeightedSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_resize_sampler(self): self.sampler = [[ResizeSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_sampler(self): self.sampler = [[GridSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() else: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_network(self): print("Initializing network") #IMPORTING REGULARIZERS w AND b w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) #---------W_INI = "he_normal", -- application_factory.py w_ini= InitializerFactory.get_initializer(name=self.net_param.weight_initializer) print("wWwWwWwWwWWWWWWwWWWWWWWWWweight_initializer; ", self.net_param.weight_initializer) print("NNNNNNNNname of application: ", self.net_param.name) #SELF.NET_PARAM.NAME = DENSE_VET #Create dense_vnet and initialize with regularizers and activ funcs. self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=self.segmentation_param.num_classes, w_initializer=w_ini, b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) #print("Wwwwwwwwwwwwwwwww_INITIALIZER :", w_ini) #print("BBBBBBBBBBBBBBBBB_INITIALIZER :", b_initializer) def connect_data_and_network(self,outputs_collector=None, gradients_collector=None): #def data_net(for_training): # with tf.name_scope('train' if for_training else 'validation'): # sampler = self.get_sampler()[0][0 if for_training else -1] # data_dict = sampler.pop_batch_op() # image = tf.cast(data_dict['image'], tf.float32) # return data_dict, self.net(image, is_training=for_training) def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: print("-CONNECT DATA AND NETWORK -TRAINING---------------") #if self.action_param.validation_every_n > 0: # data_dict, net_out = tf.cond(tf.logical_not(self.is_validation), # lambda: data_net(True), # lambda: data_net(False)) #else: # data_dict, net_out = data_net(True) if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) #ADAM OPTIMISER self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) print("####################################nombre del optimiser: ",self.action_param.optimiser) print("##############################3learning rate: ", self.action_param.lr) #loss func loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) ground_truth=data_dict.get('label', None) weight_map=data_dict.get('weight', None) #data_loss, ONEHOT, IDS= loss_func( data_loss = loss_func( prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) ################################################################ ################################################################ #setting up printing variables self.loss_variable = data_loss firstSlice = ground_truth self.first_slice = firstSlice #self.first_slice = tf.squeeze(tf.slice(firstSlice, [0,0,0,60,0], [1,103,103,1,1])) #self.first_slice_cut = tf.slice(firstSlice, [0,52,52,60,1], [1,30,30,1,1]) netOut = tf.nn.softmax(net_out) self.netOut = netOut #self.netOut = tf.squeeze(netOut[0,50,50,1,:]) GROUNDTRUTH, PREDICTION, CONT = loss_func.return_loss_args() self.GROUNDTRUTH = GROUNDTRUTH self.PREDICTION = PREDICTION self.CONT = CONT self.SUMA = loss_func.SUMA print("Salio del seteo de variable en connect data and net") ################################################################ ################################################################ #calculating regularizers reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) print("############## que que e isso: ", tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) #grads2 = self.optimiser.compute_gradients(loss,[prediction]) self.GRADS = grads #print("#############GRADIENDSSSSSSSSSS", grads) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) ##################### outputs_collector.add_to_collection( var=image*180.0, name='image', average_over_devices=False, summary_type='image3_sagittal', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=image, name='image', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=tf.reduce_mean(image), name='mean_image', average_over_devices=False, summary_type='scalar', collection=CONSOLE) else: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator() def return_loss_variable(self): return self.loss_variable def return_first_slice(self): return self.first_slice, self.first_slice.get_shape(), self.netOut, self.netOut.get_shape() def return_seg_args(self): return self.GROUNDTRUTH, self.PREDICTION, self.CONT def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch( batch_output['window'], batch_output['location']) return True
class RegressionApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "REGRESSION" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting regression application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.regression_param = None self.data_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), } def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.regression_param = task_param # read each line of csv files into an instance of Subject if self.is_training: file_lists = [] if self.action_param.validation_every_n > 0: file_lists.append(data_partitioner.train_files) file_lists.append(data_partitioner.validation_files) else: file_lists.append(data_partitioner.train_files) self.readers = [] for file_list in file_lists: reader = ImageReader(SUPPORTED_INPUT) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) else: inference_reader = ImageReader(['image']) file_list = data_partitioner.inference_files inference_reader.initialise(data_param, task_param, file_list) self.readers = [inference_reader] mean_var_normaliser = MeanVarNormalisationLayer( image_name='image') histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append(RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append(RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1])) if self.action_param.rotation_angle: augmentation_layers.append(RandomRotationLayer()) augmentation_layers[-1].init_uniform_angle( self.action_param.rotation_angle) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append(PadLayer( image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) for reader in self.readers: reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) def initialise_uniform_sampler(self): self.sampler = [[UniformSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_weighted_sampler(self): self.sampler = [[WeightedSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_resize_sampler(self): self.sampler = [[ResizeSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_sampler(self): self.sampler = [[GridSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() else: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=1, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( loss_type=self.action_param.loss_type) crop_layer = CropLayer( border=self.regression_param.loss_border, name='crop-88') prediction = crop_layer(net_out) ground_truth = crop_layer(data_dict.get('output', None)) weight_map = None if data_dict.get('weight', None) is None \ else crop_layer(data_dict.get('weight', None)) data_loss = loss_func(prediction=prediction, ground_truth=ground_truth, weight_map=weight_map) reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='Loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='Loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) else: data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) crop_layer = CropLayer(border=0, name='crop-88') post_process_layer = PostProcessingLayer('IDENTITY') net_out = post_process_layer(crop_layer(net_out)) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator() def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch( batch_output['window'], batch_output['location']) else: return True
class RegressionApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "REGRESSION" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting regression application') self.action = action self.net_param = net_param self.action_param = action_param self.regression_param = None self.data_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), 'balanced': (self.initialise_balanced_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.regression_param = task_param file_lists = self.get_file_lists(data_partitioner) # read each line of csv files into an instance of Subject if self.is_training: self.readers = [] for file_list in file_lists: reader = ImageReader({'image', 'output', 'weight', 'sampler'}) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) elif self.is_inference: inference_reader = ImageReader(['image']) file_list = data_partitioner.inference_files inference_reader.initialise(data_param, task_param, file_lists[0]) self.readers = [inference_reader] elif self.is_evaluation: file_list = data_partitioner.inference_files reader = ImageReader({'image', 'output', 'inferred'}) reader.initialise(data_param, task_param, file_lists[0]) self.readers = [reader] else: raise ValueError( 'Action `{}` not supported. Expected one of {}'.format( self.action, self.SUPPORTED_ACTIONS)) mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1]) ) if self.action_param.rotation_angle: augmentation_layers.append(RandomRotationLayer()) augmentation_layers[-1].init_uniform_angle( self.action_param.rotation_angle) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append( PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size, mode=self.net_param.volume_padding_mode)) for reader in self.readers: reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) def initialise_uniform_sampler(self): self.sampler = [[ UniformSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_weighted_sampler(self): self.sampler = [[ WeightedSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_resize_sampler(self): self.sampler = [[ ResizeSampler(reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_sampler(self): self.sampler = [[ GridSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_balanced_sampler(self): self.sampler = [[ BalancedSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() elif self.is_inference: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_aggregator(self): self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=1, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) crop_layer = CropLayer(border=self.regression_param.loss_border, name='crop-88') prediction = crop_layer(net_out) ground_truth = crop_layer(data_dict.get('output', None)) weight_map = None if data_dict.get('weight', None) is None \ else crop_layer(data_dict.get('weight', None)) data_loss = loss_func(prediction=prediction, ground_truth=ground_truth, weight_map=weight_map) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # Gradient Clipping associated with VDSR3D # Gradients are clipped by value, instead of clipping by global norm. # The authors of VDSR do not specify a threshold for the clipping process. # grads2, vars2 = zip(*grads) # grads2, _ = tf.clip_by_global_norm(grads2, 5.0) # grads = zip(grads2, vars2) grads = [(tf.clip_by_value(grad, -0.00001 / self.action_param.lr, +0.00001 / self.action_param.lr), val) for grad, val in grads if grad is not None] # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='Loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='Loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) elif self.is_inference: data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) crop_layer = CropLayer(border=0, name='crop-88') post_process_layer = PostProcessingLayer('IDENTITY') net_out = post_process_layer(crop_layer(net_out)) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): if self.is_inference: return self.output_decoder.decode_batch(batch_output['window'], batch_output['location']) else: return True def initialise_evaluator(self, eval_param): self.eval_param = eval_param self.evaluator = RegressionEvaluator(self.readers[0], self.regression_param, eval_param) def add_inferred_output(self, data_param, task_param): return self.add_inferred_output_like(data_param, task_param, 'output')
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_samplers(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0 if for_training else -1] return sampler() # returns image only if self.is_training: if self.action_param.validation_every_n > 0: sampler_window = \ tf.cond(tf.logical_not(self.is_validation), lambda: switch_samplers(True), lambda: switch_samplers(False)) else: sampler_window = switch_samplers(True) image_windows, _ = sampler_window # image_windows, locations = sampler_window # decode channels for moving and fixed images image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1)] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list # estimate ddf dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer( interpolation='linear', boundary='replicate') resampled_moving_label = resampler(moving_label, dense_field) # compute label loss (foreground only) loss_func = LossFunction( n_class=1, loss_type=self.action_param.loss_type, softmax=False) label_loss = loss_func(prediction=resampled_moving_label, ground_truth=fixed_label) dice_fg = 1.0 - label_loss # appending regularisation loss total_loss = label_loss reg_loss = tf.get_collection('bending_energy') if reg_loss: total_loss = total_loss + \ self.net_param.decay * tf.reduce_mean(reg_loss) # compute training gradients with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) grads = self.optimiser.compute_gradients(total_loss) gradients_collector.add_to_collection(grads) metrics_dice = loss_func( prediction=tf.to_float(resampled_moving_label >= 0.5), ground_truth=tf.to_float(fixed_label >= 0.5)) metrics_dice = 1.0 - metrics_dice # command line output outputs_collector.add_to_collection( var=dice_fg, name='one_minus_data_loss', collection=CONSOLE) outputs_collector.add_to_collection( var=tf.reduce_mean(reg_loss), name='bending_energy', collection=CONSOLE) outputs_collector.add_to_collection( var=total_loss, name='total_loss', collection=CONSOLE) outputs_collector.add_to_collection( var=metrics_dice, name='ave_fg_dice', collection=CONSOLE) # for tensorboard outputs_collector.add_to_collection( var=dice_fg, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=total_loss, name='averaged_total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=metrics_dice, name='averaged_foreground_Dice', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # for visualisation debugging # resampled_moving_image = resampler(moving_image, dense_field) # outputs_collector.add_to_collection( # var=fixed_image, name='fixed_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=fixed_label, name='fixed_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_image, name='moving_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_label, name='moving_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_image, name='resampled_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_label, name='resampled_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=dense_field, name='ddf', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=locations, name='locations', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=shift[0], name='a', collection=CONSOLE) # outputs_collector.add_to_collection( # var=shift[1], name='b', collection=CONSOLE) else: image_windows, locations = self.sampler() image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1)] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer( interpolation='linear', boundary='replicate') resampled_moving_image = resampler(moving_image, dense_field) resampled_moving_label = resampler(moving_label, dense_field) outputs_collector.add_to_collection( var=fixed_image, name='fixed_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=moving_image, name='moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=resampled_moving_image, name='resampled_moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=resampled_moving_label, name='resampled_moving_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=fixed_label, name='fixed_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=moving_label, name='moving_label', collection=NETWORK_OUTPUT) #outputs_collector.add_to_collection( # var=dense_field, name='field', # collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=locations, name='locations', collection=NETWORK_OUTPUT) self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], # fixed image reader name='fixed_image', output_path=self.action_param.save_seg_dir, interp_order=self.action_param.output_interp_order)