def test_25d_init(self): reader = get_25d_reader() sampler = ResizeSampler(reader=reader, window_sizes=SINGLE_25D_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = WindowAsImageAggregator( image_reader=reader, output_path=os.path.join('testing_data', 'aggregated_identity'), ) more_batch = True out_shape = [] with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) out_shape = out['image'].shape[1:] + (1, ) except tf.errors.OutOfRangeError: break more_batch = aggregator.decode_batch( {'window_image': out['image']}, out['image_location']) output_filename = '{}_window_image_niftynet_generated.nii.gz'.format( sampler.reader.get_subject_id(0)) output_file = os.path.join('testing_data', 'aggregated_identity', output_filename) out_shape = [out_shape[i] for i in NEW_ORDER_2D] + [ 1, ] self.assertAllClose(nib.load(output_file).shape, out_shape[:2]) sampler.close_all()
def test_init_2d_mo_bidimcsv(self): reader = get_2d_reader() sampler = ResizeSampler(reader=reader, window_sizes=MOD_2D_DATA, batch_size=1, shuffle=False, queue_length=50) aggregator = WindowAsImageAggregator( image_reader=reader, output_path=os.path.join('testing_data', 'aggregated_identity'), ) more_batch = True out_shape = [] with self.cached_session() as sess: sampler.set_num_threads(2) while more_batch: try: out = sess.run(sampler.pop_batch_op()) out_shape = out['image'].shape[1:] + (1, ) except tf.errors.OutOfRangeError: break min_val = np.sum((np.asarray(out['image']).flatten())) stats_val = [ np.min(out['image']), np.max(out['image']), np.sum(out['image']) ] stats_val = np.expand_dims(stats_val, 0) stats_val = np.concatenate([stats_val, stats_val], axis=0) more_batch = aggregator.decode_batch( { 'window_image': out['image'], 'csv_sum': min_val, 'csv_stats2d': stats_val }, out['image_location']) output_filename = '{}_window_image_niftynet_generated.nii.gz'.format( sampler.reader.get_subject_id(0)) sum_filename = os.path.join( 'testing_data', 'aggregated_identity', '{}_csv_sum_niftynet_generated.csv'.format( sampler.reader.get_subject_id(0))) stats_filename = os.path.join( 'testing_data', 'aggregated_identity', '{}_csv_stats2d_niftynet_generated.csv'.format( sampler.reader.get_subject_id(0))) output_file = os.path.join('testing_data', 'aggregated_identity', output_filename) out_shape = [out_shape[i] for i in NEW_ORDER_2D] + [ 1, ] self.assertAllClose(nib.load(output_file).shape, out_shape[:2]) min_pd = pd.read_csv(sum_filename) self.assertAllClose(min_pd.shape, [1, 2]) stats_pd = pd.read_csv(stats_filename) self.assertAllClose(stats_pd.shape, [1, 7]) sampler.close_all()
class AutoencoderApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "AUTOENCODER" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting autoencoder application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.autoencoder_param = None def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.autoencoder_param = task_param if not self.is_training: self._infer_type = look_up_operations( self.autoencoder_param.inference_type, SUPPORTED_INFERENCE) else: self._infer_type = None file_lists = self.get_file_lists(data_partitioner) # read each line of csv files into an instance of Subject if self.is_evaluation: NotImplementedError('Evaluation is not yet ' 'supported in this application.') if self.is_training: self.readers = [] for file_list in file_lists: reader = ImageReader(['image']) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) if self._infer_type in ('encode', 'encode-decode'): self.readers = [ImageReader(['image'])] self.readers[0].initialise(data_param, task_param, file_lists[0]) elif self._infer_type == 'sample': self.readers = [] elif self._infer_type == 'linear_interpolation': self.readers = [ImageReader(['feature'])] self.readers[0].initialise(data_param, task_param, [file_lists]) # if self.is_training or self._infer_type in ('encode', 'encode-decode'): # mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') # self.reader.add_preprocessing_layers([mean_var_normaliser]) def initialise_sampler(self): self.sampler = [] if self.is_training: self.sampler.append([ResizeSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle_buffer=True, queue_length=self.net_param.queue_length) for reader in self.readers]) return if self._infer_type in ('encode', 'encode-decode'): self.sampler.append([ResizeSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle_buffer=False, queue_length=self.net_param.queue_length) for reader in self.readers]) return if self._infer_type == 'linear_interpolation': self.sampler.append([LinearInterpolateSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, n_interpolations=self.autoencoder_param.n_interpolations, queue_length=self.net_param.queue_length) for reader in self.readers]) return def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( w_regularizer=w_regularizer, b_regularizer=b_regularizer) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_output = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) data_loss = loss_func(net_output) loss = data_loss if self.net_param.decay > 0.0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = loss + reg_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) outputs_collector.add_to_collection( var=data_loss, name='variational_lower_bound', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='variational_lower_bound', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=net_output[4], name='Originals', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=net_output[2], name='Means', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=net_output[5], name='Variances', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) else: if self._infer_type in ('encode', 'encode-decode'): data_dict = self.get_sampler()[0][0].pop_batch_op() image = tf.cast(data_dict['image'], dtype=tf.float32) net_output = self.net(image, is_training=False) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=True, collection=NETWORK_OUTPUT) if self._infer_type == 'encode-decode': outputs_collector.add_to_collection( var=net_output[2], name='generated_image', average_over_devices=True, collection=NETWORK_OUTPUT) if self._infer_type == 'encode': outputs_collector.add_to_collection( var=net_output[7], name='embedded', average_over_devices=True, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir) return elif self._infer_type == 'sample': image_size = (self.net_param.batch_size,) + \ self.action_param.spatial_window_size + (1,) dummy_image = tf.zeros(image_size) net_output = self.net(dummy_image, is_training=False) noise_shape = net_output[-1].shape.as_list() stddev = self.autoencoder_param.noise_stddev noise = tf.random_normal(shape=noise_shape, mean=0.0, stddev=stddev, dtype=tf.float32) partially_decoded_sample = self.net.shared_decoder( noise, is_training=False) decoder_output = self.net.decoder_means( partially_decoded_sample, is_training=False) outputs_collector.add_to_collection( var=decoder_output, name='generated_image', average_over_devices=True, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=None, output_path=self.action_param.save_seg_dir) return elif self._infer_type == 'linear_interpolation': # construct the entire network image_size = (self.net_param.batch_size,) + \ self.action_param.spatial_window_size + (1,) dummy_image = tf.zeros(image_size) net_output = self.net(dummy_image, is_training=False) data_dict = self.get_sampler()[0][0].pop_batch_op() real_code = data_dict['feature'] real_code = tf.reshape(real_code, net_output[-1].get_shape()) partially_decoded_sample = self.net.shared_decoder( real_code, is_training=False) decoder_output = self.net.decoder_means( partially_decoded_sample, is_training=False) outputs_collector.add_to_collection( var=decoder_output, name='generated_image', average_over_devices=True, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['feature_location'], name='location', average_over_devices=True, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir) else: raise NotImplementedError def interpret_output(self, batch_output): if self.is_training: return True else: infer_type = look_up_operations( self.autoencoder_param.inference_type, SUPPORTED_INFERENCE) if infer_type == 'encode': return self.output_decoder.decode_batch( batch_output['embedded'], batch_output['location'][:, 0:1]) if infer_type == 'encode-decode': return self.output_decoder.decode_batch( batch_output['generated_image'], batch_output['location'][:, 0:1]) if infer_type == 'sample': return self.output_decoder.decode_batch( batch_output['generated_image'], None) if infer_type == 'linear_interpolation': return self.output_decoder.decode_batch( batch_output['generated_image'], batch_output['location'][:, :2])
class AutoencoderApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "AUTOENCODER" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting autoencoder application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.autoencoder_param = None def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.autoencoder_param = task_param if not self.is_training: self._infer_type = look_up_operations( self.autoencoder_param.inference_type, SUPPORTED_INFERENCE) else: self._infer_type = None # read each line of csv files into an instance of Subject if self.is_training: file_lists = [] if self.action_param.validation_every_n > 0: file_lists.append(data_partitioner.train_files) file_lists.append(data_partitioner.validation_files) else: file_lists.append(data_partitioner.train_files) self.readers = [] for file_list in file_lists: reader = ImageReader(['image']) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) if self._infer_type in ('encode', 'encode-decode'): self.readers = [ImageReader(['image'])] self.readers[0].initialise(data_param, task_param, data_partitioner.inference_files) elif self._infer_type == 'sample': self.readers = [] elif self._infer_type == 'linear_interpolation': self.readers = [ImageReader(['feature'])] self.readers[0].initialise(data_param, task_param, data_partitioner.inference_files) # if self.is_training or self._infer_type in ('encode', 'encode-decode'): # mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') # self.reader.add_preprocessing_layers([mean_var_normaliser]) def initialise_sampler(self): self.sampler = [] if self.is_training: self.sampler.append([ ResizeSampler(reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle_buffer=True, queue_length=self.net_param.queue_length) for reader in self.readers ]) return if self._infer_type in ('encode', 'encode-decode'): self.sampler.append([ ResizeSampler(reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle_buffer=False, queue_length=self.net_param.queue_length) for reader in self.readers ]) return if self._infer_type == 'linear_interpolation': self.sampler.append([ LinearInterpolateSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, n_interpolations=self.autoencoder_param.n_interpolations, queue_length=self.net_param.queue_length) for reader in self.readers ]) return def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( w_regularizer=w_regularizer, b_regularizer=b_regularizer) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_output = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) data_loss = loss_func(net_output) loss = data_loss if self.net_param.decay > 0.0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = loss + reg_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) outputs_collector.add_to_collection(var=data_loss, name='variational_lower_bound', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='variational_lower_bound', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=net_output[4], name='Originals', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=net_output[2], name='Means', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=net_output[5], name='Variances', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) else: if self._infer_type in ('encode', 'encode-decode'): data_dict = self.get_sampler()[0][0].pop_batch_op() image = tf.cast(data_dict['image'], dtype=tf.float32) net_output = self.net(image, is_training=False) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=True, collection=NETWORK_OUTPUT) if self._infer_type == 'encode-decode': outputs_collector.add_to_collection( var=net_output[2], name='generated_image', average_over_devices=True, collection=NETWORK_OUTPUT) if self._infer_type == 'encode': outputs_collector.add_to_collection( var=net_output[7], name='embedded', average_over_devices=True, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir) return elif self._infer_type == 'sample': image_size = (self.net_param.batch_size,) + \ self.action_param.spatial_window_size + (1,) dummy_image = tf.zeros(image_size) net_output = self.net(dummy_image, is_training=False) noise_shape = net_output[-1].get_shape().as_list() stddev = self.autoencoder_param.noise_stddev noise = tf.random_normal(shape=noise_shape, mean=0.0, stddev=stddev, dtype=tf.float32) partially_decoded_sample = self.net.shared_decoder( noise, is_training=False) decoder_output = self.net.decoder_means( partially_decoded_sample, is_training=False) outputs_collector.add_to_collection(var=decoder_output, name='generated_image', average_over_devices=True, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=None, output_path=self.action_param.save_seg_dir) return elif self._infer_type == 'linear_interpolation': # construct the entire network image_size = (self.net_param.batch_size,) + \ self.action_param.spatial_window_size + (1,) dummy_image = tf.zeros(image_size) net_output = self.net(dummy_image, is_training=False) data_dict = self.get_sampler()[0][0].pop_batch_op() real_code = data_dict['feature'] real_code = tf.reshape(real_code, net_output[-1].get_shape()) partially_decoded_sample = self.net.shared_decoder( real_code, is_training=False) decoder_output = self.net.decoder_means( partially_decoded_sample, is_training=False) outputs_collector.add_to_collection(var=decoder_output, name='generated_image', average_over_devices=True, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['feature_location'], name='location', average_over_devices=True, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir) else: raise NotImplementedError def interpret_output(self, batch_output): if self.is_training: return True else: infer_type = look_up_operations( self.autoencoder_param.inference_type, SUPPORTED_INFERENCE) if infer_type == 'encode': return self.output_decoder.decode_batch( batch_output['embedded'], batch_output['location'][:, 0:1]) if infer_type == 'encode-decode': return self.output_decoder.decode_batch( batch_output['generated_image'], batch_output['location'][:, 0:1]) if infer_type == 'sample': return self.output_decoder.decode_batch( batch_output['generated_image'], None) if infer_type == 'linear_interpolation': return self.output_decoder.decode_batch( batch_output['generated_image'], batch_output['location'][:, :2])
class GANApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "GAN" def __init__(self, net_param, action_param, is_training): BaseApplication.__init__(self) tf.logging.info('starting GAN application') self.is_training = is_training self.net_param = net_param self.action_param = action_param self.data_param = None self.gan_param = None def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.gan_param = task_param # read each line of csv files into an instance of Subject if self.is_training: file_lists = [] if self.action_param.validation_every_n > 0: file_lists.append(data_partitioner.train_files) file_lists.append(data_partitioner.validation_files) else: file_lists.append(data_partitioner.train_files) self.readers = [] for file_list in file_lists: reader = ImageReader(['image', 'conditioning']) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) else: inference_reader = ImageReader(['conditioning']) file_list = data_partitioner.inference_files inference_reader.initialise(data_param, task_param, file_list) self.readers = [inference_reader] foreground_masking_layer = None if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append(RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append(RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1])) if self.action_param.rotation_angle: augmentation_layers.append(RandomRotationLayer()) augmentation_layers[-1].init_uniform_angle( self.action_param.rotation_angle) for reader in self.readers: reader.add_preprocessing_layers( normalisation_layers + augmentation_layers) def initialise_sampler(self): self.sampler = [] if self.is_training: self.sampler.append([ResizeSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle_buffer=True, queue_length=self.net_param.queue_length) for reader in self.readers]) else: self.sampler.append([RandomVectorSampler( names=('vector',), vector_size=(self.gan_param.noise_size,), batch_size=self.net_param.batch_size, n_interpolations=self.gan_param.n_interpolations, repeat=None, queue_length=self.net_param.queue_length) for _ in self.readers]) # repeat each resized image n times, so that each # image matches one random vector, # (n = self.gan_param.n_interpolations) self.sampler.append([ResizeSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.gan_param.n_interpolations, shuffle_buffer=False, queue_length=self.net_param.queue_length) for reader in self.readers]) def initialise_network(self): self.net = ApplicationNetFactory.create(self.net_param.name)() def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): if self.is_training: def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) images = tf.cast(data_dict['image'], tf.float32) noise_shape = [self.net_param.batch_size, self.gan_param.noise_size] noise = tf.random_normal(shape=noise_shape, mean=0.0, stddev=1.0, dtype=tf.float32) conditioning = data_dict['conditioning'] net_output = self.net( noise, images, conditioning, self.is_training) loss_func = LossFunction( loss_type=self.action_param.loss_type) real_logits = net_output[1] fake_logits = net_output[2] lossG, lossD = loss_func(real_logits, fake_logits) if self.net_param.decay > 0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(l_reg) for l_reg in reg_losses]) lossD = lossD + reg_loss lossG = lossG + reg_loss # variables to display in STDOUT outputs_collector.add_to_collection( var=lossD, name='lossD', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection( var=lossG, name='lossG', average_over_devices=False, collection=CONSOLE) # variables to display in tensorboard outputs_collector.add_to_collection( var=lossG, name='lossG', average_over_devices=False, collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=lossG, name='lossD', average_over_devices=True, collection=TF_SUMMARIES) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) with tf.name_scope('ComputeGradients'): # gradients of generator generator_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') generator_grads = self.optimiser.compute_gradients( lossG, var_list=generator_variables) # gradients of discriminator discriminator_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') discriminator_grads = self.optimiser.compute_gradients( lossD, var_list=discriminator_variables) grads = [generator_grads, discriminator_grads] # add the grads back to application_driver's training_grads gradients_collector.add_to_collection(grads) else: data_dict = self.get_sampler()[0][0].pop_batch_op() conditioning_dict = self.get_sampler()[1][0].pop_batch_op() conditioning = conditioning_dict['conditioning'] image_size = conditioning.shape.as_list()[:-1] dummy_image = tf.zeros(image_size + [1]) net_output = self.net(data_dict['vector'], dummy_image, conditioning, self.is_training) outputs_collector.add_to_collection( var=net_output[0], name='image', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=conditioning_dict['conditioning_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir) def interpret_output(self, batch_output): if self.is_training: return True return self.output_decoder.decode_batch( batch_output['image'], batch_output['location'])
class GANApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "GAN" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting GAN application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.gan_param = None def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.gan_param = task_param file_lists = self.get_file_lists(data_partitioner) # read each line of csv files into an instance of Subject if self.is_training: self.readers = [] for file_list in file_lists: reader = ImageReader(['image', 'conditioning']) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) elif self.is_inference: inference_reader = ImageReader(['conditioning']) inference_reader.initialise(data_param, task_param, file_lists[0]) self.readers = [inference_reader] elif self.is_evaluation: NotImplementedError('Evaluation is not yet ' 'supported in this application.') foreground_masking_layer = None if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1]) ) if self.action_param.rotation_angle: augmentation_layers.append(RandomRotationLayer()) augmentation_layers[-1].init_uniform_angle( self.action_param.rotation_angle) for reader in self.readers: reader.add_preprocessing_layers(normalisation_layers + augmentation_layers) def initialise_sampler(self): self.sampler = [] if self.is_training: self.sampler.append([ ResizeSampler(reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle_buffer=True, queue_length=self.net_param.queue_length) for reader in self.readers ]) else: self.sampler.append([ RandomVectorSampler( names=('vector', ), vector_size=(self.gan_param.noise_size, ), batch_size=self.net_param.batch_size, n_interpolations=self.gan_param.n_interpolations, repeat=None, queue_length=self.net_param.queue_length) for _ in self.readers ]) # repeat each resized image n times, so that each # image matches one random vector, # (n = self.gan_param.n_interpolations) self.sampler.append([ ResizeSampler( reader=reader, data_param=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.gan_param.n_interpolations, shuffle_buffer=False, queue_length=self.net_param.queue_length) for reader in self.readers ]) def initialise_network(self): self.net = ApplicationNetFactory.create(self.net_param.name)() def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): if self.is_training: def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) images = tf.cast(data_dict['image'], tf.float32) noise_shape = [ self.net_param.batch_size, self.gan_param.noise_size ] noise = tf.random_normal(shape=noise_shape, mean=0.0, stddev=1.0, dtype=tf.float32) conditioning = data_dict['conditioning'] net_output = self.net(noise, images, conditioning, self.is_training) loss_func = LossFunction(loss_type=self.action_param.loss_type) real_logits = net_output[1] fake_logits = net_output[2] lossG, lossD = loss_func(real_logits, fake_logits) if self.net_param.decay > 0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(l_reg) for l_reg in reg_losses]) lossD = lossD + reg_loss lossG = lossG + reg_loss # variables to display in STDOUT outputs_collector.add_to_collection(var=lossD, name='lossD', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=lossG, name='lossG', average_over_devices=False, collection=CONSOLE) # variables to display in tensorboard outputs_collector.add_to_collection(var=lossG, name='lossG', average_over_devices=False, collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=lossG, name='lossD', average_over_devices=True, collection=TF_SUMMARIES) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) with tf.name_scope('ComputeGradients'): # gradients of generator generator_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') generator_grads = self.optimiser.compute_gradients( lossG, var_list=generator_variables) # gradients of discriminator discriminator_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') discriminator_grads = self.optimiser.compute_gradients( lossD, var_list=discriminator_variables) grads = [generator_grads, discriminator_grads] # add the grads back to application_driver's training_grads gradients_collector.add_to_collection(grads) else: data_dict = self.get_sampler()[0][0].pop_batch_op() conditioning_dict = self.get_sampler()[1][0].pop_batch_op() conditioning = conditioning_dict['conditioning'] image_size = conditioning.shape.as_list()[:-1] dummy_image = tf.zeros(image_size + [1]) net_output = self.net(data_dict['vector'], dummy_image, conditioning, self.is_training) outputs_collector.add_to_collection(var=net_output[0], name='image', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=conditioning_dict['conditioning_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir) def interpret_output(self, batch_output): if self.is_training: return True return self.output_decoder.decode_batch(batch_output['image'], batch_output['location'])
class RegressionApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "REGRESSION" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting regression application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.regression_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), 'balanced': (self.initialise_balanced_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), } def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.regression_param = task_param # initialise input image readers if self.is_training: reader_names = ('image', 'output', 'weight', 'sampler') elif self.is_inference: # in the inference process use `image` input only reader_names = ('image', ) elif self.is_evaluation: reader_names = ('image', 'output', 'inferred') else: tf.logging.fatal('Action `%s` not supported. Expected one of %s', self.action, self.SUPPORTED_PHASES) raise ValueError try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by(phase=reader_phase, action=self.action) self.readers = [ ImageReader(reader_names).initialise(data_param, task_param, file_list) for file_list in file_lists ] # initialise input preprocessing layers mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') \ if self.net_param.whitening else None histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') \ if (self.net_param.histogram_ref_file and self.net_param.normalisation) else None rgb_normaliser = RGBHistogramEquilisationLayer( image_name='image', name='rbg_norm_layer' ) if self.net_param.rgb_normalisation else None normalisation_layers = [] if histogram_normaliser is not None: normalisation_layers.append(histogram_normaliser) if mean_var_normaliser is not None: normalisation_layers.append(mean_var_normaliser) if rgb_normaliser is not None: normalisation_layers.append(rgb_normaliser) volume_padding_layer = [ PadLayer(image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size, mode=self.net_param.volume_padding_mode, pad_to=self.net_param.volume_padding_to_size) ] # initialise training data augmentation layers augmentation_layers = [] if self.is_training: train_param = self.action_param if train_param.random_flipping_axes != -1: augmentation_layers.append( RandomFlipLayer( flip_axes=train_param.random_flipping_axes)) if train_param.scaling_percentage: augmentation_layers.append( RandomSpatialScalingLayer( min_percentage=train_param.scaling_percentage[0], max_percentage=train_param.scaling_percentage[1], antialiasing=train_param.antialiasing, isotropic=train_param.isotropic_scaling)) if train_param.rotation_angle: rotation_layer = RandomRotationLayer() if train_param.rotation_angle: rotation_layer.init_uniform_angle( train_param.rotation_angle) augmentation_layers.append(rotation_layer) if train_param.do_elastic_deformation: spatial_rank = list(self.readers[0].spatial_ranks.values())[0] augmentation_layers.append( RandomElasticDeformationLayer( spatial_rank=spatial_rank, num_controlpoints=train_param.num_ctrl_points, std_deformation_sigma=train_param.deformation_sigma, proportion_to_augment=train_param.proportion_to_deform) ) # only add augmentation to first reader (not validation reader) self.readers[0].add_preprocessing_layers(volume_padding_layer + normalisation_layers + augmentation_layers) for reader in self.readers[1:]: reader.add_preprocessing_layers(volume_padding_layer + normalisation_layers) def initialise_uniform_sampler(self): self.sampler = [[ UniformSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_weighted_sampler(self): self.sampler = [[ WeightedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_resize_sampler(self): self.sampler = [[ ResizeSampler(reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, shuffle=self.is_training, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_sampler(self): self.sampler = [[ GridSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_balanced_sampler(self): self.sampler = [[ BalancedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix, fill_constant=self.action_param.fill_constant) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_identity_aggregator(self): self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, postfix=self.action_param.output_postfix) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() elif self.is_inference: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_aggregator(self): if self.net_param.force_output_identity_resizing: self.initialise_identity_aggregator() else: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( num_classes=1, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: self.patience = self.action_param.patience self.mode = self.action_param.early_stopping_mode if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) weight_map = data_dict.get('weight', None) border = self.regression_param.loss_border if border == None or tf.reduce_sum(tf.abs(border)) == 0: data_loss = loss_func(prediction=net_out, ground_truth=data_dict['output'], weight_map=weight_map) else: crop_layer = CropLayer(border) weight_map = None if weight_map is None else crop_layer( weight_map) data_loss = loss_func(prediction=crop_layer(net_out), ground_truth=crop_layer( data_dict['output']), weight_map=weight_map) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss # Get all vars to_optimise = tf.trainable_variables() vars_to_freeze = \ self.action_param.vars_to_freeze or \ self.action_param.vars_to_restore if vars_to_freeze: import re var_regex = re.compile(vars_to_freeze) # Only optimise vars that are not frozen to_optimise = \ [v for v in to_optimise if not var_regex.search(v.name)] tf.logging.info( "Optimizing %d out of %d trainable variables, " "the other variables are fixed (--vars_to_freeze %s)", len(to_optimise), len(tf.trainable_variables()), vars_to_freeze) self.total_loss = loss grads = self.optimiser.compute_gradients( loss, var_list=to_optimise, colocate_gradients_with_ops=True) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=self.total_loss, name='total_loss', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=self.total_loss, name='total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) elif self.is_inference: data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) net_out = PostProcessingLayer('IDENTITY')(net_out) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): if self.is_inference: return self.output_decoder.decode_batch( {'window_reg': batch_output['window']}, batch_output['location']) return True def initialise_evaluator(self, eval_param): self.eval_param = eval_param self.evaluator = RegressionEvaluator(self.readers[0], self.regression_param, eval_param) def add_inferred_output(self, data_param, task_param): return self.add_inferred_output_like(data_param, task_param, 'output')
class MultiOutputApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting multioutput test') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.multioutput_param = None self.SUPPORTED_SAMPLING = { 'uniform': (self.initialise_uniform_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'weighted': (self.initialise_weighted_sampler, self.initialise_grid_sampler, self.initialise_grid_aggregator), 'resize': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_resize_aggregator), 'classifier': (self.initialise_resize_sampler, self.initialise_resize_sampler, self.initialise_classifier_aggregator), 'identity': (self.initialise_uniform_sampler, self.initialise_resize_sampler, self.initialise_identity_aggregator) } def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.multioutput_param = task_param # initialise input image readers if self.is_training: reader_names = ('image', 'label', 'weight', 'sampler') elif self.is_inference: # in the inference process use `image` input only reader_names = ('image', ) elif self.is_evaluation: reader_names = ('image', 'label', 'inferred') else: tf.logging.fatal('Action `%s` not supported. Expected one of %s', self.action, self.SUPPORTED_PHASES) raise ValueError try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by(phase=reader_phase, action=self.action) self.readers = [ ImageReader(reader_names).initialise(data_param, task_param, file_list) for file_list in file_lists ] def initialise_uniform_sampler(self): self.sampler = [[ UniformSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_weighted_sampler(self): self.sampler = [[ WeightedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_resize_sampler(self): self.sampler = [[ ResizeSampler(reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, shuffle=self.is_training, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_sampler(self): self.sampler = [[ GridSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, spatial_window_size=self.action_param.spatial_window_size, window_border=self.action_param.border, smaller_final_batch_mode=self.net_param. smaller_final_batch_mode, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_balanced_sampler(self): self.sampler = [[ BalancedSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=self.action_param.sample_per_volume, queue_length=self.net_param.queue_length) for reader in self.readers ]] def initialise_grid_aggregator(self): self.output_decoder = GridSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix, fill_constant=self.action_param.fill_constant) def initialise_resize_aggregator(self): self.output_decoder = ResizeSamplesAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, window_border=self.action_param.border, interp_order=self.action_param.output_interp_order, postfix=self.action_param.output_postfix) def initialise_identity_aggregator(self): self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir, postfix=self.action_param.output_postfix) def initialise_classifier_aggregator(self): pass # self.output_decoder = ClassifierSamplesAggregator( # image_reader=self.readers[0], # output_path=self.action_param.save_seg_dir, # postfix=self.action_param.output_postfix) def initialise_sampler(self): if self.is_training: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() elif self.is_inference: self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() def initialise_aggregator(self): self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create('toynet')( num_classes=self.multioutput_param.num_classes, w_initializer=InitializerFactory.get_initializer( name=self.net_param.weight_initializer), b_initializer=InitializerFactory.get_initializer( name=self.net_param.bias_initializer), w_regularizer=w_regularizer, b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: # extract data if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.multioutput_param.num_classes, loss_type=self.action_param.loss_type, softmax=self.multioutput_param.softmax) data_loss = loss_func(prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss # set the optimiser and the gradient to_optimise = tf.trainable_variables() vars_to_freeze = \ self.action_param.vars_to_freeze or \ self.action_param.vars_to_restore if vars_to_freeze: import re var_regex = re.compile(vars_to_freeze) # Only optimise vars that are not frozen to_optimise = \ [v for v in to_optimise if not var_regex.search(v.name)] tf.logging.info( "Optimizing %d out of %d trainable variables, " "the other variables fixed (--vars_to_freeze %s)", len(to_optimise), len(tf.trainable_variables()), vars_to_freeze) grads = self.optimiser.compute_gradients( loss, var_list=to_optimise, colocate_gradients_with_ops=True) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) elif self.is_inference: data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) num_classes = self.multioutput_param.num_classes argmax_layer = PostProcessingLayer('ARGMAX', num_classes=num_classes) softmax_layer = PostProcessingLayer('SOFTMAX', num_classes=num_classes) arg_max_out = argmax_layer(net_out) soft_max_out = softmax_layer(net_out) # sum_prob_out = tf.reshape(tf.reduce_sum(soft_max_out),[1,1]) # min_prob_out = tf.reshape(tf.reduce_min(soft_max_out),[1,1]) sum_prob_out = tf.reduce_sum(soft_max_out) min_prob_out = tf.reduce_min(soft_max_out) outputs_collector.add_to_collection(var=arg_max_out, name='window_argmax', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=soft_max_out, name='window_softmax', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=sum_prob_out, name='csv_sum', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection(var=min_prob_out, name='csv_min', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator() def interpret_output(self, batch_output): if self.is_inference: return self.output_decoder.decode_batch( { 'window_argmax': batch_output['window_argmax'], 'window_softmax': batch_output['window_softmax'], 'csv_sum': batch_output['csv_sum'], 'csv_min': batch_output['csv_min'] }, batch_output['location']) return True