def test_25d_init(self): reader = get_25d_reader() sampler = GridSampler(reader=reader, window_sizes=SINGLE_25D_DATA, batch_size=10, spatial_window_size=None, window_border=(3, 4, 5), queue_length=50) aggregator = GridSamplesAggregator( image_reader=reader, name='image', output_path=os.path.join('testing_data', 'aggregated'), window_border=(3, 4, 5), interp_order=0) more_batch = True with self.test_session() as sess: coordinator = tf.train.Coordinator() sampler.run_threads(sess, coordinator, num_threads=2) while more_batch: out = sess.run(sampler.pop_batch_op()) more_batch = aggregator.decode_batch( out['image'], out['image_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose( nib.load(output_file).shape, [255, 168, 256, 1, 1], rtol=1e-03, atol=1e-03) sampler.close_all()
def test_inverse_mapping(self): reader = get_label_reader() data_param = MOD_LABEL_DATA sampler = GridSampler(reader=reader, window_sizes=data_param, batch_size=10, spatial_window_size=None, window_border=(3, 4, 5), queue_length=50) aggregator = GridSamplesAggregator( image_reader=reader, name='label', output_path=os.path.join('testing_data', 'aggregated'), window_border=(3, 4, 5), interp_order=0) more_batch = True with self.test_session() as sess: coordinator = tf.train.Coordinator() sampler.run_threads(sess, coordinator, num_threads=2) while more_batch: out = sess.run(sampler.pop_batch_op()) more_batch = aggregator.decode_batch( out['label'], out['label_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join( 'testing_data', 'aggregated', output_filename) self.assertAllClose( nib.load(output_file).shape, [256, 168, 256, 1, 1]) sampler.close_all() output_data = nib.load(output_file).get_data()[..., 0, 0] expected_data = nib.load( 'testing_data/T1_1023_NeuroMorph_Parcellation.nii.gz').get_data() self.assertAllClose(output_data, expected_data)
def test_inverse_mapping(self): reader = get_label_reader() data_param = MOD_LABEL_DATA sampler = GridSampler(reader=reader, data_param=data_param, batch_size=10, spatial_window_size=None, window_border=(3, 4, 5), queue_length=50) aggregator = GridSamplesAggregator( image_reader=reader, name='label', output_path=os.path.join('testing_data', 'aggregated'), window_border=(3, 4, 5), interp_order=0) more_batch = True with self.test_session() as sess: coordinator = tf.train.Coordinator() sampler.run_threads(sess, coordinator, num_threads=2) while more_batch: out = sess.run(sampler.pop_batch_op()) more_batch = aggregator.decode_batch( out['label'], out['label_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join( 'testing_data', 'aggregated', output_filename) self.assertAllClose( nib.load(output_file).shape, [256, 168, 256, 1, 1]) sampler.close_all() output_data = nib.load(output_file).get_data()[..., 0, 0] expected_data = nib.load( 'testing_data/T1_1023_NeuroMorph_Parcellation.nii.gz').get_data() self.assertAllClose(output_data, expected_data)
def test_25d_init(self): reader = get_25d_reader() sampler = GridSampler(reader=reader, data_param=SINGLE_25D_DATA, batch_size=10, spatial_window_size=None, window_border=(3, 4, 5), queue_length=50) aggregator = GridSamplesAggregator( image_reader=reader, name='image', output_path=os.path.join('testing_data', 'aggregated'), window_border=(3, 4, 5), interp_order=0) more_batch = True with self.test_session() as sess: coordinator = tf.train.Coordinator() sampler.run_threads(sess, coordinator, num_threads=2) while more_batch: out = sess.run(sampler.pop_batch_op()) more_batch = aggregator.decode_batch( out['image'], out['image_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose( nib.load(output_file).shape, [255, 168, 256, 1, 1], rtol=1e-03, atol=1e-03) sampler.close_all()
def test_2d_init(self): reader = get_2d_reader() sampler = GridSampler(reader=reader, window_sizes=MOD_2D_DATA, batch_size=10, spatial_window_size=None, window_border=(3, 4, 5), queue_length=50) aggregator = GridSamplesAggregator(image_reader=reader, name='image', output_path=os.path.join( 'testing_data', 'aggregated'), window_border=(3, 4, 5), interp_order=0) more_batch = True with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: out = sess.run(sampler.pop_batch_op()) more_batch = aggregator.decode_batch( {'window_image': out['image']}, out['image_location']) output_filename = 'window_image_{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, [128, 128]) sampler.close_all()
def test_init_2d_mo_bidimcsv(self): reader = get_2d_reader() sampler = GridSampler(reader=reader, window_sizes=MOD_2D_DATA, batch_size=10, spatial_window_size=None, window_border=(3, 4, 5), queue_length=50) aggregator = GridSamplesAggregator(image_reader=reader, name='image', output_path=os.path.join( 'testing_data', 'aggregated'), window_border=(3, 4, 5), interp_order=0) more_batch = True with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: out = sess.run(sampler.pop_batch_op()) out_flatten = np.reshape(np.asarray(out['image']), [10, -1]) min_val = np.sum( np.reshape(np.asarray(out['image']), [10, -1]), 1) stats_val = np.concatenate([ np.min(out_flatten, 1, keepdims=True), np.max(out_flatten, 1, keepdims=True), np.sum(out_flatten, 1, keepdims=True) ], 1) stats_val = np.expand_dims(stats_val, 1) stats_val = np.concatenate([stats_val, stats_val], axis=1) more_batch = aggregator.decode_batch( { 'window_image': out['image'], 'csv_sum': min_val, 'csv_stats_2d': stats_val }, out['image_location']) output_filename = 'window_image_{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) sum_filename = os.path.join( 'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) stats_filename = os.path.join( 'testing_data', 'aggregated', 'csv_stats_2d_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, (128, 128)) min_pd = pd.read_csv(sum_filename) self.assertAllClose(min_pd.shape, [10, 9]) stats_pd = pd.read_csv(stats_filename) self.assertAllClose(stats_pd.shape, [10, 14]) sampler.close_all()
def run_inference(self, window_border, inference_path, checkpoint_path): output = GridSamplesAggregator( image_reader=self.samplers[INFERENCE].reader, window_border=window_border, interp_order=3, output_path=inference_path) self.model.load_state_dict(torch.load(checkpoint_path)) self.model.to(self.device) self.model.eval() for batch_output in self.samplers[INFERENCE](): window = batch_output['image'] # [...,0,:] eliminates time coordinate from NiftyNet Volume window = window[..., 0, :] window = np.transpose(window, (0, 4, 1, 2, 3)) window = torch.Tensor(window).to(self.device) with torch.no_grad(): outputs = self.model(window) outputs = outputs.cpu().numpy() outputs = np.transpose(outputs, (0, 2, 3, 4, 1)) output.decode_batch(outputs, batch_output['image_location'])
def inference(sampler, model, device, pred_path, cp_path): output = GridSamplesAggregator(image_reader=sampler.reader, window_border=(8, 8, 8), output_path=pred_path) for _ in sampler(): # for each subject model.load_state_dict(torch.load(cp_path)) model.to(device) model.eval() for batch_output in sampler(): # for each sliding window step window = batch_output['image'] # [...,0,:] eliminates time coordinate from NiftyNet Volume window = window[..., 0, :] window = np.transpose(window, (0, 4, 1, 2, 3)) window = torch.Tensor(window).to(device) with torch.no_grad(): outputs = model(window) outputs = outputs.cpu().numpy() outputs = np.transpose(outputs, (0, 2, 3, 4, 1)) output.decode_batch(outputs.astype(np.float32), batch_output['image_location'])
def test_filling(self): reader = get_nonnormalising_label_reader() test_constant = 0.5731 postfix = '_niftynet_out_background' test_border = (10, 7, 8) data_param = MOD_LABEL_DATA sampler = GridSampler(reader=reader, window_sizes=data_param, batch_size=10, spatial_window_size=None, window_border=test_border, queue_length=50) aggregator = GridSamplesAggregator(image_reader=reader, name='label', output_path=os.path.join( 'testing_data', 'aggregated'), window_border=test_border, interp_order=0, postfix=postfix, fill_constant=test_constant) more_batch = True with self.test_session() as sess: sampler.set_num_threads(2) while more_batch: out = sess.run(sampler.pop_batch_op()) more_batch = aggregator.decode_batch(out['label'], out['label_location']) output_filename = '{}{}.nii.gz'.format( sampler.reader.get_subject_id(0), postfix) output_file = os.path.join('testing_data', 'aggregated', output_filename) output_data = nib.load(output_file).get_data()[..., 0, 0] output_shape = output_data.shape for i in range(3): def _test_background(idcs): extract = output_data[idcs] self.assertTrue( (extract == test_constant).sum() == extract.size) extract_idcs = [slice(None)] * 3 extract_idcs[i] = slice(0, test_border[i]) _test_background(tuple(extract_idcs)) extract_idcs[i] = slice(output_shape[i] - test_border[i], output_shape[i]) _test_background(tuple(extract_idcs))
def test_3d_init_mo(self): reader = get_3d_reader() sampler = GridSampler(reader=reader, window_sizes=MULTI_MOD_DATA, batch_size=10, spatial_window_size=None, window_border=(3, 4, 5), queue_length=50) aggregator = GridSamplesAggregator(image_reader=reader, name='image', output_path=os.path.join( 'testing_data', 'aggregated'), window_border=(3, 4, 5), interp_order=0) more_batch = True with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: out = sess.run(sampler.pop_batch_op()) out_flatten = np.reshape(np.asarray(out['image']), [10, -1]) min_val = np.sum( np.reshape(np.asarray(out['image']), [10, -1]), 1) more_batch = aggregator.decode_batch( { 'window_image': out['image'], 'csv_sum': min_val }, out['image_location']) output_filename = 'window_image_{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) sum_filename = os.path.join( 'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, (256, 168, 256, 1, 2)) min_pd = pd.read_csv(sum_filename) self.assertAllClose(min_pd.shape, [420, 9]) sampler.close_all()
class BRATSApp(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting BRATS segmentation app') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.segmentation_param = task_param # read each line of csv files into an instance of Subject if self.is_training: file_lists = [] if self.action_param.validation_every_n > 0: file_lists.append(data_partitioner.train_files) file_lists.append(data_partitioner.validation_files) else: file_lists.append(data_partitioner.all_files) self.readers = [] for file_list in file_lists: reader = ImageReader(SUPPORTED_INPUT) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) else: # in the inference process use image input only inference_reader = ImageReader(['image']) file_list = data_partitioner.inference_files inference_reader.initialise(data_param, task_param, file_list) self.readers = [inference_reader] foreground_masking_layer = None if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) normalisation_layers = [] normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation: normalisation_layers.append(label_normaliser) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append(PadLayer( image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) for reader in self.readers: reader.add_preprocessing_layers( normalisation_layers + volume_padding_layer) def initialise_sampler(self): if self.is_training: self.sampler = [[UniformSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers]] else: self.sampler = [[GridSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=self.segmentation_param.num_classes, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) data_loss = loss_func( prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) else: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch( batch_output['window'], batch_output['location']) return True
class BRATSApp(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting BRATS segmentation app') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.segmentation_param = task_param # read each line of csv files into an instance of Subject if self.is_training: file_lists = [] if self.action_param.validation_every_n > 0: file_lists.append(data_partitioner.train_files) file_lists.append(data_partitioner.validation_files) else: file_lists.append(data_partitioner.all_files) self.readers = [] for file_list in file_lists: reader = ImageReader(SUPPORTED_INPUT) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) else: # in the inference process use image input only inference_reader = ImageReader(['image']) file_list = data_partitioner.inference_files inference_reader.initialise(data_param, task_param, file_list) self.readers = [inference_reader] foreground_masking_layer = None if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) normalisation_layers = [] normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation: normalisation_layers.append(label_normaliser) volume_padding_layer = [] if self.net_param.volume_padding_size: volume_padding_layer.append( PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) for reader in self.readers: reader.add_preprocessing_layers(normalisation_layers + volume_padding_layer) def initialise_sampler(self): if self.is_training: self.sampler = [[ UniformSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] else: self.sampler = [[ GridSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=self.segmentation_param.num_classes, w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) data_loss = loss_func(prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='dice_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='dice_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) else: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch(batch_output['window'], batch_output['location']) return True
class SelectiveSampling(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, is_training): super(SelectiveSampling, self).__init__() tf.logging.info('starting segmentation application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.segmentation_param = None self.SUPPORTED_SAMPLING = { 'selective': (self.initialise_selective_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator) } def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.segmentation_param = task_param try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by( phase=reader_phase, action=self.action) # read each line of csv files into an instance of Subject if self.is_training: self.readers = [] for file_list in file_lists: reader = ImageReader(SUPPORTED_INPUT) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) else: # in the inference process use image input only inference_reader = ImageReader(['image']) inference_reader.initialise(data_param, task_param, file_lists[0]) self.readers = [inference_reader] foreground_masking_layer = None if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') label_normaliser = None if self.net_param.histogram_ref_file: label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation: normalisation_layers.append(label_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append(RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append(RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1])) if self.action_param.rotation_angle or \ self.action_param.rotation_angle_x or \ self.action_param.rotation_angle_y or \ self.action_param.rotation_angle_z: rotation_layer = RandomRotationLayer() if self.action_param.rotation_angle: rotation_layer.init_uniform_angle( self.action_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( self.action_param.rotation_angle_x, self.action_param.rotation_angle_y, self.action_param.rotation_angle_z) augmentation_layers.append(rotation_layer) volume_padding_layer = [PadLayer( image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size, mode=self.net_param.volume_padding_mode, pad_to=self.net_param.volume_padding_to_size) ] for reader in self.readers: reader.add_preprocessing_layers( volume_padding_layer + normalisation_layers + augmentation_layers) def initialise_selective_sampler(self): # print("Initialisation ", # self.segmentation_param.compulsory_labels, # self.segmentation_param.proba_connect) # print(self.segmentation_param.num_min_labels, # self.segmentation_param.proba_connect) selective_constraints = Constraint( self.segmentation_param.compulsory_labels, self.segmentation_param.min_sampling_ratio, self.segmentation_param.min_numb_labels, self.segmentation_param.proba_connect) self.sampler = [[ SelectiveSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, constraint=selective_constraints, random_windows_per_image=self.segmentation_param.rand_samples, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_sampler(self): self.sampler = [[GridSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, queue_length=self.net_param.queue_length) for reader in self.readers]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING['selective'][0]() else: self.SUPPORTED_SAMPLING['selective'][1]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=self.segmentation_param.num_classes, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) data_loss = loss_func( prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image*180.0, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=tf.reduce_mean(image), name='mean_image', # average_over_devices=False, summary_type='scalar', # collection=CONSOLE) else: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator() def interpret_output(self, batch_output): if not self.is_training: return self.output_decoder.decode_batch( {'window_image': batch_output['window']}, batch_output['location']) return True