Пример #1
0
    def layer_op(self, input_tensor, is_training, layer_id=-1):
        hyper = self.hyperparameters

        # Initialize DenseVNet network layers
        net = self.create_network()

        #
        # Parameter handling
        #

        # Shape and dimension variable shortcuts
        channel_dim = len(input_tensor.shape) - 1
        input_size = input_tensor.shape.as_list()
        spatial_size = input_size[1:-1]
        n_spatial_dims = input_tensor.shape.ndims - 2

        # Quick access to hyperparams
        pkeep = hyper['p_channels_selected']

        # Validate input dimension with dilation rates
        modulo = 2 ** (len(hyper['dilation_rates']))
        assert layer_util.check_spatial_dims(input_tensor,
                                             lambda x: x % modulo == 0)

        #
        # Augmentation + Downsampling + Initial Layers
        #

        # On the fly data augmentation
        if is_training and hyper['augmentation_scale'] > 0:
            if n_spatial_dims == 2:
                augmentation_class = Affine2DAugmentationLayer
            elif n_spatial_dims == 3:
                augmentation_class = Affine3DAugmentationLayer
            else:
                raise NotImplementedError(
                    'Affine augmentation only supports 2D and 3D images')

            augment_layer = augmentation_class(hyper['augmentation_scale'],
                                               'LINEAR', 'ZERO')
            input_tensor = augment_layer(input_tensor)

        # Variable storing all intermediate results -- VLinks
        all_segmentation_features = []

        # Downsample input to the network
        down_tensor = self.downsample_input(input_tensor, n_spatial_dims)
        downsampled_img = net.initial_bn(down_tensor, is_training=is_training)

        # Add initial downsampled image VLink
        all_segmentation_features.append(downsampled_img)

        # All results should match the downsampled input's shape
        output_shape = downsampled_img.shape.as_list()[1:-1]

        init_features = net.initial_conv(input_tensor, is_training=is_training)

        #
        # Dense VNet Main Block
        #

        # `down` will handle the input of each Dense VNet block
        # Initialize it by stacking downsampled image and initial conv features
        down = tf.concat([downsampled_img, init_features], channel_dim)

        # Process Dense VNet Blocks
        for dblock in net.dense_vblocks:
            # Get skip layer and activation output
            skip, down = dblock(down, is_training=is_training, keep_prob=pkeep)

            # Resize skip layer to original shape and add VLink
            skip = image_resize(skip, output_shape)
            all_segmentation_features.append(skip)

        # Concatenate all intermediate skip layers
        inter_results = tf.concat(all_segmentation_features, channel_dim)

        # Initial segmentation output
        seg_output = net.seg_layer(inter_results, is_training=is_training)

        #
        # Dense VNet End - Now postprocess outputs
        #

        # Refine segmentation with prior if any
        if self.architecture_parameters['use_prior']:
            xyz_prior = SpatialPriorBlock([12] * n_spatial_dims, output_shape)
            seg_output += xyz_prior

        # Invert augmentation if any
        if is_training and hyper['augmentation_scale'] > 0:
            inverse_aug = augment_layer.inverse()
            seg_output = inverse_aug(seg_output)

        # Resize output to original size
        seg_output = image_resize(seg_output, spatial_size)

        # Segmentation results
        seg_argmax = tf.to_float(tf.expand_dims(tf.argmax(seg_output, -1), -1))
        seg_summary = seg_argmax * (255. / self.num_classes - 1)

        # Image Summary
        norm_axes = list(range(1, n_spatial_dims+1))
        mean, var = tf.nn.moments(input_tensor, axes=norm_axes, keep_dims=True)
        timg = tf.to_float(input_tensor - mean) / (tf.sqrt(var) * 2.)
        timg = (timg + 1.) * 127.
        single_channel = tf.reduce_mean(timg, axis=-1, keep_dims=True)
        img_summary = tf.minimum(255., tf.maximum(0., single_channel))
        if n_spatial_dims == 2:
            tf.summary.image(
                tf.get_default_graph().unique_name('imgseg'),
                tf.concat([img_summary, seg_summary], 1),
                5, [tf.GraphKeys.SUMMARIES])
        elif n_spatial_dims == 3:
            # Show summaries
            image3_axial(
                tf.get_default_graph().unique_name('imgseg'),
                tf.concat([img_summary, seg_summary], 1),
                5, [tf.GraphKeys.SUMMARIES])
        else:
            raise NotImplementedError(
                'Image Summary only supports 2D and 3D images')

        return seg_output
Пример #2
0
    def layer_op(self,
                 input_tensor,
                 is_training=True,
                 layer_id=-1,
                 keep_prob=0.5,
                 **unused_kwargs):
        """

        :param input_tensor: tensor to input to the network, size has to be divisible by 2*dilation_rates
        :param is_training: boolean, True if network is in training mode
        :param layer_id: not in use
        :param keep_prob: double, percentage of nodes to keep for drop-out
        :param unused_kwargs:
        :return: network prediction
        """
        hyperparams = self.hyperparams

        # Validate that dilation rates are compatible with input dimensions
        modulo = 2**(len(hyperparams['dilation_rates']))
        assert layer_util.check_spatial_dims(input_tensor,
                                             lambda x: x % modulo == 0)

        # Perform on the fly data augmentation
        if is_training and hyperparams['augmentation_scale'] > 0:
            augment_layer = AffineAugmentationLayer(
                hyperparams['augmentation_scale'], 'LINEAR', 'ZERO')
            input_tensor = augment_layer(input_tensor)

        ###################
        ### Feedforward ###
        ###################

        # Initialize network components
        dense_vnet = self.create_network()

        # Store output feature maps from each component
        feature_maps = []

        # Downsample input to the network
        downsample_layer = DownSampleLayer(func='AVG', kernel_size=3, stride=2)
        downsampled_tensor = downsample_layer(input_tensor)
        bn_layer = BNLayer()
        downsampled_tensor = bn_layer(downsampled_tensor,
                                      is_training=is_training)
        feature_maps.append(downsampled_tensor)

        # All feature maps should match the downsampled tensor's shape
        feature_map_shape = downsampled_tensor.shape.as_list()[1:-1]

        # Prepare initial input to dense_vblocks
        initial_features = dense_vnet.initial_conv(input_tensor,
                                                   is_training=is_training)
        channel_dim = len(input_tensor.shape) - 1
        down = tf.concat([downsampled_tensor, initial_features], channel_dim)

        # Feed downsampled input through dense_vblocks
        for dblock in dense_vnet.dense_vblocks:
            # Get skip layer and activation output
            skip, down = dblock(down,
                                is_training=is_training,
                                keep_prob=keep_prob)
            # Resize skip layer to original shape and add to feature maps
            skip = LinearResizeLayer(feature_map_shape)(skip)
            feature_maps.append(skip)

        # Merge feature maps
        all_features = tf.concat(feature_maps, channel_dim)

        # Perform final convolution to segment structures
        output = dense_vnet.final_conv(all_features, is_training=is_training)

        ######################
        ### Postprocessing ###
        ######################

        # Get the number of spatial dimensions of input tensor
        n_spatial_dims = input_tensor.shape.ndims - 2

        # Refine segmentation with prior
        if hyperparams['use_prior']:
            spatial_prior_shape = [hyperparams['prior_size']] * n_spatial_dims
            # Prior shape must be 4 or 5 dim to work with linear_resize layer
            # ie to conform to shape=[batch, X, Y, Z, channels]
            prior_shape = [1] + spatial_prior_shape + [1]
            spatial_prior = SpatialPriorBlock(prior_shape, feature_map_shape)
            output += spatial_prior()

        # Invert augmentation
        if is_training and hyperparams['augmentation_scale'] > 0:
            inverse_aug = augment_layer.inverse()
            output = inverse_aug(output)

        # Resize output to original size
        input_tensor_spatial_size = input_tensor.shape.as_list()[1:-1]
        output = LinearResizeLayer(input_tensor_spatial_size)(output)

        # Segmentation summary
        seg_argmax = tf.to_float(tf.expand_dims(tf.argmax(output, -1), -1))
        seg_summary = seg_argmax * (255. / self.num_classes - 1)

        # Image Summary
        norm_axes = list(range(1, n_spatial_dims + 1))
        mean, var = tf.nn.moments(input_tensor, axes=norm_axes, keep_dims=True)
        timg = tf.to_float(input_tensor - mean) / (tf.sqrt(var) * 2.)
        timg = (timg + 1.) * 127.
        single_channel = tf.reduce_mean(timg, -1, True)
        img_summary = tf.minimum(255., tf.maximum(0., single_channel))

        if n_spatial_dims == 2:
            tf.summary.image(tf.get_default_graph().unique_name('imgseg'),
                             tf.concat([img_summary, seg_summary], 1), 5,
                             [tf.GraphKeys.SUMMARIES])
        elif n_spatial_dims == 3:
            image3_axial(tf.get_default_graph().unique_name('imgseg'),
                         tf.concat([img_summary, seg_summary], 1), 5,
                         [tf.GraphKeys.SUMMARIES])
        else:
            raise NotImplementedError(
                'Image Summary only supports 2D and 3D images')

        return output
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        data_dict = self.get_sampler()[0].pop_batch_op()
        image = tf.cast(data_dict['image'], tf.float32)
        net_out = self.net(image, self.is_training)

        if self.is_training:

            label = data_dict.get('label', None)

            # Changed label on 11/29/2017: This will generate a 2D label
            # from the 3D label provided in the input. Only suitable for STNeuroNet
            k = label.get_shape().as_list()
            label = tf.nn.max_pool3d(label, [1, 1, 1, k[3], 1],
                                     [1, 1, 1, 1, 1],
                                     'VALID',
                                     data_format='NDHWC')
            print('label shape is{}'.format(label.get_shape()))
            print('Image shape is{}'.format(image.get_shape()))
            print('Out shape is{}'.format(net_out.get_shape()))
            ####
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=label,
                                  weight_map=data_dict.get('weight', None))
            if self.net_param.decay > 0.0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                    loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            # ADDED on 10/30 by Soltanian-Zadeh for tensorboard visualization
            seg_summary = tf.to_float(
                tf.expand_dims(tf.argmax(net_out, -1), -1)) * (
                    255. / self.segmentation_param.num_classes - 1)
            label_summary = tf.to_float(tf.expand_dims(
                label, -1)) * (255. / self.segmentation_param.num_classes - 1)
            m, v = tf.nn.moments(image, axes=[1, 2, 3], keep_dims=True)
            img_summary = tf.minimum(
                255.,
                tf.maximum(0., (tf.to_float(image - m) /
                                (tf.sqrt(v) * 2.) + 1.) * 127.))
            image3_axial('img', img_summary, 50, [tf.GraphKeys.SUMMARIES])
            image3_axial('seg', seg_summary, 5, [tf.GraphKeys.SUMMARIES])
            image3_axial('label', label_summary, 5, [tf.GraphKeys.SUMMARIES])
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)
            print('output shape is{}'.format(net_out.get_shape()))

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()
Пример #4
0
    def layer_op(self, input_tensor, is_training, layer_id=-1):
        hp = self.hyperparameters
        if is_training and hp['augmentation_scale'] > 0:
            aug = Affine3DAugmentationLayer(hp['augmentation_scale'], 'LINEAR',
                                            'ZERO')
            input_tensor = aug(input_tensor)
        channel_dim = len(input_tensor.get_shape()) - 1
        input_size = input_tensor.get_shape().as_list()
        spatial_rank = len(input_size) - 2

        modulo = 2**(len(hp['dilation_rates']))
        assert layer_util.check_spatial_dims(input_tensor,
                                             lambda x: x % modulo == 0)

        downsample_channels = list(hp['n_input_channels'][1:]) + [None]
        v_params = zip(hp['n_dense_channels'], hp['n_seg_channels'],
                       downsample_channels, hp['dilation_rates'],
                       range(len(downsample_channels)))

        downsampled_img = BNLayer()(
            tf.nn.avg_pool3d(input_tensor, [1] + [3] * spatial_rank + [1],
                             [1] + [2] * spatial_rank + [1], 'SAME'),
            is_training=is_training)
        all_segmentation_features = [downsampled_img]
        output_shape = downsampled_img.get_shape().as_list()[1:-1]
        initial_features = ConvolutionalLayer(hp['n_input_channels'][0],
                                              kernel_size=5,
                                              stride=2)(
                                                  input_tensor,
                                                  is_training=is_training)

        down = tf.concat([downsampled_img, initial_features], channel_dim)
        for dense_ch, seg_ch, down_ch, dil_rate, idx in v_params:
            sd = DenseFeatureStackBlockWithSkipAndDownsample(
                dense_ch,
                3,
                dil_rate,
                seg_ch,
                down_ch,
                self.architecture_parameters['use_bdo'],
                acti_func='relu')
            skip, down = sd(down,
                            is_training=is_training,
                            keep_prob=hp['p_channels_selected'])
            all_segmentation_features.append(image_resize(skip, output_shape))
        segmentation = ConvolutionalLayer(
            self.num_classes,
            kernel_size=hp['final_kernel'],
            with_bn=False,
            with_bias=True)(tf.concat(all_segmentation_features, channel_dim),
                            is_training=is_training)
        if self.architecture_parameters['use_prior']:
            segmentation = segmentation + \
                           SpatialPriorBlock([12] * spatial_rank, output_shape)
        if is_training and hp['augmentation_scale'] > 0:
            inverse_aug = aug.inverse()
            segmentation = inverse_aug(segmentation)
        segmentation = image_resize(segmentation, input_size[1:-1])
        seg_summary = tf.to_float(
            tf.expand_dims(tf.argmax(segmentation, -1),
                           -1)) * (255. / self.num_classes - 1)
        m, v = tf.nn.moments(input_tensor, axes=[1, 2, 3], keep_dims=True)
        img_summary = tf.minimum(
            255.,
            tf.maximum(0., (tf.to_float(input_tensor - m) /
                            (tf.sqrt(v) * 2.) + 1.) * 127.))
        image3_axial('imgseg', tf.concat([img_summary, seg_summary], 1), 5,
                     [tf.GraphKeys.SUMMARIES])
        return segmentation
Пример #5
0
    def layer_op(self,
                 input_tensor,
                 is_training=True,
                 layer_id=-1,
                 keep_prob=0.5,
                 **unused_kwargs):
        hyper = self.hyperparameters

        # Initialize DenseVNet network layers
        net = self.create_network()

        #
        # Parameter handling
        #

        # Shape and dimension variable shortcuts
        channel_dim = len(input_tensor.shape) - 1
        input_size = input_tensor.shape.as_list()
        spatial_size = input_size[1:-1]
        n_spatial_dims = input_tensor.shape.ndims - 2

        # Validate input dimension with dilation rates
        modulo = 2**(len(hyper['dilation_rates']))
        assert layer_util.check_spatial_dims(input_tensor,
                                             lambda x: x % modulo == 0)

        #
        # Augmentation + Downsampling + Initial Layers
        #

        # On the fly data augmentation
        augment_layer = None
        if is_training and hyper['augmentation_scale'] > 0:
            augmentation_class = AffineAugmentationLayer
            augment_layer = augmentation_class(hyper['augmentation_scale'],
                                               'LINEAR', 'ZERO')
            input_tensor = augment_layer(input_tensor)

        # Variable storing all intermediate results -- VLinks
        all_segmentation_features = []

        # Downsample input to the network
        ave_downsample_layer = DownSampleLayer(func='AVG',
                                               kernel_size=3,
                                               stride=2)
        down_tensor = ave_downsample_layer(input_tensor)
        downsampled_img = net.initial_bn(down_tensor, is_training=is_training)

        # Add initial downsampled image VLink
        all_segmentation_features.append(downsampled_img)

        # All results should match the downsampled input's shape
        output_shape = downsampled_img.shape.as_list()[1:-1]

        init_features = net.initial_conv(input_tensor, is_training=is_training)

        #
        # Dense VNet Main Block
        #

        # `down` will handle the input of each Dense VNet block
        # Initialize it by stacking downsampled image and initial conv features
        down = tf.concat([downsampled_img, init_features], channel_dim)

        # Process Dense VNet Blocks
        for dblock in net.dense_vblocks:
            # Get skip layer and activation output
            skip, down = dblock(down,
                                is_training=is_training,
                                keep_prob=keep_prob)

            # Resize skip layer to original shape and add VLink
            skip = LinearResizeLayer(output_shape)(skip)
            all_segmentation_features.append(skip)

        # Concatenate all intermediate skip layers
        inter_results = tf.concat(all_segmentation_features, channel_dim)

        # Initial segmentation output
        seg_output = net.seg_layer(inter_results, is_training=is_training)

        #
        # Dense VNet End - Now postprocess outputs
        #

        # Refine segmentation with prior if any
        if self.architecture_parameters['use_prior']:
            xyz_prior = SpatialPriorBlock([12] * n_spatial_dims, output_shape)
            seg_output += xyz_prior

        # Invert augmentation if any
        if is_training and hyper['augmentation_scale'] > 0 \
                and augment_layer is not None:
            inverse_aug = augment_layer.inverse()
            seg_output = inverse_aug(seg_output)

        # Resize output to original size
        seg_output = LinearResizeLayer(spatial_size)(seg_output)

        # Segmentation results
        seg_argmax = tf.to_float(tf.expand_dims(tf.argmax(seg_output, -1), -1))
        seg_summary = seg_argmax * (255. / self.num_classes - 1)

        # Image Summary
        norm_axes = list(range(1, n_spatial_dims + 1))
        mean, var = tf.nn.moments(input_tensor, axes=norm_axes, keep_dims=True)
        timg = tf.to_float(input_tensor - mean) / (tf.sqrt(var) * 2.)
        timg = (timg + 1.) * 127.
        single_channel = tf.reduce_mean(timg, -1, True)
        img_summary = tf.minimum(255., tf.maximum(0., single_channel))
        if n_spatial_dims == 2:
            tf.summary.image(tf.get_default_graph().unique_name('imgseg'),
                             tf.concat([img_summary, seg_summary], 1), 5,
                             [tf.GraphKeys.SUMMARIES])
        elif n_spatial_dims == 3:
            # Show summaries
            image3_axial(tf.get_default_graph().unique_name('imgseg'),
                         tf.concat([img_summary, seg_summary], 1), 5,
                         [tf.GraphKeys.SUMMARIES])
        else:
            raise NotImplementedError(
                'Image Summary only supports 2D and 3D images')

        return seg_output
Пример #6
0
    def layer_op(self, input_tensor, is_training, layer_id=-1):
        hp = self.hyperparameters
        if is_training and hp['augmentation_scale']>0:
            aug = Affine3DAugmentationLayer(hp['augmentation_scale'],
                                            'LINEAR','ZERO')
            input_tensor=aug(input_tensor)
        channel_dim = len(input_tensor.get_shape()) - 1
        input_size = input_tensor.shape.as_list()
        spatial_rank = len(input_size) - 2

        modulo = 2 ** (len(hp['dilation_rates']))
        assert layer_util.check_spatial_dims(input_tensor,
                                             lambda x: x % modulo == 0)

        downsample_channels = list(hp['n_input_channels'][1:]) + [None]
        v_params = zip(hp['n_dense_channels'],
                       hp['n_seg_channels'],
                       downsample_channels,
                       hp['dilation_rates'],
                       range(len(downsample_channels)))

        downsampled_img = BNLayer()(tf.nn.avg_pool3d(input_tensor,
                                                     [1] + [3] * spatial_rank + [1],
                                                     [1] + [2] * spatial_rank + [1],
                                                     'SAME'), is_training=is_training)
        all_segmentation_features = [downsampled_img]
        output_shape = downsampled_img.shape.as_list()[1:-1]
        initial_features = ConvolutionalLayer(
            hp['n_input_channels'][0],
            kernel_size=5, stride=2)(input_tensor, is_training=is_training)

        down = tf.concat([downsampled_img, initial_features], channel_dim)
        for dense_ch, seg_ch, down_ch, dil_rate, idx in v_params:
            sd = DenseFeatureStackBlockWithSkipAndDownsample(
                dense_ch,
                3,
                dil_rate,
                seg_ch,
                down_ch,
                self.architecture_parameters['use_bdo'],
                acti_func='relu')
            skip, down = sd(down,
                            is_training=is_training,
                            keep_prob=hp['p_channels_selected'])
            all_segmentation_features.append(image_resize(skip, output_shape))
        segmentation = ConvolutionalLayer(
            self.num_classes,
            kernel_size=hp['final_kernel'],
            with_bn=False,
            with_bias=True)(tf.concat(all_segmentation_features, channel_dim),
                            is_training=is_training)
        if self.architecture_parameters['use_prior']:
            segmentation = segmentation + \
                           SpatialPriorBlock([12] * spatial_rank, output_shape)
        if is_training and hp['augmentation_scale']>0:
            inverse_aug = aug.inverse()
            segmentation = inverse_aug(segmentation)
        segmentation = image_resize(segmentation, input_size[1:-1])
        seg_summary = tf.to_float(tf.expand_dims(tf.argmax(segmentation,-1),-1)) * (255./self.num_classes-1)
        m,v = tf.nn.moments(input_tensor,axes=[1,2,3],keep_dims=True)
        img_summary = tf.minimum(255., tf.maximum(0.,
                         (tf.to_float(input_tensor-m) / (tf.sqrt(v) * 2.) + 1.) * 127.))
        image3_axial('imgseg', tf.concat([img_summary,seg_summary],1) ,
                     5, [tf.GraphKeys.SUMMARIES])
        return segmentation