コード例 #1
0
    def test_3d_argmax_shape(self):
        x = self.get_3d_input()
        post_process_layer = PostProcessingLayer("ARGMAX")
        out_post = post_process_layer(x)
        print(post_process_layer)

        with self.cached_session() as sess:
            out = sess.run(out_post)
            x_shape = tuple(x.shape.as_list()[:-1])
            self.assertAllClose(x_shape + (1, ), out.shape)
コード例 #2
0
    def test_2d_shape(self):
        x = self.get_2d_input()
        post_process_layer = PostProcessingLayer("IDENTITY")
        out_post = post_process_layer(x)
        print(post_process_layer)

        with self.cached_session() as sess:
            out = sess.run(out_post)
            x_shape = tuple(x.shape.as_list())
            self.assertAllClose(x_shape, out.shape)
コード例 #3
0
    def test_3d_shape(self):
        x = self.get_3d_input()
        post_process_layer = PostProcessingLayer("SOFTMAX")
        print(post_process_layer)
        out_post = post_process_layer(x)
        print(post_process_layer)

        with self.test_session() as sess:
            out = sess.run(out_post)
            x_shape = tuple(x.get_shape().as_list())
            self.assertAllClose(x_shape, out.shape)
コード例 #4
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(loss_type=self.action_param.loss_type)

            crop_layer = CropLayer(border=self.regression_param.loss_border,
                                   name='crop-88')
            prediction = crop_layer(net_out)
            ground_truth = crop_layer(data_dict.get('output', None))
            weight_map = None if data_dict.get('weight', None) is None \
                else crop_layer(data_dict.get('weight', None))
            data_loss = loss_func(prediction=prediction,
                                  ground_truth=ground_truth,
                                  weight_map=weight_map)

            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)

            # Gradient Clipping associated with VDSR3D
            # Gradients are clipped by value, instead of clipping by global norm.
            # The authors of VDSR do not specify a threshold for the clipping process.
            # grads2, vars2 = zip(*grads)
            # grads2, _ = tf.clip_by_global_norm(grads2, 5.0)
            # grads = zip(grads2, vars2)
            grads = [(tf.clip_by_value(grad, -0.00001 / self.action_param.lr,
                                       +0.00001 / self.action_param.lr), val)
                     for grad, val in grads if grad is not None]

            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='Loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='Loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
        elif self.is_inference:
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            crop_layer = CropLayer(border=0, name='crop-88')
            post_process_layer = PostProcessingLayer('IDENTITY')
            net_out = post_process_layer(crop_layer(net_out))

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()
コード例 #5
0
    def connect_data_and_network(self,outputs_collector=None, gradients_collector=None):
        #def data_net(for_training):
        #    with tf.name_scope('train' if for_training else 'validation'):
        #        sampler = self.get_sampler()[0][0 if for_training else -1]
        #        data_dict = sampler.pop_batch_op()
        #        image = tf.cast(data_dict['image'], tf.float32)
        #        return data_dict, self.net(image, is_training=for_training)

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:

            print("-CONNECT DATA AND NETWORK -TRAINING---------------")
            #if self.action_param.validation_every_n > 0:
            #    data_dict, net_out = tf.cond(tf.logical_not(self.is_validation),
            #                                 lambda: data_net(True),
            #                                 lambda: data_net(False))
            #else:
            #    data_dict, net_out = data_net(True)
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)
            
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)
            

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)       #ADAM OPTIMISER
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)


            
            print("####################################nombre del optimiser: ",self.action_param.optimiser)
            print("##############################3learning rate: ", self.action_param.lr)

            #loss func
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)            

            ground_truth=data_dict.get('label', None)
            weight_map=data_dict.get('weight', None)
           
            #data_loss, ONEHOT, IDS= loss_func(
            data_loss = loss_func(
                prediction=net_out,
                ground_truth=data_dict.get('label', None),
                weight_map=data_dict.get('weight', None))            
            
            ################################################################
            ################################################################
            #setting up printing variables
            
            self.loss_variable = data_loss
            firstSlice = ground_truth            
            self.first_slice = firstSlice 
            #self.first_slice = tf.squeeze(tf.slice(firstSlice, [0,0,0,60,0], [1,103,103,1,1]))
            #self.first_slice_cut = tf.slice(firstSlice, [0,52,52,60,1], [1,30,30,1,1])
            netOut = tf.nn.softmax(net_out)
            self.netOut = netOut
            #self.netOut = tf.squeeze(netOut[0,50,50,1,:])
            
            GROUNDTRUTH, PREDICTION, CONT = loss_func.return_loss_args()
            self.GROUNDTRUTH = GROUNDTRUTH
            self.PREDICTION = PREDICTION
            self.CONT = CONT
            self.SUMA = loss_func.SUMA
            
            print("Salio del seteo de variable en connect data and net")
            ################################################################
            ################################################################            
            #calculating regularizers
            reg_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            
            print("############## que que e isso: ", tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss


            grads = self.optimiser.compute_gradients(loss)

            
            #grads2 = self.optimiser.compute_gradients(loss,[prediction])
            self.GRADS = grads
            #print("#############GRADIENDSSSSSSSSSS", grads)

            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)

            #####################
            outputs_collector.add_to_collection(
                var=image*180.0, name='image',
                average_over_devices=False, summary_type='image3_sagittal',
                collection=TF_SUMMARIES)

            outputs_collector.add_to_collection(
                var=image, name='image',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)

            outputs_collector.add_to_collection(
                var=tf.reduce_mean(image), name='mean_image',
                average_over_devices=False, summary_type='scalar',
                collection=CONSOLE)
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()
コード例 #6
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {'is_training': self.is_training,
                        'keep_prob': self.net_param.keep_prob}
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                self.learning_rate = tf.placeholder(tf.float32, shape=[])
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.learning_rate)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type,
                softmax=self.segmentation_param.softmax)
            data_loss = loss_func(
                prediction=net_out,
                ground_truth=data_dict.get('label', None),
                weight_map=data_dict.get('weight_map', None))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss

            # Get all vars
            to_optimise = tf.trainable_variables()
            vars_to_freeze = \
                self.action_param.vars_to_freeze or \
                self.action_param.vars_to_restore
            if vars_to_freeze:
                import re
                var_regex = re.compile(vars_to_freeze)
                # Only optimise vars that are not frozen
                to_optimise = \
                    [v for v in to_optimise if not var_regex.search(v.name)]
                tf.logging.info(
                    "Optimizing %d out of %d trainable variables, "
                    "the other variables fixed (--vars_to_freeze %s)",
                    len(to_optimise),
                    len(tf.trainable_variables()),
                    vars_to_freeze)

            grads = self.optimiser.compute_gradients(
                loss, var_list=to_optimise, colocate_gradients_with_ops=True)
             # clip gradients
            gradients, variables = zip(*grads)
            gradients, _ = tf.clip_by_global_norm(gradients, self.action_param.gradient_clipping_value)
            grads = list(zip(gradients, variables))
            gnorm = tf.global_norm(list(gradients))


            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=gnorm, name='gnorm',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=gnorm, name='gnorm',
                average_over_devices=False, summary_type='scalar',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=data_loss, name='loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=self.learning_rate, name='lr',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)

            #outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)
            #outputs_collector.add_to_collection(
            #    var=net_out, name='output',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image*180.0, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False,
            #    collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #    var=tf.reduce_mean(image), name='mean_image',
            #    average_over_devices=False, summary_type='scalar',
            #    collection=CONSOLE)
        elif self.is_inference:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {'is_training': self.is_training,
                        'keep_prob': self.net_param.keep_prob}
            net_out = self.net(image, **net_args)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            self.initialise_aggregator()
コード例 #7
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.classification_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=data_dict.get('label', None))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(
                loss, colocate_gradients_with_ops=True)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='data_loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='data_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            self.add_confusion_matrix_summaries_(outputs_collector, net_out,
                                                 data_dict)
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)
            tf.logging.info('net_out.shape may need to be resized: %s',
                            net_out.shape)
            output_prob = self.classification_param.output_prob
            num_classes = self.classification_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()
コード例 #8
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                loss_type=self.action_param.loss_type)

            crop_layer = CropLayer(
                border=self.regression_param.loss_border, name='crop-88')
            prediction = crop_layer(net_out)
            ground_truth = crop_layer(data_dict.get('output', None))
            weight_map = None if data_dict.get('weight', None) is None \
                else crop_layer(data_dict.get('weight', None))
            data_loss = loss_func(prediction=prediction,
                                  ground_truth=ground_truth,
                                  weight_map=weight_map)

            reg_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='Loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='Loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)
        else:
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            crop_layer = CropLayer(border=0, name='crop-88')
            post_process_layer = PostProcessingLayer('IDENTITY')
            net_out = post_process_layer(crop_layer(net_out))

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()
コード例 #9
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:

            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(
                prediction=net_out,
                ground_truth=data_dict.get('label', None),
                weight_map=data_dict.get('weight', None))
            reg_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image*180.0, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False,
            #    collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #    var=tf.reduce_mean(image), name='mean_image',
            #    average_over_devices=False, summary_type='scalar',
            #    collection=CONSOLE)
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()
コード例 #10
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(loss_type=self.action_param.loss_type)

            crop_layer = CropLayer(border=self.regression_param.loss_border)
            weight_map = data_dict.get('weight', None)
            weight_map = None if weight_map is None else crop_layer(weight_map)
            data_loss = loss_func(prediction=crop_layer(net_out),
                                  ground_truth=crop_layer(data_dict['output']),
                                  weight_map=weight_map)
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss

            # Get all vars
            to_optimise = tf.trainable_variables()
            vars_to_freeze = \
                self.action_param.vars_to_freeze or \
                self.action_param.vars_to_restore
            if vars_to_freeze:
                import re
                var_regex = re.compile(vars_to_freeze)
                # Only optimise vars that are not frozen
                to_optimise = \
                    [v for v in to_optimise if not var_regex.search(v.name)]
                tf.logging.info(
                    "Optimizing %d out of %d trainable variables, "
                    "the other variables are fixed (--vars_to_freeze %s)",
                    len(to_optimise), len(tf.trainable_variables()),
                    vars_to_freeze)

            grads = self.optimiser.compute_gradients(
                loss, var_list=to_optimise, colocate_gradients_with_ops=True)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
        elif self.is_inference:
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)
            net_out = PostProcessingLayer('IDENTITY')(net_out)

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()
コード例 #11
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()


        self.var = tf.placeholder_with_default(0, [], 'var')
        self.choices = tf.placeholder_with_default([True, True, True, True], [4], 'choices')
        

        if self.is_training:
            self.lr = tf.placeholder_with_default(self.action_param.lr, [], 'learning_rate')
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            image_unstack = tf.unstack (image, axis=-1)
            print('hellllo')
            print(image)
            net_img, post_param = self.net({MODALITIES_img[k]: tf.expand_dims(image_unstack[k],-1) for k in range(4)}, self.choices, is_training=self.is_training)

            net_seg = net_img['seg']
            net_img = tf.concat([net_img[mod] for mod in MODALITIES_img],axis=-1)
            

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.lr)
             
            print('seeg')
            gt =  data_dict['label']

            cross = LossFunction(
                n_class=4,
                loss_type='CrossEntropy')

            dice = LossFunction(
                n_class=4,
                loss_type='Dice',
                softmax=True)


            gt =  data_dict['label']
            loss_cross = cross(prediction=net_seg,ground_truth=gt, weight_map=None)
            loss_dice = dice(prediction=net_seg,ground_truth=gt)
            loss_seg = loss_cross + loss_dice
            
            print('output')
            print(net_img)
            loss_reconstruction = tf.reduce_mean(tf.square(net_img - image))


            print('output_seg')
            print(net_seg)

            print('gt')
            print(gt)

            sum_inter_KLD = 0.0
            sum_prior_KLD = 0.0
            nb_skip = len(post_param)
            for k in range(nb_skip):
                inter_KLD, prior_KLD = compute_KLD(post_param[k]['mu'], post_param[k]['logvar'], self.choices)
                sum_inter_KLD += inter_KLD
                sum_prior_KLD += prior_KLD

            KLD = 1/nb_skip*sum_inter_KLD + 1/nb_skip*sum_prior_KLD

            data_loss =  loss_seg + 0.1*KLD + 0.1*loss_reconstruction

            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            reg_loss = tf.reduce_mean(
                [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
            loss = data_loss + reg_loss

            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=loss, name='loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=loss, name='loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=KLD, name='KLD',
                average_over_devices=False, collection=CONSOLE)   
            outputs_collector.add_to_collection(
                var=loss_reconstruction, name='loss_reconstruction',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=self.choices, name='choices',
                average_over_devices=False, collection=CONSOLE)                   
            outputs_collector.add_to_collection(
                var=self.lr, name='lr',
                average_over_devices=False, collection=CONSOLE)

        elif self.is_inference:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)

            image = tf.unstack (image, axis=-1)
            choices = self.segmentation_param.choices
            print(self.segmentation_param.choices)
            choices = [str2bool(k) for k in choices]
            print('salut')
            print(choices)
            output_mod = self.segmentation_param.output_mod[0]

            post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=4)
            if output_mod =='seg':
                #predict using samples
                net_img, _ = self.net({MODALITIES_img[k]: tf.expand_dims(image[k],-1) for k in range(4)}, choices, is_training=True, is_inference=False)
                net_out = post_process_layer(net_img['seg'])
            else:
                #predict using means
                net_img, _ = self.net({MODALITIES_img[k]: tf.expand_dims(image[k],-1) for k in range(4)}, choices, is_training=True, is_inference=True)
                net_out = net_img[output_mod]

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            self.initialise_aggregator()
コード例 #12
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        # def data_net(for_training):
        #    with tf.name_scope('train' if for_training else 'validation'):
        #        sampler = self.get_sampler()[0][0 if for_training else -1]
        #        data_dict = sampler.pop_batch_op()
        #        image = tf.cast(data_dict['image'], tf.float32)
        #        return data_dict, self.net(image, is_training=for_training)

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        self.var = tf.placeholder_with_default(0, [], 'var')
        self.choices = tf.placeholder_with_default([True, True, True, True],
                                                   [4], 'choices')

        if self.is_training:
            self.lr = tf.placeholder_with_default(self.action_param.lr, [],
                                                  'learning_rate')
            # if self.action_param.validation_every_n > 0:
            #    data_dict, net_out = tf.cond(tf.logical_not(self.is_validation),
            #                                 lambda: data_net(True),
            #                                 lambda: data_net(False))
            # else:
            #    data_dict, net_out = data_net(True)
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.lr)

            image = tf.cast(data_dict['image'], tf.float32)
            image_unstack = tf.unstack(image, axis=-1)

            net_seg = self.net(
                {
                    MODALITIES_img[k]: tf.expand_dims(image_unstack[k], -1)
                    for k in range(4)
                },
                self.choices,
                is_training=self.is_training)

            cross = LossFunction(n_class=4, loss_type='CrossEntropy')

            dice = LossFunction(n_class=4, loss_type='Dice', softmax=True)

            gt = data_dict['label']
            loss_cross = cross(prediction=net_seg,
                               ground_truth=gt,
                               weight_map=None)
            loss_dice = dice(prediction=net_seg, ground_truth=gt)
            data_loss = loss_cross + loss_dice

            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            reg_loss = tf.reduce_mean(
                [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
            loss = data_loss + reg_loss

            grads = self.optimiser.compute_gradients(
                loss, colocate_gradients_with_ops=False)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=loss,
                                                name='loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=self.choices,
                                                name='choices',
                                                average_over_devices=False,
                                                collection=CONSOLE)

        elif self.is_inference:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            image = tf.unstack(image, axis=-1)

            choices = self.segmentation_param.choices
            choices = [str2bool(k) for k in choices]
            print(choices)

            net_seg = self.net(
                {
                    MODALITIES_img[k]: tf.expand_dims(image[k], -1)
                    for k in range(4)
                },
                choices,
                is_training=True,
                is_inference=True)

            print('output')
            post_process_layer = PostProcessingLayer('ARGMAX', num_classes=4)
            net_seg = post_process_layer(net_seg)
            outputs_collector.add_to_collection(var=net_seg,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()
コード例 #13
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        def mixup_switch_sampler(for_training):
            # get first set of samples
            d_dict = switch_sampler(for_training=for_training)

            mix_fields = ('image', 'weight', 'label')

            if not for_training:
                with tf.name_scope('nomix'):
                    # ensure label is appropriate for dense loss functions
                    ground_truth = tf.cast(d_dict['label'], tf.int32)
                    one_hot = tf.one_hot(
                        tf.squeeze(ground_truth, axis=-1),
                        depth=self.segmentation_param.num_classes)
                    d_dict['label'] = one_hot
            else:
                with tf.name_scope('mixup'):
                    # get the mixing parameter from the Beta distribution
                    alpha = self.segmentation_param.mixup_alpha
                    beta = tf.distributions.Beta(alpha,
                                                 alpha)  # 1, 1: uniform:
                    rand_frac = beta.sample()

                    # get another minibatch
                    d_dict_to_mix = switch_sampler(for_training=True)

                    # look at binarised labels: sort them
                    if self.segmentation_param.mix_match:
                        # sum up the positive labels to sort by their volumes
                        inds1 = tf.argsort(
                            tf.map_fn(tf.reduce_sum,
                                      tf.cast(d_dict['label'], tf.int64)))
                        inds2 = tf.argsort(
                            tf.map_fn(
                                tf.reduce_sum,
                                tf.cast(d_dict_to_mix['label'] > 0, tf.int64)))
                        for field in [
                                field for field in mix_fields
                                if field in d_dict
                        ]:
                            d_dict[field] = tf.gather(d_dict[field],
                                                      indices=inds1)
                            # note: sorted for opposite directions for d_dict_to_mix
                            d_dict_to_mix[field] = tf.gather(
                                d_dict_to_mix[field], indices=inds2[::-1])

                    # making the labels dense and one-hot
                    for d in (d_dict, d_dict_to_mix):
                        ground_truth = tf.cast(d['label'], tf.int32)
                        one_hot = tf.one_hot(
                            tf.squeeze(ground_truth, axis=-1),
                            depth=self.segmentation_param.num_classes)
                        d['label'] = one_hot

                    # do the mixing for any fields that are relevant and present
                    mixed_up = {
                        field: d_dict[field] * rand_frac +
                        d_dict_to_mix[field] * (1 - rand_frac)
                        for field in mix_fields if field in d_dict
                    }
                    # reassign all relevant values in d_dict
                    d_dict.update(mixed_up)

            return d_dict

        if self.is_training:
            if not self.segmentation_param.do_mixup:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                # mix up the samples if not in validation phase
                data_dict = tf.cond(
                    tf.logical_not(self.is_validation),
                    lambda: mixup_switch_sampler(for_training=True),
                    lambda: mixup_switch_sampler(for_training=False
                                                 ))  # don't mix the validation

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type,
                softmax=self.segmentation_param.softmax)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=data_dict.get('label', None),
                                  weight_map=data_dict.get('weight', None))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss

            # Get all vars
            to_optimise = tf.trainable_variables()
            vars_to_freeze = \
                self.action_param.vars_to_freeze or \
                self.action_param.vars_to_restore
            if vars_to_freeze:
                import re
                var_regex = re.compile(vars_to_freeze)
                # Only optimise vars that are not frozen
                to_optimise = \
                    [v for v in to_optimise if not var_regex.search(v.name)]
                tf.logging.info(
                    "Optimizing %d out of %d trainable variables, "
                    "the other variables fixed (--vars_to_freeze %s)",
                    len(to_optimise), len(tf.trainable_variables()),
                    vars_to_freeze)

            grads = self.optimiser.compute_gradients(
                loss, var_list=to_optimise, colocate_gradients_with_ops=True)

            self.total_loss = loss

            # collecting gradients variables
            gradients_collector.add_to_collection([grads])

            # collecting output variables
            outputs_collector.add_to_collection(var=self.total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=self.total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image*180.0, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False,
            #    collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #    var=tf.reduce_mean(image), name='mean_image',
            #    average_over_devices=False, summary_type='scalar',
            #    collection=CONSOLE)
        elif self.is_inference:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()
コード例 #14
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        data_dict = self.get_sampler()[0].pop_batch_op()
        image = tf.cast(data_dict['image'], tf.float32)
        net_out = self.net(image, self.is_training)

        if self.is_training:

            label = data_dict.get('label', None)

            # Changed label on 11/29/2017: This will generate a 2D label
            # from the 3D label provided in the input. Only suitable for STNeuroNet
            k = label.get_shape().as_list()
            label = tf.nn.max_pool3d(label, [1, 1, 1, k[3], 1],
                                     [1, 1, 1, 1, 1],
                                     'VALID',
                                     data_format='NDHWC')
            print('label shape is{}'.format(label.get_shape()))
            print('Image shape is{}'.format(image.get_shape()))
            print('Out shape is{}'.format(net_out.get_shape()))
            ####
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=label,
                                  weight_map=data_dict.get('weight', None))
            if self.net_param.decay > 0.0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                    loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            # ADDED on 10/30 by Soltanian-Zadeh for tensorboard visualization
            seg_summary = tf.to_float(
                tf.expand_dims(tf.argmax(net_out, -1), -1)) * (
                    255. / self.segmentation_param.num_classes - 1)
            label_summary = tf.to_float(tf.expand_dims(
                label, -1)) * (255. / self.segmentation_param.num_classes - 1)
            m, v = tf.nn.moments(image, axes=[1, 2, 3], keep_dims=True)
            img_summary = tf.minimum(
                255.,
                tf.maximum(0., (tf.to_float(image - m) /
                                (tf.sqrt(v) * 2.) + 1.) * 127.))
            image3_axial('img', img_summary, 50, [tf.GraphKeys.SUMMARIES])
            image3_axial('seg', seg_summary, 5, [tf.GraphKeys.SUMMARIES])
            image3_axial('label', label_summary, 5, [tf.GraphKeys.SUMMARIES])
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)
            print('output shape is{}'.format(net_out.get_shape()))

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()