def test_3d_argmax_shape(self): x = self.get_3d_input() post_process_layer = PostProcessingLayer("ARGMAX") out_post = post_process_layer(x) print(post_process_layer) with self.cached_session() as sess: out = sess.run(out_post) x_shape = tuple(x.shape.as_list()[:-1]) self.assertAllClose(x_shape + (1, ), out.shape)
def test_2d_shape(self): x = self.get_2d_input() post_process_layer = PostProcessingLayer("IDENTITY") out_post = post_process_layer(x) print(post_process_layer) with self.cached_session() as sess: out = sess.run(out_post) x_shape = tuple(x.shape.as_list()) self.assertAllClose(x_shape, out.shape)
def test_3d_shape(self): x = self.get_3d_input() post_process_layer = PostProcessingLayer("SOFTMAX") print(post_process_layer) out_post = post_process_layer(x) print(post_process_layer) with self.test_session() as sess: out = sess.run(out_post) x_shape = tuple(x.get_shape().as_list()) self.assertAllClose(x_shape, out.shape)
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) crop_layer = CropLayer(border=self.regression_param.loss_border, name='crop-88') prediction = crop_layer(net_out) ground_truth = crop_layer(data_dict.get('output', None)) weight_map = None if data_dict.get('weight', None) is None \ else crop_layer(data_dict.get('weight', None)) data_loss = loss_func(prediction=prediction, ground_truth=ground_truth, weight_map=weight_map) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # Gradient Clipping associated with VDSR3D # Gradients are clipped by value, instead of clipping by global norm. # The authors of VDSR do not specify a threshold for the clipping process. # grads2, vars2 = zip(*grads) # grads2, _ = tf.clip_by_global_norm(grads2, 5.0) # grads = zip(grads2, vars2) grads = [(tf.clip_by_value(grad, -0.00001 / self.action_param.lr, +0.00001 / self.action_param.lr), val) for grad, val in grads if grad is not None] # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='Loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='Loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) elif self.is_inference: data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) crop_layer = CropLayer(border=0, name='crop-88') post_process_layer = PostProcessingLayer('IDENTITY') net_out = post_process_layer(crop_layer(net_out)) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator()
def connect_data_and_network(self,outputs_collector=None, gradients_collector=None): #def data_net(for_training): # with tf.name_scope('train' if for_training else 'validation'): # sampler = self.get_sampler()[0][0 if for_training else -1] # data_dict = sampler.pop_batch_op() # image = tf.cast(data_dict['image'], tf.float32) # return data_dict, self.net(image, is_training=for_training) def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: print("-CONNECT DATA AND NETWORK -TRAINING---------------") #if self.action_param.validation_every_n > 0: # data_dict, net_out = tf.cond(tf.logical_not(self.is_validation), # lambda: data_net(True), # lambda: data_net(False)) #else: # data_dict, net_out = data_net(True) if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) #ADAM OPTIMISER self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) print("####################################nombre del optimiser: ",self.action_param.optimiser) print("##############################3learning rate: ", self.action_param.lr) #loss func loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) ground_truth=data_dict.get('label', None) weight_map=data_dict.get('weight', None) #data_loss, ONEHOT, IDS= loss_func( data_loss = loss_func( prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) ################################################################ ################################################################ #setting up printing variables self.loss_variable = data_loss firstSlice = ground_truth self.first_slice = firstSlice #self.first_slice = tf.squeeze(tf.slice(firstSlice, [0,0,0,60,0], [1,103,103,1,1])) #self.first_slice_cut = tf.slice(firstSlice, [0,52,52,60,1], [1,30,30,1,1]) netOut = tf.nn.softmax(net_out) self.netOut = netOut #self.netOut = tf.squeeze(netOut[0,50,50,1,:]) GROUNDTRUTH, PREDICTION, CONT = loss_func.return_loss_args() self.GROUNDTRUTH = GROUNDTRUTH self.PREDICTION = PREDICTION self.CONT = CONT self.SUMA = loss_func.SUMA print("Salio del seteo de variable en connect data and net") ################################################################ ################################################################ #calculating regularizers reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) print("############## que que e isso: ", tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) #grads2 = self.optimiser.compute_gradients(loss,[prediction]) self.GRADS = grads #print("#############GRADIENDSSSSSSSSSS", grads) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) ##################### outputs_collector.add_to_collection( var=image*180.0, name='image', average_over_devices=False, summary_type='image3_sagittal', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=image, name='image', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=tf.reduce_mean(image), name='mean_image', average_over_devices=False, summary_type='scalar', collection=CONSOLE) else: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = {'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob} net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): self.learning_rate = tf.placeholder(tf.float32, shape=[]) optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.learning_rate) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type, softmax=self.segmentation_param.softmax) data_loss = loss_func( prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight_map', None)) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss # Get all vars to_optimise = tf.trainable_variables() vars_to_freeze = \ self.action_param.vars_to_freeze or \ self.action_param.vars_to_restore if vars_to_freeze: import re var_regex = re.compile(vars_to_freeze) # Only optimise vars that are not frozen to_optimise = \ [v for v in to_optimise if not var_regex.search(v.name)] tf.logging.info( "Optimizing %d out of %d trainable variables, " "the other variables fixed (--vars_to_freeze %s)", len(to_optimise), len(tf.trainable_variables()), vars_to_freeze) grads = self.optimiser.compute_gradients( loss, var_list=to_optimise, colocate_gradients_with_ops=True) # clip gradients gradients, variables = zip(*grads) gradients, _ = tf.clip_by_global_norm(gradients, self.action_param.gradient_clipping_value) grads = list(zip(gradients, variables)) gnorm = tf.global_norm(list(gradients)) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=gnorm, name='gnorm', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=gnorm, name='gnorm', average_over_devices=False, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=self.learning_rate, name='lr', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) #outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) #outputs_collector.add_to_collection( # var=net_out, name='output', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image*180.0, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=tf.reduce_mean(image), name='mean_image', # average_over_devices=False, summary_type='scalar', # collection=CONSOLE) elif self.is_inference: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = {'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob} net_out = self.net(image, **net_args) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.classification_param.num_classes, loss_type=self.action_param.loss_type) data_loss = loss_func(prediction=net_out, ground_truth=data_dict.get('label', None)) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients( loss, colocate_gradients_with_ops=True) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='data_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) self.add_confusion_matrix_summaries_(outputs_collector, net_out, data_dict) else: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) tf.logging.info('net_out.shape may need to be resized: %s', net_out.shape) output_prob = self.classification_param.output_prob num_classes = self.classification_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( loss_type=self.action_param.loss_type) crop_layer = CropLayer( border=self.regression_param.loss_border, name='crop-88') prediction = crop_layer(net_out) ground_truth = crop_layer(data_dict.get('output', None)) weight_map = None if data_dict.get('weight', None) is None \ else crop_layer(data_dict.get('weight', None)) data_loss = loss_func(prediction=prediction, ground_truth=ground_truth, weight_map=weight_map) reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='Loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='Loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) else: data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) crop_layer = CropLayer(border=0, name='crop-88') post_process_layer = PostProcessingLayer('IDENTITY') net_out = post_process_layer(crop_layer(net_out)) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) data_loss = loss_func( prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='dice_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image*180.0, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=tf.reduce_mean(image), name='mean_image', # average_over_devices=False, summary_type='scalar', # collection=CONSOLE) else: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) crop_layer = CropLayer(border=self.regression_param.loss_border) weight_map = data_dict.get('weight', None) weight_map = None if weight_map is None else crop_layer(weight_map) data_loss = loss_func(prediction=crop_layer(net_out), ground_truth=crop_layer(data_dict['output']), weight_map=weight_map) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss # Get all vars to_optimise = tf.trainable_variables() vars_to_freeze = \ self.action_param.vars_to_freeze or \ self.action_param.vars_to_restore if vars_to_freeze: import re var_regex = re.compile(vars_to_freeze) # Only optimise vars that are not frozen to_optimise = \ [v for v in to_optimise if not var_regex.search(v.name)] tf.logging.info( "Optimizing %d out of %d trainable variables, " "the other variables are fixed (--vars_to_freeze %s)", len(to_optimise), len(tf.trainable_variables()), vars_to_freeze) grads = self.optimiser.compute_gradients( loss, var_list=to_optimise, colocate_gradients_with_ops=True) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) elif self.is_inference: data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) net_out = PostProcessingLayer('IDENTITY')(net_out) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() self.var = tf.placeholder_with_default(0, [], 'var') self.choices = tf.placeholder_with_default([True, True, True, True], [4], 'choices') if self.is_training: self.lr = tf.placeholder_with_default(self.action_param.lr, [], 'learning_rate') if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) image_unstack = tf.unstack (image, axis=-1) print('hellllo') print(image) net_img, post_param = self.net({MODALITIES_img[k]: tf.expand_dims(image_unstack[k],-1) for k in range(4)}, self.choices, is_training=self.is_training) net_seg = net_img['seg'] net_img = tf.concat([net_img[mod] for mod in MODALITIES_img],axis=-1) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.lr) print('seeg') gt = data_dict['label'] cross = LossFunction( n_class=4, loss_type='CrossEntropy') dice = LossFunction( n_class=4, loss_type='Dice', softmax=True) gt = data_dict['label'] loss_cross = cross(prediction=net_seg,ground_truth=gt, weight_map=None) loss_dice = dice(prediction=net_seg,ground_truth=gt) loss_seg = loss_cross + loss_dice print('output') print(net_img) loss_reconstruction = tf.reduce_mean(tf.square(net_img - image)) print('output_seg') print(net_seg) print('gt') print(gt) sum_inter_KLD = 0.0 sum_prior_KLD = 0.0 nb_skip = len(post_param) for k in range(nb_skip): inter_KLD, prior_KLD = compute_KLD(post_param[k]['mu'], post_param[k]['logvar'], self.choices) sum_inter_KLD += inter_KLD sum_prior_KLD += prior_KLD KLD = 1/nb_skip*sum_inter_KLD + 1/nb_skip*sum_prior_KLD data_loss = loss_seg + 0.1*KLD + 0.1*loss_reconstruction reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection( var=loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=KLD, name='KLD', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=loss_reconstruction, name='loss_reconstruction', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=self.choices, name='choices', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection( var=self.lr, name='lr', average_over_devices=False, collection=CONSOLE) elif self.is_inference: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) image = tf.unstack (image, axis=-1) choices = self.segmentation_param.choices print(self.segmentation_param.choices) choices = [str2bool(k) for k in choices] print('salut') print(choices) output_mod = self.segmentation_param.output_mod[0] post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=4) if output_mod =='seg': #predict using samples net_img, _ = self.net({MODALITIES_img[k]: tf.expand_dims(image[k],-1) for k in range(4)}, choices, is_training=True, is_inference=False) net_out = post_process_layer(net_img['seg']) else: #predict using means net_img, _ = self.net({MODALITIES_img[k]: tf.expand_dims(image[k],-1) for k in range(4)}, choices, is_training=True, is_inference=True) net_out = net_img[output_mod] outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): # def data_net(for_training): # with tf.name_scope('train' if for_training else 'validation'): # sampler = self.get_sampler()[0][0 if for_training else -1] # data_dict = sampler.pop_batch_op() # image = tf.cast(data_dict['image'], tf.float32) # return data_dict, self.net(image, is_training=for_training) def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() self.var = tf.placeholder_with_default(0, [], 'var') self.choices = tf.placeholder_with_default([True, True, True, True], [4], 'choices') if self.is_training: self.lr = tf.placeholder_with_default(self.action_param.lr, [], 'learning_rate') # if self.action_param.validation_every_n > 0: # data_dict, net_out = tf.cond(tf.logical_not(self.is_validation), # lambda: data_net(True), # lambda: data_net(False)) # else: # data_dict, net_out = data_net(True) if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: data_dict = switch_sampler(for_training=True) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.lr) image = tf.cast(data_dict['image'], tf.float32) image_unstack = tf.unstack(image, axis=-1) net_seg = self.net( { MODALITIES_img[k]: tf.expand_dims(image_unstack[k], -1) for k in range(4) }, self.choices, is_training=self.is_training) cross = LossFunction(n_class=4, loss_type='CrossEntropy') dice = LossFunction(n_class=4, loss_type='Dice', softmax=True) gt = data_dict['label'] loss_cross = cross(prediction=net_seg, ground_truth=gt, weight_map=None) loss_dice = dice(prediction=net_seg, ground_truth=gt) data_loss = loss_cross + loss_dice reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss grads = self.optimiser.compute_gradients( loss, colocate_gradients_with_ops=False) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=self.choices, name='choices', average_over_devices=False, collection=CONSOLE) elif self.is_inference: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) image = tf.unstack(image, axis=-1) choices = self.segmentation_param.choices choices = [str2bool(k) for k in choices] print(choices) net_seg = self.net( { MODALITIES_img[k]: tf.expand_dims(image[k], -1) for k in range(4) }, choices, is_training=True, is_inference=True) print('output') post_process_layer = PostProcessingLayer('ARGMAX', num_classes=4) net_seg = post_process_layer(net_seg) outputs_collector.add_to_collection(var=net_seg, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() def mixup_switch_sampler(for_training): # get first set of samples d_dict = switch_sampler(for_training=for_training) mix_fields = ('image', 'weight', 'label') if not for_training: with tf.name_scope('nomix'): # ensure label is appropriate for dense loss functions ground_truth = tf.cast(d_dict['label'], tf.int32) one_hot = tf.one_hot( tf.squeeze(ground_truth, axis=-1), depth=self.segmentation_param.num_classes) d_dict['label'] = one_hot else: with tf.name_scope('mixup'): # get the mixing parameter from the Beta distribution alpha = self.segmentation_param.mixup_alpha beta = tf.distributions.Beta(alpha, alpha) # 1, 1: uniform: rand_frac = beta.sample() # get another minibatch d_dict_to_mix = switch_sampler(for_training=True) # look at binarised labels: sort them if self.segmentation_param.mix_match: # sum up the positive labels to sort by their volumes inds1 = tf.argsort( tf.map_fn(tf.reduce_sum, tf.cast(d_dict['label'], tf.int64))) inds2 = tf.argsort( tf.map_fn( tf.reduce_sum, tf.cast(d_dict_to_mix['label'] > 0, tf.int64))) for field in [ field for field in mix_fields if field in d_dict ]: d_dict[field] = tf.gather(d_dict[field], indices=inds1) # note: sorted for opposite directions for d_dict_to_mix d_dict_to_mix[field] = tf.gather( d_dict_to_mix[field], indices=inds2[::-1]) # making the labels dense and one-hot for d in (d_dict, d_dict_to_mix): ground_truth = tf.cast(d['label'], tf.int32) one_hot = tf.one_hot( tf.squeeze(ground_truth, axis=-1), depth=self.segmentation_param.num_classes) d['label'] = one_hot # do the mixing for any fields that are relevant and present mixed_up = { field: d_dict[field] * rand_frac + d_dict_to_mix[field] * (1 - rand_frac) for field in mix_fields if field in d_dict } # reassign all relevant values in d_dict d_dict.update(mixed_up) return d_dict if self.is_training: if not self.segmentation_param.do_mixup: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(for_training=True), lambda: switch_sampler(for_training=False)) else: # mix up the samples if not in validation phase data_dict = tf.cond( tf.logical_not(self.is_validation), lambda: mixup_switch_sampler(for_training=True), lambda: mixup_switch_sampler(for_training=False )) # don't mix the validation image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type, softmax=self.segmentation_param.softmax) data_loss = loss_func(prediction=net_out, ground_truth=data_dict.get('label', None), weight_map=data_dict.get('weight', None)) reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.net_param.decay > 0.0 and reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss # Get all vars to_optimise = tf.trainable_variables() vars_to_freeze = \ self.action_param.vars_to_freeze or \ self.action_param.vars_to_restore if vars_to_freeze: import re var_regex = re.compile(vars_to_freeze) # Only optimise vars that are not frozen to_optimise = \ [v for v in to_optimise if not var_regex.search(v.name)] tf.logging.info( "Optimizing %d out of %d trainable variables, " "the other variables fixed (--vars_to_freeze %s)", len(to_optimise), len(tf.trainable_variables()), vars_to_freeze) grads = self.optimiser.compute_gradients( loss, var_list=to_optimise, colocate_gradients_with_ops=True) self.total_loss = loss # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=self.total_loss, name='total_loss', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection(var=self.total_loss, name='total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image*180.0, name='image', # average_over_devices=False, summary_type='image3_sagittal', # collection=TF_SUMMARIES) # outputs_collector.add_to_collection( # var=image, name='image', # average_over_devices=False, # collection=NETWORK_OUTPUT) # outputs_collector.add_to_collection( # var=tf.reduce_mean(image), name='mean_image', # average_over_devices=False, summary_type='scalar', # collection=CONSOLE) elif self.is_inference: # converting logits into final output for # classification probabilities or argmax classification labels data_dict = switch_sampler(for_training=False) image = tf.cast(data_dict['image'], tf.float32) net_args = { 'is_training': self.is_training, 'keep_prob': self.net_param.keep_prob } net_out = self.net(image, **net_args) output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT) self.initialise_aggregator()
def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): data_dict = self.get_sampler()[0].pop_batch_op() image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, self.is_training) if self.is_training: label = data_dict.get('label', None) # Changed label on 11/29/2017: This will generate a 2D label # from the 3D label provided in the input. Only suitable for STNeuroNet k = label.get_shape().as_list() label = tf.nn.max_pool3d(label, [1, 1, 1, k[3], 1], [1, 1, 1, 1, 1], 'VALID', data_format='NDHWC') print('label shape is{}'.format(label.get_shape())) print('Image shape is{}'.format(image.get_shape())) print('Out shape is{}'.format(net_out.get_shape())) #### with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction( n_class=self.segmentation_param.num_classes, loss_type=self.action_param.loss_type) data_loss = loss_func(prediction=net_out, ground_truth=label, weight_map=data_dict.get('weight', None)) if self.net_param.decay > 0.0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = data_loss + reg_loss else: loss = data_loss grads = self.optimiser.compute_gradients(loss) # collecting gradients variables gradients_collector.add_to_collection([grads]) # collecting output variables outputs_collector.add_to_collection(var=data_loss, name='dice_loss', average_over_devices=False, collection=CONSOLE) outputs_collector.add_to_collection(var=data_loss, name='dice_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) # ADDED on 10/30 by Soltanian-Zadeh for tensorboard visualization seg_summary = tf.to_float( tf.expand_dims(tf.argmax(net_out, -1), -1)) * ( 255. / self.segmentation_param.num_classes - 1) label_summary = tf.to_float(tf.expand_dims( label, -1)) * (255. / self.segmentation_param.num_classes - 1) m, v = tf.nn.moments(image, axes=[1, 2, 3], keep_dims=True) img_summary = tf.minimum( 255., tf.maximum(0., (tf.to_float(image - m) / (tf.sqrt(v) * 2.) + 1.) * 127.)) image3_axial('img', img_summary, 50, [tf.GraphKeys.SUMMARIES]) image3_axial('seg', seg_summary, 5, [tf.GraphKeys.SUMMARIES]) image3_axial('label', label_summary, 5, [tf.GraphKeys.SUMMARIES]) else: # converting logits into final output for # classification probabilities or argmax classification labels output_prob = self.segmentation_param.output_prob num_classes = self.segmentation_param.num_classes if output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'SOFTMAX', num_classes=num_classes) elif not output_prob and num_classes > 1: post_process_layer = PostProcessingLayer( 'ARGMAX', num_classes=num_classes) else: post_process_layer = PostProcessingLayer( 'IDENTITY', num_classes=num_classes) net_out = post_process_layer(net_out) print('output shape is{}'.format(net_out.get_shape())) outputs_collector.add_to_collection(var=net_out, name='window', average_over_devices=False, collection=NETORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETORK_OUTPUT) init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator()