def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix)
def test_25d_init(self): reader = get_25d_reader() sampler = ResizeSampler(reader=reader, window_sizes=SINGLE_25D_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = ResizeSamplesAggregator( image_reader=reader, name='image', output_path=os.path.join('testing_data', 'aggregated'), interp_order=3) more_batch = True with self.test_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) except tf.errors.OutOfRangeError: break more_batch = aggregator.decode_batch( out['image'], out['image_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose( nib.load(output_file).shape, [255, 168, 256, 1, 1], rtol=1e-03, atol=1e-03) sampler.close_all()
def test_inverse_mapping(self): reader = get_label_reader() sampler = ResizeSampler(reader=reader, data_param=MOD_LABEL_DATA, batch_size=1, shuffle_buffer=False, queue_length=50) aggregator = ResizeSamplesAggregator(image_reader=reader, name='label', output_path=os.path.join( 'testing_data', 'aggregated'), interp_order=0) more_batch = True with self.test_session() as sess: coordinator = tf.train.Coordinator() sampler.run_threads(sess, coordinator, num_threads=2) while more_batch: out = sess.run(sampler.pop_batch_op()) more_batch = aggregator.decode_batch(out['label'], out['label_location']) output_filename = '{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, [256, 168, 256, 1, 1]) sampler.close_all()
def test_inverse_mapping(self): reader = get_label_reader() sampler = ResizeSampler(reader=reader, window_sizes=MOD_LABEL_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = ResizeSamplesAggregator(image_reader=reader, name='label', output_path=os.path.join( 'testing_data', 'aggregated'), interp_order=0) more_batch = True with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) except tf.errors.OutOfRangeError: break more_batch = aggregator.decode_batch( {'window_label': out['label']}, out['label_location']) output_filename = 'window_label_{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, [256, 168, 256]) sampler.close_all()
def test_init_2d_mo_bidimcsv(self): reader = get_2d_reader() sampler = ResizeSampler(reader=reader, window_sizes=MOD_2D_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = ResizeSamplesAggregator(image_reader=reader, name='image', output_path=os.path.join( 'testing_data', 'aggregated'), interp_order=3) more_batch = True with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) except tf.errors.OutOfRangeError: break min_val = np.sum((np.asarray(out['image']).flatten())) stats_val = [ np.min(out['image']), np.max(out['image']), np.sum(out['image']) ] stats_val = np.expand_dims(stats_val, 0) stats_val = np.concatenate([stats_val, stats_val], axis=0) more_batch = aggregator.decode_batch( { 'window_image': out['image'], 'csv_sum': min_val, 'csv_stats_2d': stats_val }, out['image_location']) output_filename = 'window_image_{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) sum_filename = os.path.join( 'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) stats_filename = os.path.join( 'testing_data', 'aggregated', 'csv_stats_2d_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, (128, 128)) min_pd = pd.read_csv(sum_filename) self.assertAllClose(min_pd.shape, [1, 2]) stats_pd = pd.read_csv(stats_filename) self.assertAllClose(stats_pd.shape, [1, 7]) sampler.close_all()
def initialise_resize_aggregator(self): ''' Define the resize aggregator used for decoding using the configuration parameters :return: ''' self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix)
def test_3d_init_mo_3out(self): reader = get_3d_reader() sampler = ResizeSampler(reader=reader, window_sizes=MULTI_MOD_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = ResizeSamplesAggregator(image_reader=reader, name='image', output_path=os.path.join( 'testing_data', 'aggregated'), interp_order=3) more_batch = True with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) except tf.errors.OutOfRangeError: break sum_val = np.sum(out['image']) stats_val = [ np.sum(out['image']), np.min(out['image']), np.max(out['image']) ] more_batch = aggregator.decode_batch( { 'window_image': out['image'], 'csv_sum': sum_val, 'csv_stats': stats_val }, out['image_location']) output_filename = 'window_image_{}_niftynet_out.nii.gz'.format( sampler.reader.get_subject_id(0)) sum_filename = os.path.join( 'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) stats_filename = os.path.join( 'testing_data', 'aggregated', 'csv_stats_{}_niftynet_out.csv'.format( sampler.reader.get_subject_id(0))) output_file = os.path.join('testing_data', 'aggregated', output_filename) self.assertAllClose(nib.load(output_file).shape, (256, 168, 256, 1, 2)) sum_pd = pd.read_csv(sum_filename) self.assertAllClose(sum_pd.shape, [1, 2]) stats_pd = pd.read_csv(stats_filename) self.assertAllClose(stats_pd.shape, [1, 4]) sampler.close_all()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_samplers(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0 if for_training else -1] return sampler() # returns image only if self.is_training: self.patience = self.action_param.patience self.mode = self.action_param.early_stopping_mode if self.action_param.validation_every_n > 0: sampler_window = \ tf.cond(tf.logical_not(self.is_validation), lambda: switch_samplers(True), lambda: switch_samplers(False)) else: sampler_window = switch_samplers(True) image_windows, _ = sampler_window # image_windows, locations = sampler_window # decode channels for moving and fixed images image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1) ] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list # estimate ddf dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer(interpolation='linear', boundary='replicate') resampled_moving_label = resampler(moving_label, dense_field) # compute label loss (foreground only) loss_func = LossFunction(n_class=1, loss_type=self.action_param.loss_type, softmax=False) label_loss = loss_func(prediction=resampled_moving_label, ground_truth=fixed_label) dice_fg = 1.0 - label_loss # appending regularisation loss total_loss = label_loss reg_loss = tf.get_collection('bending_energy') if reg_loss: total_loss = total_loss + \ self.net_param.decay * tf.reduce_mean(reg_loss) self.total_loss = total_loss # compute training gradients with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) grads = self.optimiser.compute_gradients( total_loss, colocate_gradients_with_ops=True) gradients_collector.add_to_collection(grads) metrics_dice = loss_func( prediction=tf.to_float(resampled_moving_label >= 0.5), ground_truth=tf.to_float(fixed_label >= 0.5)) metrics_dice = 1.0 - metrics_dice # command line output outputs_collector.add_to_collection(var=dice_fg, name='one_minus_data_loss', collection=CONSOLE) outputs_collector.add_to_collection(var=tf.reduce_mean(reg_loss), name='bending_energy', collection=CONSOLE) outputs_collector.add_to_collection(var=total_loss, name='total_loss', collection=CONSOLE) outputs_collector.add_to_collection(var=metrics_dice, name='ave_fg_dice', collection=CONSOLE) # for tensorboard outputs_collector.add_to_collection(var=dice_fg, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=total_loss, name='total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=metrics_dice, name='averaged_foreground_Dice', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # for visualisation debugging # resampled_moving_image = resampler(moving_image, dense_field) # outputs_collector.add_to_collection( # var=fixed_image, name='fixed_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=fixed_label, name='fixed_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_image, name='moving_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=moving_label, name='moving_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_image, name='resampled_image', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=resampled_moving_label, name='resampled_label', # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=dense_field, name='ddf', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=locations, name='locations', collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=shift[0], name='a', collection=CONSOLE) # outputs_collector.add_to_collection( # var=shift[1], name='b', collection=CONSOLE) else: image_windows, locations = self.sampler() image_windows_list = [ tf.expand_dims(img, axis=-1) for img in tf.unstack(image_windows, axis=-1) ] fixed_image, fixed_label, moving_image, moving_label = \ image_windows_list dense_field = self.net(fixed_image, moving_image) if isinstance(dense_field, tuple): dense_field = dense_field[0] # transform the moving labels resampler = ResamplerLayer(interpolation='linear', boundary='replicate') resampled_moving_image = resampler(moving_image, dense_field) resampled_moving_label = resampler(moving_label, dense_field) outputs_collector.add_to_collection(var=fixed_image, name='fixed_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=moving_image, name='moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=resampled_moving_image, name='resampled_moving_image', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=resampled_moving_label, name='resampled_moving_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=fixed_label, name='fixed_label', collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=moving_label, name='moving_label', collection=NETWORK_OUTPUT) #outputs_collector.add_to_collection( # var=dense_field, name='field', # collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=locations, name='locations', collection=NETWORK_OUTPUT) self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], # fixed image reader name='fixed_image', output_path=self.action_param.save_seg_dir, interp_order=self.action_param.output_interp_order)