Example #1
0
def convertArrayToMuliLabelImage(arr, templateImg):

    output_image = sitk.Image(templateImg.GetSize(), sitk.sitkUInt8)
    pz = arr[0, :, :, :]
    pz[pz > 0] = 1

    cz = arr[1, :, :, :]
    cz[cz > 0] = 2
    us = arr[2, :, :, :]
    us[us > 0] = 3
    afs = arr[3, :, :, :]
    afs[afs > 0] = 4

    output_image = sitk.Add(
        output_image, castImage(sitk.GetImageFromArray(pz), sitk.sitkUInt8))
    output_image = sitk.Add(
        output_image, castImage(sitk.GetImageFromArray(cz), sitk.sitkUInt8))
    output_image = sitk.Add(
        output_image, castImage(sitk.GetImageFromArray(us), sitk.sitkUInt8))
    output_image = sitk.Add(
        output_image, castImage(sitk.GetImageFromArray(afs), sitk.sitkUInt8))

    output_image.CopyInformation(templateImg)
    #sitk.WriteImage(output_image, 'output.nrrd')

    return output_image
    def run(self, SegSelectList, spineNode):
        seg_encode_list = []

        seg_full = None
        for idx, seg_select in enumerate(SegSelectList):

            seg_node = seg_select.currentNode()

            if seg_node is None:
                continue
            seg_img = sitkUtils.PullVolumeFromSlicer(seg_node)

            seg_img_encode = sitk.Multiply(seg_img,
                                           vert_encode[vert_list[idx]])

            if seg_full is None:
                seg_full = seg_img_encode

            else:
                seg_full = sitk.Add(seg_full, seg_img_encode)

        seg_name = spineNode.GetName()

        self.segCombineVol = sitkUtils.PushVolumeToSlicer(
            seg_full,
            name=seg_name + '_seg',
            className='vtkMRMLLabelMapVolumeNode')

        slicer.util.setSliceViewerLayers(label=self.segCombineVol,
                                         labelOpacity=1)

        return self.segCombineVol
Example #3
0
def remove_small_connected_component(mask, labels, threshold):
    """ Pick the largest connected component.
    :param mask: The multi label mask
    :param labels: The labels which will be selected the largest connect component.
    :param threshold: The threshold below which the voxel will be set as 0.
    """
    assert isinstance(mask, sitk.Image)
    assert isinstance(labels, list)

    filter = sitk.ConnectedComponentImageFilter()
    filter.SetFullyConnected(True)

    largest_cc_binaries = []
    for label in labels:
        mask_binary = (mask == label)
        mask_binary_cc = filter.Execute(mask_binary)

        mask_binary_cc = sitk.RelabelComponent(mask_binary_cc, threshold)
        mask_binary_largest_cc = (mask_binary_cc > 0)
        largest_cc_binaries.append(mask_binary_largest_cc)

    largest_cc_multi = largest_cc_binaries[0]
    for idx in range(1, len(labels)):
        largest_cc_multi = sitk.Add(largest_cc_multi,
                                    labels[idx] * largest_cc_binaries[idx])

    return sitk.Cast(largest_cc_multi, mask.GetPixelID())
Example #4
0
 def relabelImage(self, labelImage, newRegion, newLabel):
     castedLabelImage = sitk.Cast(labelImage, sitk.sitkInt16)
     castedNewRegion = sitk.Cast(newRegion, sitk.sitkInt16)
     negatedMask = sitk.BinaryNot(castedNewRegion)
     negatedImage = sitk.Mask(castedLabelImage, negatedMask)
     maskTimesNewLabel = sitk.Multiply(castedNewRegion, newLabel)
     relabeledImage = sitk.Add(negatedImage, maskTimesNewLabel)
     return relabeledImage
Example #5
0
def image_add(input_image1, input_image2):
    """Add operation between two input images: input_image1 + input_image2

    :param input_image1: first input image
    :param input_image2: second input image
    :return: result image
    """
    return sitk.Add(input_image1, input_image2)
def SmoothByCurvatureFlow(shape, iterations):
    levelSet = sitk.SignedDanielssonDistanceMap(shape, True, False, True)
    levelSet = sitk.ZeroFluxNeumannPad(levelSet, [0, 0, 1], [0, 0, 1])

    potential = sitk.Image(levelSet.GetSize(), sitk.sitkFloat32)
    potential.CopyInformation(levelSet)
    potential = sitk.Add(potential, 1.0)

    levelSet = sitk.ShapeDetectionLevelSet(levelSet, potential, 0.0, 0.0, 1.0, iterations)
    levelSet = sitk.Crop(levelSet, [0, 0, 1], [0, 0, 1])

    return sitk.Greater(levelSet, 0)
 def relabelImage(self, labelImage, newRegion, newLabel):
     """
     This function...
     :param labelImage:
     :param newRegion:
     :param newLabel:
     :return: relabeledImage
     """
     castedLabelImage = sitk.Cast(labelImage, sitk.sitkInt16)
     castedNewRegion = sitk.Cast(newRegion, sitk.sitkInt16)
     negatedMask = sitk.BinaryNot(castedNewRegion)
     negatedImage = sitk.Mask(castedLabelImage, negatedMask)
     maskTimesNewLabel = sitk.Multiply(castedNewRegion, newLabel)
     relabeledImage = sitk.Add(negatedImage, maskTimesNewLabel)
     return relabeledImage
Example #8
0
def AutoLungSegment(image, l=0.05, u=0.4, NPthresh=1e5):
    """
    Segments the lungs, generating a bounding box

    Args
        image (sitk.Image)  : the input image
        l (float)           : the lower (normalised) threshold
        u (float)           : the upper (normalised) threshold
        NPthresh (int)      : lower limit of voxel counts for a structure to be tested

    Returns
        maskBox (np.ndarray)    : bounding box of the automatically segmented lungs
        maskBinary (sitk.Image) : the segmented lungs (+/- airways)

    """

    # Normalise image intensity
    imNorm = sitk.Normalize(
        sitk.Threshold(image, -1000, 500, outsideValue=-1000))

    # Calculate the label maps and metrics on non-connected regions
    NP, PBR, mask, labels = ThresholdAndMeasureLungVolume(imNorm, l, u)
    indices = np.array(np.where(np.logical_and(PBR <= 5e-4, NP > NPthresh)))

    if indices.size == 0:
        print("     Warning - non-zero perimeter/border ratio")
        indices = np.argmin(PBR)

    if indices.size == 1:
        validLabels = labels[indices]
        maskBinary = sitk.Equal(mask, int(validLabels))

    else:
        validLabels = labels[indices[0]]
        maskBinary = sitk.Equal(mask, int(validLabels[0]))
        for i in range(len(validLabels) - 1):
            maskBinary = sitk.Add(maskBinary,
                                  sitk.Equal(mask, int(validLabels[i + 1])))
    maskBinary = maskBinary > 0
    label_shape_analysis = sitk.LabelShapeStatisticsImageFilter()
    label_shape_analysis.Execute(maskBinary)
    maskBox = label_shape_analysis.GetBoundingBox(True)

    return maskBox, maskBinary
Example #9
0
def mask(image: sitk.Image,
         mask: sitk.Image,
         jacobian: bool = False) -> sitk.Image:
    r""" Mask an image (special meaning for Jacobian maps).

    Parameters
    ----------
    image : sitk.Image
        Input image to be masked (possibly float).
    mask : sitk.Image
        Mask (possibly float).
    jacobian : bool
        If true, the background after masking is set to
        one, if false to zero.

    Returns
    -------
    sitk.Image
        The masked image.
    """

    if jacobian:
        background = np.logical_not(
            sitk.GetArrayViewFromImage(mask)).astype(np_float_type)
        background = sitk.GetImageFromArray(background)
        background = sitk.Cast(background, image.GetPixelID())
        background.CopyInformation(image)

        cast_mask = sitk.Cast(mask, image.GetPixelID())
        cast_mask.CopyInformation(image)

        result = sitk.Multiply(image, cast_mask)
        result = sitk.Add(result, background)

    else:
        cast_mask = sitk.Cast(mask, image.GetPixelID())
        cast_mask.CopyInformation(image)
        result = sitk.Multiply(image, cast_mask)

    return result
def FillBody(otsuMultiple):
    imageSize = otsuMultiple.GetSize()

    ones = sitk.Image(imageSize, otsuMultiple.GetPixelID())
    ones.CopyInformation(otsuMultiple)
    ones = sitk.Add(ones, 1)

    for i in range(imageSize[0]):
        for k in range(imageSize[2]):
            for j in range(imageSize[1]-1, -1, -1):
                if otsuMultiple.GetPixel([i, j, k]) == 2:
                    otsuMultiple = sitk.Paste(destinationImage = otsuMultiple,
                                              sourceImage = ones,
                                              sourceSize = [1, imageSize[1] - j - 1, 1],
                                              sourceIndex = [i, j + 1, k],
                                              destinationIndex = [i, j + 1, k])
                    break

    binary = sitk.Greater(otsuMultiple, 0)
    binary = SlicewiseFillHole(binary)

    return sitk.Maximum(otsuMultiple, binary)
Example #11
0
def compose_displacements(*fields: sitk.Image) -> sitk.Image:
    r""" Compose multiple displacement fields.

    Compute the composition pairwise and iteratively. For a couple
    of displacements :math:`d_1` and :math:`d_2`
    associated to the transforms :math:`f_1` and :math:`f_2`, the
    composition

    .. math::
        (f_2 \circ f_1) (x) = f_2(f_1(x))

    is obtained by resampling :math:`d_2` with :math:`d_1` and then
    summing.

    Parameters
    ----------
    fields : sitk.Image
        Variadic list of displacement fields.

    Returns
    -------
    sitk.Image
        The composition of the input displacement fields.
    """

    fields = list(fields)
    total_field = sitk.Image(fields.pop(0))

    for field in fields:
        resampled_field = sitk.Warp(field,
                                    total_field,
                                    outputSize=total_field.GetSize(),
                                    outputSpacing=total_field.GetSpacing(),
                                    outputOrigin=total_field.GetOrigin(),
                                    outputDirection=total_field.GetDirection())
        resampled_field.CopyInformation(total_field)
        total_field = sitk.Add(total_field, resampled_field)

    return total_field
    def findBreastImplantsRegion(self, EdgeImage, fiducialCoord):
        # The FastMarch measure distance on a map, that is initial levelSet.
        levelSet = sitk.FastMarching(EdgeImage,
                                     trialPoints=tuple(fiducialCoord),
                                     normalizationFactor=0.5,
                                     stoppingValue=10000)

        # In initial, the levelSet should have negative value. e.g. 0 ~ 10000 -> -50 ~ 9950
        levelSet = sitk.Minimum(levelSet, 10000.0)
        levelSet = sitk.Cast(levelSet, sitk.sitkFloat32)
        levelSet = sitk.Add(levelSet, -50.0)

        # move levelSet, this is GACM.
        levelSet = sitk.GeodesicActiveContourLevelSet(
            levelSet,
            EdgeImage,
            propagationScaling=1.0,
            curvatureScaling=0.0,
            advectionScaling=0.0,
            numberOfIterations=1000,
            reverseExpansionDirection=False)

        return levelSet
Example #13
0
def atlas_creation():
    #Load the train labels_native with their transform
    wdpath = 'C:/Users/Admin/PycharmProjects/MyMIALab/data/train'
    results_labels_nii = []
    results_affine = []
    resample_labels = []

    for dirpath, subdirs, files in os.walk(wdpath):
        for x in files:
            if x.endswith("labels_native.nii.gz"):
                results_labels_nii.append(os.path.join(dirpath, x))
            if x.endswith("affine.txt"):
                results_affine.append(os.path.join(dirpath, x))

    #Resample the train labels_native with the transform
    for i in range(0, len(results_affine)):
        transform = sitk.ReadTransform(results_affine[i])
        labels_image = sitk.ReadImage(results_labels_nii[i])
        resample_image = sitk.Resample(labels_image, transform,
                                       sitk.sitkNearestNeighbor, 0,
                                       labels_image.GetPixelIDValue())
        resample_labels.append(resample_image)
        #without resample
        #resample_labels.append(labels_image)

    # Threshold the images to sort them in 5 categories
    white_matter_list = []
    grey_matter_list = []
    hippocampus_list = []
    amygdala_list = []
    thalamus_list = []
    for i in range(0, len(resample_labels)):
        white_matter_list.append(sitk.Threshold(resample_labels[i], 1, 1, 0))
        grey_matter_list.append(sitk.Threshold(resample_labels[i], 2, 2, 0))
        hippocampus_list.append(sitk.Threshold(resample_labels[i], 3, 3, 0))
        amygdala_list.append(sitk.Threshold(resample_labels[i], 4, 4, 0))
        thalamus_list.append(sitk.Threshold(resample_labels[i], 5, 5, 0))

    #sum them up and divide by their number of images to make a probability map
    white_matter_map = 0
    grey_matter_map = 0
    hippocampus_map = 0
    amygdala_map = 0
    thalamus_map = 0

    for i in range(1, len(resample_labels)):
        white_matter_map = sitk.Add(white_matter_map, white_matter_list[i])
        grey_matter_map = sitk.Add(grey_matter_map, grey_matter_list[i])
        hippocampus_map = sitk.Add(hippocampus_map, hippocampus_list[i])
        amygdala_map = sitk.Add(amygdala_map, amygdala_list[i])
        thalamus_map = sitk.Add(thalamus_map, thalamus_list[i])

    white_matter_map = sitk.Divide(white_matter_map, len(white_matter_list))
    grey_matter_map = sitk.Divide(grey_matter_map, len(grey_matter_list))
    hippocampus_map = sitk.Divide(hippocampus_map, len(hippocampus_list))
    amygdala_map = sitk.Divide(amygdala_map, len(amygdala_list))
    thalamus_map = sitk.Divide(thalamus_map, len(thalamus_list))
    #atlas = sitk.Divide(sum_images, len(test_resample))
    #slice = sitk.GetArrayFromImage(atlas)[90,:,:]
    #plt.imshow(slice)

    sitk.WriteImage(
        hippocampus_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/hippocampus_map_no_threshold.nii',
        False)
    sitk.WriteImage(
        white_matter_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/white_matter_map_no_threshold.nii',
        False)
    sitk.WriteImage(
        grey_matter_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/grey_matter_map_no_threshold.nii',
        False)
    sitk.WriteImage(
        amygdala_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/amygdala_map_no_threshold.nii',
        False)
    sitk.WriteImage(
        thalamus_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/thalamus_map_no_threshold.nii',
        False)

    #Threhold the 5 different maps to get a binary map
    white_matter_map = sitk.BinaryThreshold(white_matter_map, 0, 1, 1, 0)
    grey_matter_map = sitk.BinaryThreshold(grey_matter_map, 0, 2, 2, 0)
    hippocampus_map = sitk.BinaryThreshold(hippocampus_map, 0, 3, 3, 0)
    amygdala_map = sitk.BinaryThreshold(amygdala_map, 0, 4, 4, 0)
    thalamus_map = sitk.BinaryThreshold(thalamus_map, 0, 5, 5, 0)

    #Save the images
    sitk.WriteImage(
        grey_matter_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/grey_matter_map.nii',
        False)
    sitk.WriteImage(
        white_matter_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/white_matter_map.nii',
        False)
    sitk.WriteImage(
        hippocampus_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/hippocampus_map.nii',
        False)
    sitk.WriteImage(
        amygdala_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/amygdala_map.nii',
        False)
    sitk.WriteImage(
        thalamus_map,
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/thalamus_map.nii',
        False)

    #Load the test labels_native and their transform
    wdpath_test = 'C:/Users/Admin/PycharmProjects/MyMIALab/data/test'
    test_results_nii = []
    test_results_affine = []
    test_resample = []
    for dirpath, subdirs, files in os.walk(wdpath_test):
        for x in files:
            if x.endswith("labels_native.nii.gz"):
                test_results_nii.append(os.path.join(dirpath, x))
            if x.endswith("affine.txt"):
                test_results_affine.append(os.path.join(dirpath, x))

    #Resample the labels_native with the transform
    for i in range(0, len(test_results_affine)):
        test_transform = sitk.ReadTransform(test_results_affine[i])
        test_image = sitk.ReadImage(test_results_nii[i])
        test_resample_image = sitk.Resample(test_image, test_transform,
                                            sitk.sitkNearestNeighbor)
        test_resample.append(test_resample_image)
        #Without resample
        #test_resample.append(test_image)

    #Save the first test patient labels
    sitk.WriteImage(
        test_resample[0],
        'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result/test.nii',
        False)

    #Compute the dice coeefficent (and the Hausdorff distance)
    label_list = [
        'White Matter', 'Grey Matter', 'Hippocampus', 'Amygdala', 'Thalamus'
    ]
    map_list = [
        white_matter_map, grey_matter_map, hippocampus_map, amygdala_map,
        thalamus_map
    ]
    dice_list = []
    for i in range(0, 5):
        evaluator = eval_.Evaluator(eval_.ConsoleEvaluatorWriter(5))
        evaluator.metrics = [
            metric.DiceCoefficient(),
            metric.Sensitivity(),
            metric.Precision(),
            metric.Fallout()
        ]
        evaluator.add_writer(
            eval_.CSVEvaluatorWriter(
                os.path.join(
                    'C:/Users/Admin/PycharmProjects/MyMIALab/bin/mia-result',
                    'Results_' + label_list[i] + '.csv')))
        evaluator.add_label(i + 1, label_list[i])
        for j in range(0, len(test_resample)):
            evaluator.evaluate(test_resample[j], map_list[i],
                               'Patient ' + str(j))
Example #14
0
    def _makeSTL(self):
        local_dir = self._gray_dir
        surface_dir = self._vol_dir+'_surfaces'+self._path_dlm
        try:
            os.mkdir(surface_dir)
        except:
            pass
        files = fnmatch.filter(sorted(os.listdir(local_dir)),'*.tif')
        counter = re.search("[0-9]*\.tif", files[0]).group()
        prefix = self._path_dlm+string.replace(files[0],counter,'')
        counter = str(len(counter)-4)
        prefixImageName = local_dir + prefix

        ### Create the renderer, the render window, and the interactor. The renderer
        # The following reader is used to read a series of 2D slices (images)
        # that compose the volume. The slice dimensions are set, and the
        # pixel spacing. The data Endianness must also be specified. The reader
        v16=vtk.vtkTIFFReader()

        v16.SetFilePrefix(prefixImageName)
        v16.SetDataExtent(0,100,0,100,1,len(files))
        v16.SetFilePattern("%s%0"+counter+"d.tif")
        v16.Update()

        im = v16.GetOutput()
        im.SetSpacing(self._pixel_dim[0],self._pixel_dim[1],self._pixel_dim[2])

        v = vte.vtkImageExportToArray()
        v.SetInputData(im)

        n = np.float32(v.GetArray())
        idx = np.argwhere(n)
        (ystart,xstart,zstart), (ystop,xstop,zstop) = idx.min(0),idx.max(0)+1
        I,J,K = n.shape
        if ystart > 5:
            ystart -= 5
        else:
            ystart = 0
        if ystop < I-5:
            ystop += 5
        else:
            ystop = I
        if xstart > 5:
            xstart -= 5
        else:
            xstart = 0
        if xstop < J-5:
            xstop += 5
        else:
            xstop = J
        if zstart > 5:
            zstart -= 5
        else:
            zstart = 0
        if zstop < K-5:
            zstop += 5
        else:
            zstop = K

        a = n[ystart:ystop,xstart:xstop,zstart:zstop]
        itk_img = sitk.GetImageFromArray(a)
        itk_img.SetSpacing([self._pixel_dim[0],self._pixel_dim[1],self._pixel_dim[2]])
        
        print "\n"
        print "-------------------------------------------------------"
        print "-- Applying Patch Based Denoising - this can be slow --"
        print "-------------------------------------------------------"
        print "\n"
        pb = sitk.PatchBasedDenoisingImageFilter()
        pb.KernelBandwidthEstimationOn()
        pb.SetNoiseModel(3) #use a Poisson noise model since this is confocal
        pb.SetNoiseModelFidelityWeight(1)
        pb.SetNumberOfSamplePatches(20)
        pb.SetPatchRadius(4)
        pb.SetNumberOfIterations(10)

        fimg = pb.Execute(itk_img)
        b = sitk.GetArrayFromImage(fimg)
        intensity = b.max()

        #grad = sitk.GradientMagnitudeRecursiveGaussianImageFilter()
        #grad.SetSigma(0.05)
        gf = sitk.GradientMagnitudeImageFilter()
        gf.UseImageSpacingOn()
        grad = gf.Execute(fimg)
        edge = sitk.Cast(sitk.BoundedReciprocal( grad ),sitk.sitkFloat32)


        print "\n"
        print "-------------------------------------------------------"
        print "---- Thresholding to deterimine initial level sets ----"
        print "-------------------------------------------------------"
        print "\n"
        t = 0.5
        seed = sitk.BinaryThreshold(fimg,t*intensity)
        #Opening (Erosion/Dilation) step to remove islands smaller than 2 voxels in radius)
        seed = sitk.BinaryMorphologicalOpening(seed,2)
        seed = sitk.BinaryFillhole(seed!=0)
        #Get connected regions
        r = sitk.ConnectedComponent(seed)
        labels = sitk.GetArrayFromImage(r)
        ids = sorted(np.unique(labels))
        N = len(ids)
        if N > 2:
            i = np.copy(N)
            while i == N and (t-self._tratio)>-1e-7:
                t -= 0.01
                seed = sitk.BinaryThreshold(fimg,t*intensity)
                #Opening (Erosion/Dilation) step to remove islands smaller than 2 voxels in radius)
                seed = sitk.BinaryMorphologicalOpening(seed,2)
                seed = sitk.BinaryFillhole(seed!=0)
                #Get connected regions
                r = sitk.ConnectedComponent(seed)
                labels = sitk.GetArrayFromImage(r)
                i = len(np.unique(labels))
                if i > N:
                    N = np.copy(i)
            t+=0.01
        else:
            t = np.copy(self._tratio)
        seed = sitk.BinaryThreshold(fimg,t*intensity)
        #Opening (Erosion/Dilation) step to remove islands smaller than 2 voxels in radius)
        seed = sitk.BinaryMorphologicalOpening(seed,2)
        seed = sitk.BinaryFillhole(seed!=0)
        #Get connected regions
        r = sitk.ConnectedComponent(seed)
        labels = sitk.GetArrayFromImage(r)
        labels = np.unique(labels)[1:]

        '''
        labels[labels==0] = -1
        labels = sitk.GetImageFromArray(labels)
        labels.SetSpacing([self._pixel_dim[0],self._pixel_dim[1],self._pixel_dim[2]])
        #myshow3d(labels,zslices=range(20))
        #plt.show()
        ls = sitk.ScalarChanAndVeseDenseLevelSetImageFilter()
        ls.UseImageSpacingOn()
        ls.SetLambda2(1.5)
        #ls.SetCurvatureWeight(1.0)
        ls.SetAreaWeight(1.0)
        #ls.SetReinitializationSmoothingWeight(1.0)
        ls.SetNumberOfIterations(100)
        seg = ls.Execute(sitk.Cast(labels,sitk.sitkFloat32),sitk.Cast(fimg,sitk.sitkFloat32))
        seg = sitk.Cast(seg,sitk.sitkUInt8)
        seg = sitk.BinaryMorphologicalOpening(seg,1)
        seg = sitk.BinaryFillhole(seg!=0)
        #Get connected regions
        #r = sitk.ConnectedComponent(seg)
        contours = sitk.BinaryContour(seg)
        myshow3d(sitk.LabelOverlay(sitk.Cast(fimg,sitk.sitkUInt8),contours),zslices=range(fimg.GetSize()[2]))
        plt.show()
        '''

        segmentation = sitk.Image(r.GetSize(),sitk.sitkUInt8)
        segmentation.SetSpacing([self._pixel_dim[0],self._pixel_dim[1],self._pixel_dim[2]])
        for l in labels:
            d = sitk.SignedMaurerDistanceMap(r==l,insideIsPositive=False,squaredDistance=True,useImageSpacing=True)
            #d = sitk.BinaryThreshold(d,-1000,0)
            #d = sitk.Cast(d,edge.GetPixelIDValue() )*-1+0.5
            #d = sitk.Cast(d,edge.GetPixelIDValue() )
            seg = sitk.GeodesicActiveContourLevelSetImageFilter()
            seg.SetPropagationScaling(1.0)
            seg.SetAdvectionScaling(1.0)
            seg.SetCurvatureScaling(0.5)
            seg.SetMaximumRMSError(0.01)
            levelset = seg.Execute(d,edge)
            levelset = sitk.BinaryThreshold(levelset,-1000,0)
            segmentation = sitk.Add(segmentation,levelset)
            print ("RMS Change for Cell %d: "% l,seg.GetRMSChange())
            print ("Elapsed Iterations for Cell %d: "% l, seg.GetElapsedIterations())
        '''
        contours = sitk.BinaryContour(segmentation)
        myshow3d(sitk.LabelOverlay(sitk.Cast(fimg,sitk.sitkUInt8),contours),zslices=range(fimg.GetSize()[2]))
        plt.show()
        '''

        n[ystart:ystop,xstart:xstop,zstart:zstop] = sitk.GetArrayFromImage(segmentation)*100

        i = vti.vtkImageImportFromArray()
        i.SetDataSpacing([self._pixel_dim[0],self._pixel_dim[1],self._pixel_dim[2]])
        i.SetDataExtent([0,100,0,100,1,len(files)])
        i.SetArray(n)
        i.Update()

        thres=vtk.vtkImageThreshold()
        thres.SetInputData(i.GetOutput())
        thres.ThresholdByLower(0)
        thres.ThresholdByUpper(101)

        iso=vtk.vtkImageMarchingCubes()
        iso.SetInputConnection(thres.GetOutputPort())
        iso.SetValue(0,1)

        regions = vtk.vtkConnectivityFilter()
        regions.SetInputConnection(iso.GetOutputPort())
        regions.SetExtractionModeToAllRegions()
        regions.ColorRegionsOn()
        regions.Update()

        N = regions.GetNumberOfExtractedRegions()
        for i in xrange(N):
            r = vtk.vtkConnectivityFilter()
            r.SetInputConnection(iso.GetOutputPort())
            r.SetExtractionModeToSpecifiedRegions()
            r.AddSpecifiedRegion(i)
            g = vtk.vtkExtractUnstructuredGrid()
            g.SetInputConnection(r.GetOutputPort())
            geo = vtk.vtkGeometryFilter()
            geo.SetInputConnection(g.GetOutputPort())
            geo.Update()
            t = vtk.vtkTriangleFilter()
            t.SetInputConnection(geo.GetOutputPort())
            t.Update()
            cleaner = vtk.vtkCleanPolyData()
            cleaner.SetInputConnection(t.GetOutputPort())
            s = vtk.vtkSmoothPolyDataFilter()
            s.SetInputConnection(cleaner.GetOutputPort())
            s.SetNumberOfIterations(50)
            dl = vtk.vtkDelaunay3D()
            dl.SetInputConnection(s.GetOutputPort())
            dl.Update()

            self.cells.append(dl)

        for i in xrange(N):
            g = vtk.vtkGeometryFilter()
            g.SetInputConnection(self.cells[i].GetOutputPort())
            t = vtk.vtkTriangleFilter()
            t.SetInputConnection(g.GetOutputPort())

            #get the surface points of the cells and save to points attribute
            v = t.GetOutput()
            points = []
            for j in xrange(v.GetNumberOfPoints()):
                p = [0,0,0]
                v.GetPoint(j,p)
                points.append(p)
            self.points.append(points)

            #get the volume of the cell
            vo = vtk.vtkMassProperties()
            vo.SetInputConnection(t.GetOutputPort())
            self.volumes.append(vo.GetVolume())

            stl = vtk.vtkSTLWriter()
            stl.SetInputConnection(t.GetOutputPort())
            stl.SetFileName(surface_dir+'cell%02d.stl' % (i+self._counter))
            stl.Write()

        if self._display:
            skinMapper = vtk.vtkDataSetMapper()
            skinMapper.SetInputConnection(regions.GetOutputPort())
            skinMapper.SetScalarRange(regions.GetOutput().GetPointData().GetArray("RegionId").GetRange())
            skinMapper.SetColorModeToMapScalars()
            #skinMapper.ScalarVisibilityOff()
            skinMapper.Update()

            skin = vtk.vtkActor()
            skin.SetMapper(skinMapper)
            #skin.GetProperty().SetColor(0,0,255)

            # An outline provides context around the data.
            #
            outlineData = vtk.vtkOutlineFilter()
            outlineData.SetInputConnection(v16.GetOutputPort())

            mapOutline = vtk.vtkPolyDataMapper()
            mapOutline.SetInputConnection(outlineData.GetOutputPort())

            outline = vtk.vtkActor()
            #outline.SetMapper(mapOutline)
            #outline.GetProperty().SetColor(0,0,0)

            colorbar = vtk.vtkScalarBarActor()
            colorbar.SetLookupTable(skinMapper.GetLookupTable())
            colorbar.SetTitle("Cells")
            colorbar.SetNumberOfLabels(N)


            # Create the renderer, the render window, and the interactor. The renderer
            # draws into the render window, the interactor enables mouse- and 
            # keyboard-based interaction with the data within the render window.
            #
            aRenderer = vtk.vtkRenderer()
            renWin = vtk.vtkRenderWindow()
            renWin.AddRenderer(aRenderer)
            iren = vtk.vtkRenderWindowInteractor()
            iren.SetRenderWindow(renWin)

            # It is convenient to create an initial view of the data. The FocalPoint
            # and Position form a vector direction. Later on (ResetCamera() method)
            # this vector is used to position the camera to look at the data in
            # this direction.
            aCamera = vtk.vtkCamera()
            aCamera.SetViewUp (0, 0, -1)
            aCamera.SetPosition (0, 1, 0)
            aCamera.SetFocalPoint (0, 0, 0)
            aCamera.ComputeViewPlaneNormal()

            # Actors are added to the renderer. An initial camera view is created.
            # The Dolly() method moves the camera towards the FocalPoint,
            # thereby enlarging the image.
            aRenderer.AddActor(outline)
            aRenderer.AddActor(skin)
            aRenderer.AddActor(colorbar)
            aRenderer.SetActiveCamera(aCamera)
            aRenderer.ResetCamera ()
            aCamera.Dolly(1.5)

            # Set a background color for the renderer and set the size of the
            # render window (expressed in pixels).
            aRenderer.SetBackground(0.0,0.0,0.0)
            renWin.SetSize(800, 600)

            # Note that when camera movement occurs (as it does in the Dolly()
            # method), the clipping planes often need adjusting. Clipping planes
            # consist of two planes: near and far along the view direction. The 
            # near plane clips out objects in front of the plane the far plane
            # clips out objects behind the plane. This way only what is drawn
            # between the planes is actually rendered.
            aRenderer.ResetCameraClippingRange()

            im=vtk.vtkWindowToImageFilter()
            im.SetInput(renWin)

            iren.Initialize();
            iren.Start();

        #remove gray directory
        shutil.rmtree(local_dir)
Example #15
0
def removeIslands(predictedArray):
    pred = predictedArray
    print(pred.shape)
    pred_pz = thresholdArray(pred[0, :, :, :], 0.5)
    pred_cz = thresholdArray(pred[1, :, :, :], 0.5)
    pred_us = thresholdArray(pred[2, :, :, :], 0.5)
    pred_afs = thresholdArray(pred[3, :, :, :], 0.5)
    pred_bg = thresholdArray(pred[4, :, :, :], 0.5)

    pred_pz_img = sitk.GetImageFromArray(pred_pz)
    pred_cz_img = sitk.GetImageFromArray(pred_cz)
    pred_us_img = sitk.GetImageFromArray(pred_us)
    pred_afs_img = sitk.GetImageFromArray(pred_afs)
    pred_bg_img = sitk.GetImageFromArray(pred_bg)
    # pred_bg_img = utils.castImage(pred_bg, sitk.sitkInt8)

    pred_pz_img_cc, pz_otherCC = getConnectedComponents(pred_pz_img)
    pred_cz_img_cc, cz_otherCC = getConnectedComponents(pred_cz_img)
    pred_us_img_cc, us_otherCC = getConnectedComponents(pred_us_img)
    pred_afs_img_cc, afs_otherCC = getConnectedComponents(pred_afs_img)
    pred_bg_img_cc, bg_otherCC = getConnectedComponents(pred_bg_img)

    added_otherCC = sitk.Add(afs_otherCC, pz_otherCC)
    added_otherCC = sitk.Add(added_otherCC, cz_otherCC)
    added_otherCC = sitk.Add(added_otherCC, us_otherCC)
    added_otherCC = sitk.Add(added_otherCC, bg_otherCC)

    # sitk.WriteImage(added_otherCC, 'addedOtherCC.nrrd')
    # sitk.WriteImage(pred_cz_img, 'pred_cz.nrrd')

    pz_dis = sitk.SignedMaurerDistanceMap(pred_pz_img_cc,
                                          insideIsPositive=True,
                                          squaredDistance=False,
                                          useImageSpacing=False)
    cz_dis = sitk.SignedMaurerDistanceMap(pred_cz_img_cc,
                                          insideIsPositive=True,
                                          squaredDistance=False,
                                          useImageSpacing=False)
    us_dis = sitk.SignedMaurerDistanceMap(pred_us_img_cc,
                                          insideIsPositive=True,
                                          squaredDistance=False,
                                          useImageSpacing=False)
    afs_dis = sitk.SignedMaurerDistanceMap(pred_afs_img_cc,
                                           insideIsPositive=True,
                                           squaredDistance=False,
                                           useImageSpacing=False)
    bg_dis = sitk.SignedMaurerDistanceMap(pred_bg_img_cc,
                                          insideIsPositive=True,
                                          squaredDistance=False,
                                          useImageSpacing=False)

    # sitk.WriteImage(pred_cz_img_cc, 'pred_cz_cc.nrrd')
    # sitk.WriteImage(cz_dis, 'cz_dis.nrrd')

    array_pz = sitk.GetArrayFromImage(pred_pz_img_cc)
    array_cz = sitk.GetArrayFromImage(pred_cz_img_cc)
    array_us = sitk.GetArrayFromImage(pred_us_img_cc)
    array_afs = sitk.GetArrayFromImage(pred_afs_img_cc)
    array_bg = sitk.GetArrayFromImage(pred_bg_img_cc)

    finalPrediction = np.zeros([5, 32, 168, 168])
    finalPrediction[0] = array_pz
    finalPrediction[1] = array_cz
    finalPrediction[2] = array_us
    finalPrediction[3] = array_afs
    finalPrediction[4] = array_bg

    array = np.zeros([1, 1, 1, 1])

    for x in range(0, pred_cz_img.GetSize()[0]):
        for y in range(0, pred_cz_img.GetSize()[1]):
            for z in range(0, pred_cz_img.GetSize()[2]):

                pos = [x, y, z]
                if (added_otherCC[pos] > 0):
                    # print(pz_dis.GetPixel(x,y,z),cz_dis.GetPixel(x,y,z),us_dis.GetPixel(x,y,z), afs_dis.GetPixel(x,y,z))
                    array = [
                        pz_dis.GetPixel(x, y, z),
                        cz_dis.GetPixel(x, y, z),
                        us_dis.GetPixel(x, y, z),
                        afs_dis.GetPixel(x, y, z),
                        bg_dis.GetPixel(x, y, z)
                    ]
                    maxValue = max(array)
                    max_index = array.index(maxValue)
                    finalPrediction[max_index, z, y, x] = 1

    return finalPrediction
Example #16
0
def multi_stage(setting, pair_info, overwrite=False):
    """
    :param setting:
    :param pair_info: information of the pair to be registered.
    :param overwrite:
    :return: The output moved images and dvf will be written to the disk.
             1: registration is performed correctly
             2: skip overwriting
             3: the dvf is available from the previous experiment [4, 2, 1]. Then just upsample it.
    """
    stage_list = setting['ImagePyramidSchedule']
    if setting['read_pair_mode'] == 'synthetic':
        deformed_im_ext = pair_info[0].get('deformed_im_ext', None)
        im_info_su = {
            'data': pair_info[0]['data'],
            'deform_exp': pair_info[0]['deform_exp'],
            'type_im': pair_info[0]['type_im'],
            'cn': pair_info[0]['cn'],
            'dsmooth': pair_info[0]['dsmooth'],
            'padto': pair_info[0]['padto'],
            'deformed_im_ext': deformed_im_ext
        }
        moved_im_s0_address = su.address_generator(setting,
                                                   'MovedIm_AG',
                                                   stage=1,
                                                   **im_info_su)
        moved_torso_s1_address = su.address_generator(setting,
                                                      'MovedTorso_AG',
                                                      stage=1,
                                                      **im_info_su)
        moved_lung_s1_address = su.address_generator(setting,
                                                     'MovedLung_AG',
                                                     stage=1,
                                                     **im_info_su)
    else:
        moved_im_s0_address = su.address_generator(setting,
                                                   'MovedIm',
                                                   pair_info=pair_info,
                                                   stage=0,
                                                   stage_list=stage_list)
        moved_torso_s1_address = None
        moved_lung_s1_address = None

    if setting['read_pair_mode'] == 'synthetic':
        if os.path.isfile(moved_im_s0_address) and os.path.isfile(
                moved_torso_s1_address):
            if not overwrite:
                logging.debug('overwrite=False, file ' + moved_im_s0_address +
                              ' already exists, skipping .....')
                return 2
            else:
                logging.debug('overwrite=True, file ' + moved_im_s0_address +
                              ' already exists, but overwriting .....')
    else:
        if os.path.isfile(moved_im_s0_address):
            if not overwrite:
                logging.debug('overwrite=False, file ' + moved_im_s0_address +
                              ' already exists, skipping .....')
                return 2
            else:
                logging.debug('overwrite=True, file ' + moved_im_s0_address +
                              ' already exists, but overwriting .....')

    pair_stage1 = real_pair.Images(
        setting, pair_info, stage=1, padto=setting['PadTo']
        ['stage1'])  # just read the original images without any padding
    pyr = dict()  # pyr: a dictionary of pyramid images
    pyr['fixed_im_s1_sitk'] = pair_stage1.get_fixed_im_sitk()
    pyr['moving_im_s1_sitk'] = pair_stage1.get_moved_im_affine_sitk()
    pyr['fixed_im_s1'] = pair_stage1.get_fixed_im()
    pyr['moving_im_s1'] = pair_stage1.get_moved_im_affine()
    if setting['UseMask']:
        pyr['fixed_mask_s1_sitk'] = pair_stage1.get_fixed_mask_sitk()
        pyr['moving_mask_s1_sitk'] = pair_stage1.get_moved_mask_affine_sitk()
    if setting['read_pair_mode'] == 'real':
        if not (os.path.isdir(
                su.address_generator(setting,
                                     'full_reg_folder',
                                     pair_info=pair_info,
                                     stage_list=stage_list))):
            os.makedirs(
                su.address_generator(setting,
                                     'full_reg_folder',
                                     pair_info=pair_info,
                                     stage_list=stage_list))
    setting['GPUMemory'], setting['NumberOfGPU'] = tfu.client.read_gpu_memory()
    time_before_dvf = time.time()

    # check if DVF is available from the previous experiment [4, 2, 1]. Then just upsample it.
    if stage_list in [[4, 2], [4]]:
        dvf0_address = su.address_generator(setting,
                                            'dvf_s0',
                                            pair_info=pair_info,
                                            stage_list=stage_list)
        chosen_stage = None
        if stage_list == [4, 2]:
            chosen_stage = 2
        elif stage_list == [4]:
            chosen_stage = 4
        if chosen_stage is not None:
            dvf_s_up_address = su.address_generator(setting,
                                                    'dvf_s_up',
                                                    pair_info=pair_info,
                                                    stage=chosen_stage,
                                                    stage_list=[4, 2, 1])
            if os.path.isfile(dvf_s_up_address):
                logging.debug('DVF found from prev exp:' + dvf_s_up_address +
                              ', only performing upsampling')
                dvf_s_up = sitk.ReadImage(dvf_s_up_address)
                dvf0 = ip.resampler_sitk(
                    dvf_s_up,
                    scale=1 / (chosen_stage / 2),
                    im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(),
                    interpolator=sitk.sitkLinear)
                sitk.WriteImage(sitk.Cast(dvf0, sitk.sitkVectorFloat32),
                                dvf0_address)
                return 3

    for i_stage, stage in enumerate(setting['ImagePyramidSchedule']):
        mask_to_zero_stage = setting['network_dict']['stage' +
                                                     str(stage)]['MaskToZero']
        if stage != 1:
            pyr['fixed_im_s' + str(stage) + '_sitk'] = ip.downsampler_gpu(
                pyr['fixed_im_s1_sitk'],
                stage,
                default_pixel_value=setting['data'][
                    pair_info[0]['data']]['DefaultPixelValue'])
            pyr['moving_im_s' + str(stage) + '_sitk'] = ip.downsampler_gpu(
                pyr['moving_im_s1_sitk'],
                stage,
                default_pixel_value=setting['data'][
                    pair_info[1]['data']]['DefaultPixelValue'])
        if setting['UseMask']:
            pyr['fixed_mask_s' + str(stage) + '_sitk'] = ip.resampler_sitk(
                pyr['fixed_mask_s1_sitk'],
                scale=stage,
                im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'],
                default_pixel_value=0,
                interpolator=sitk.sitkNearestNeighbor)
            pyr['moving_mask_s' + str(stage) + '_sitk'] = ip.resampler_sitk(
                pyr['moving_mask_s1_sitk'],
                scale=stage,
                im_ref=pyr['moving_im_s' + str(stage) + '_sitk'],
                default_pixel_value=0,
                interpolator=sitk.sitkNearestNeighbor)

            if setting['WriteMasksForLSTM']:
                # only to be used in sequential training (LSTM)
                if setting['read_pair_mode'] == 'synthetic':
                    fixed_mask_stage_address = su.address_generator(
                        setting,
                        'Deformed' + mask_to_zero_stage,
                        stage=stage,
                        **im_info_su)
                    moving_mask_stage_address = su.address_generator(
                        setting, mask_to_zero_stage, stage=stage, **im_info_su)
                    fixed_im_stage_address = su.address_generator(setting,
                                                                  'DeformedIm',
                                                                  stage=stage,
                                                                  **im_info_su)
                    sitk.WriteImage(
                        sitk.Cast(
                            pyr['fixed_im_s' + str(stage) + '_sitk'],
                            setting['data'][pair_info[1]['data']]
                            ['ImageByte']), fixed_im_stage_address)
                    sitk.WriteImage(pyr['fixed_mask_s' + str(stage) + '_sitk'],
                                    fixed_mask_stage_address)
                    if im_info_su['dsmooth'] != 0 and stage == 4:
                        # not overwirte original images
                        moving_im_stage_address = su.address_generator(
                            setting, 'Im', stage=stage, **im_info_su)
                        sitk.WriteImage(
                            sitk.Cast(
                                pyr['moving_im_s' + str(stage) + '_sitk'],
                                setting['data'][pair_info[1]['data']]
                                ['ImageByte']), moving_im_stage_address)
                        sitk.WriteImage(
                            pyr['moving_mask_s' + str(stage) + '_sitk'],
                            moving_mask_stage_address)
                else:
                    fixed_im_stage_address = su.address_generator(
                        setting, 'Im', stage=stage, **pair_info[0])
                    fixed_mask_stage_address = su.address_generator(
                        setting,
                        mask_to_zero_stage,
                        stage=stage,
                        **pair_info[0])
                    if not os.path.isfile(fixed_im_stage_address):
                        sitk.WriteImage(
                            sitk.Cast(
                                pyr['fixed_im_s' + str(stage) + '_sitk'],
                                setting['data'][pair_info[1]['data']]
                                ['ImageByte']), fixed_im_stage_address)
                    if not os.path.isfile(fixed_mask_stage_address):
                        sitk.WriteImage(
                            pyr['fixed_mask_s' + str(stage) + '_sitk'],
                            fixed_mask_stage_address)
                    if i_stage == 0:
                        moved_im_affine_stage_address = su.address_generator(
                            setting,
                            'MovedImBaseReg',
                            pair_info=pair_info,
                            stage=stage,
                            **pair_info[1])
                        moved_mask_affine_stage_address = su.address_generator(
                            setting,
                            'Moved' + mask_to_zero_stage + 'BaseReg',
                            pair_info=pair_info,
                            stage=stage,
                            **pair_info[1])
                        if not os.path.isfile(moved_im_affine_stage_address):
                            sitk.WriteImage(
                                sitk.Cast(
                                    pyr['moving_im_s' + str(stage) + '_sitk'],
                                    setting['data'][pair_info[1]
                                                    ['data']]['ImageByte']),
                                moved_im_affine_stage_address)
                        if not os.path.isfile(moved_mask_affine_stage_address):
                            sitk.WriteImage(
                                pyr['moving_mask_s' + str(stage) + '_sitk'],
                                moved_mask_affine_stage_address)

        else:
            pyr['fixed_mask_s' + str(stage) + '_sitk'] = None
            pyr['moving_mask_s' + str(stage) + '_sitk'] = None
        input_regnet_moving_mask = None
        if i_stage == 0:
            input_regnet_moving = 'moving_im_s' + str(stage) + '_sitk'
            if setting['UseMask']:
                input_regnet_moving_mask = 'moving_mask_s' + str(
                    stage) + '_sitk'
        else:
            previous_pyramid = setting['ImagePyramidSchedule'][i_stage - 1]
            dvf_composed_previous_up_sitk = 'DVF_s' + str(
                previous_pyramid) + '_composed_up_sitk'
            dvf_composed_previous_sitk = 'DVF_s' + str(
                previous_pyramid) + '_composed_sitk'
            if i_stage == 1:
                pyr[dvf_composed_previous_sitk] = pyr[
                    'DVF_s' +
                    str(setting['ImagePyramidSchedule'][i_stage - 1]) +
                    '_sitk']
            elif i_stage > 1:
                pyr[dvf_composed_previous_sitk] = sitk.Add(
                    pyr['DVF_s' +
                        str(setting['ImagePyramidSchedule'][i_stage - 2]) +
                        '_composed_up_sitk'],
                    pyr['DVF_s' +
                        str(setting['ImagePyramidSchedule'][i_stage - 1]) +
                        '_sitk'])
            pyr[dvf_composed_previous_up_sitk] = ip.upsampler_gpu(
                pyr[dvf_composed_previous_sitk],
                round(previous_pyramid / stage),
                output_shape_3d=pyr['fixed_im_s' + str(stage) +
                                    '_sitk'].GetSize()[::-1],
            )
            if setting['WriteAfterEachStage'] and not setting['WriteNoDVF']:
                sitk.WriteImage(
                    sitk.Cast(pyr[dvf_composed_previous_up_sitk],
                              sitk.sitkVectorFloat32),
                    su.address_generator(setting,
                                         'dvf_s_up',
                                         pair_info=pair_info,
                                         stage=previous_pyramid,
                                         stage_list=stage_list))

            dvf_t = sitk.DisplacementFieldTransform(
                pyr[dvf_composed_previous_up_sitk])
            # after this line DVF_composed_previous_up_sitk is converted to a transform. so we need to load it again.
            pyr['moved_im_s' + str(stage) +
                '_sitk'] = ip.resampler_by_transform(
                    pyr['moving_im_s' + str(stage) + '_sitk'],
                    dvf_t,
                    default_pixel_value=setting['data'][
                        pair_info[1]['data']]['DefaultPixelValue'])
            if setting['UseMask']:
                pyr['moved_mask_s' + str(stage) +
                    '_sitk'] = ip.resampler_by_transform(
                        pyr['moving_mask_s' + str(stage) + '_sitk'],
                        dvf_t,
                        default_pixel_value=0,
                        interpolator=sitk.sitkNearestNeighbor)

            pyr[dvf_composed_previous_up_sitk] = dvf_t.GetDisplacementField()
            if setting['WriteAfterEachStage']:
                if setting['read_pair_mode'] == 'synthetic':
                    moved_im_s_address = su.address_generator(setting,
                                                              'MovedIm_AG',
                                                              stage=stage,
                                                              **im_info_su)
                    moved_mask_s_address = su.address_generator(
                        setting,
                        'Moved' + mask_to_zero_stage + '_AG',
                        stage=stage,
                        **im_info_su)
                else:
                    moved_im_s_address = su.address_generator(
                        setting,
                        'MovedIm',
                        pair_info=pair_info,
                        stage=stage,
                        stage_list=stage_list)
                    moved_mask_s_address = su.address_generator(
                        setting,
                        'Moved' + mask_to_zero_stage,
                        pair_info=pair_info,
                        stage=stage,
                        stage_list=stage_list)

                sitk.WriteImage(
                    sitk.Cast(
                        pyr['moved_im_s' + str(stage) + '_sitk'],
                        setting['data'][pair_info[1]['data']]['ImageByte']),
                    moved_im_s_address)

                if setting['WriteMasksForLSTM']:
                    sitk.WriteImage(pyr['moved_mask_s' + str(stage) + '_sitk'],
                                    moved_mask_s_address)

            input_regnet_moving = 'moved_im_s' + str(stage) + '_sitk'
            if setting['UseMask']:
                input_regnet_moving_mask = 'moved_mask_s' + str(
                    stage) + '_sitk'

        pyr['DVF_s' + str(stage)] = np.zeros(
            np.r_[pyr['fixed_im_s' + str(stage) + '_sitk'].GetSize()[::-1], 3],
            dtype=np.float64)
        if setting['network_dict']['stage'+str(stage)]['R'] == 'Auto' and \
                setting['network_dict']['stage'+str(stage)]['Ry'] == 'Auto':
            current_network_name = setting['network_dict'][
                'stage' + str(stage)]['NetworkDesign']
            r_out_erode_default = setting['network_dict'][
                'stage' + str(stage)]['Ry_erode']
            r_in, r_out, r_out_erode = network.utils.find_optimal_radius(
                pyr['fixed_im_s' + str(stage) + '_sitk'],
                current_network_name,
                r_out_erode_default,
                gpu_memory=setting['GPUMemory'],
                number_of_gpu=setting['NumberOfGPU'])

        else:
            r_in = setting['network_dict']['stage' + str(stage)][
                'R']  # Radius of normal resolution patch size. Total size is (2*R +1)
            r_out = setting['network_dict']['stage' + str(stage)][
                'Ry']  # Radius of output. Total size is (2*Ry +1)
            r_out_erode = setting['network_dict']['stage' + str(stage)][
                'Ry_erode']  # at the test time, sometimes there are some problems at the border

        logging.debug(
            'stage' + str(stage) + ' ,' + pair_info[0]['data'] +
            ', CN{}, ImType{}, Size={}'.format(
                pair_info[0]['cn'], pair_info[0]['type_im'], pyr[
                    'fixed_im_s' + str(stage) + '_sitk'].GetSize()[::-1]) +
            ', ' +
            setting['network_dict']['stage' + str(stage)]['NetworkDesign'] +
            ': r_in:{}, r_out:{}, r_out_erode:{}'.format(
                r_in, r_out, r_out_erode))
        pair_pyramid = real_pair.Images(
            setting,
            pair_info,
            stage=stage,
            fixed_im_sitk=pyr['fixed_im_s' + str(stage) + '_sitk'],
            moved_im_affine_sitk=pyr[input_regnet_moving],
            fixed_mask_sitk=pyr['fixed_mask_s' + str(stage) + '_sitk'],
            moved_mask_affine_sitk=pyr[input_regnet_moving_mask],
            padto=setting['PadTo']['stage' + str(stage)],
            r_in=r_in,
            r_out=r_out,
            r_out_erode=r_out_erode)

        # building and loading network
        tf.reset_default_graph()
        images_tf = tf.placeholder(
            tf.float32,
            shape=[None, 2 * r_in + 1, 2 * r_in + 1, 2 * r_in + 1, 2],
            name="Images")
        bn_training = tf.placeholder(tf.bool, name='bn_training')
        dvf_tf = getattr(
            getattr(
                network, setting['network_dict']['stage' +
                                                 str(stage)]['NetworkDesign']),
            'network')(images_tf, bn_training)
        logging.debug(' Total number of variables %s' % (np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ])))
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
        sess = tf.Session()
        saver.restore(
            sess,
            su.address_generator(
                setting,
                'saved_model_with_step',
                current_experiment=setting['network_dict'][
                    'stage' + str(stage)]['NetworkLoad'],
                step=setting['network_dict']['stage' +
                                             str(stage)]['GlobalStepLoad']))
        while not pair_pyramid.get_sweep_completed():
            # The pyr[DVF_S] is the numpy DVF which will be filled. the dvf_np is an output
            # patch from the network. We control the spatial location of both dvf in this function
            batch_im, win_center, win_r_before, win_r_after, predicted_begin, predicted_end = pair_pyramid.next_sweep_patch(
            )
            time_before_gpu = time.time()
            [dvf_np] = sess.run([dvf_tf],
                                feed_dict={
                                    images_tf: batch_im,
                                    bn_training: 0
                                })
            time_after_gpu = time.time()
            logging.debug('GPU: ' + pair_info[0]['data'] +
                          ', CN{} center={} is done in {:.2f}s '.format(
                              pair_info[0]['cn'], win_center, time_after_gpu -
                              time_before_gpu))

            pyr['DVF_s'+str(stage)][win_center[0] - win_r_before[0]: win_center[0] + win_r_after[0],
                                    win_center[1] - win_r_before[1]: win_center[1] + win_r_after[1],
                                    win_center[2] - win_r_before[2]: win_center[2] + win_r_after[2], :] = \
                dvf_np[0, predicted_begin[0]:predicted_end[0], predicted_begin[1]:predicted_end[1], predicted_begin[2]:predicted_end[2], :]
            # rescaling dvf based on the voxel spacing:
            spacing_ref = [1.0 * stage for _ in range(3)]
            spacing_current = pyr['fixed_im_s' + str(stage) +
                                  '_sitk'].GetSpacing()
            for dim in range(3):
                pyr['DVF_s' + str(stage)][:, :, :, dim] = pyr['DVF_s' + str(
                    stage)][:, :, :,
                            dim] * spacing_current[dim] / spacing_ref[dim]

        pyr['DVF_s' + str(stage) + '_sitk'] = ip.array_to_sitk(
            pyr['DVF_s' + str(stage)],
            im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'],
            is_vector=True)

        if i_stage == (len(setting['ImagePyramidSchedule']) - 1):
            # when all stages are finished, final dvf and moved image are written
            dvf_composed_final_sitk = 'DVF_s' + str(stage) + '_composed_sitk'
            if len(setting['ImagePyramidSchedule']) == 1:
                # need to upsample in the case that last stage is not 1
                if stage == 1:
                    pyr[dvf_composed_final_sitk] = pyr['DVF_s' + str(stage) +
                                                       '_sitk']
                else:
                    pyr[dvf_composed_final_sitk] = ip.resampler_sitk(
                        pyr['DVF_s' + str(stage) + '_sitk'],
                        scale=1 / stage,
                        im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(),
                        interpolator=sitk.sitkLinear)
            else:
                pyr[dvf_composed_final_sitk] = sitk.Add(
                    pyr['DVF_s' + str(setting['ImagePyramidSchedule'][-2]) +
                        '_composed_up_sitk'],
                    pyr['DVF_s' + str(stage) + '_sitk'])
                if stage != 1:
                    pyr[dvf_composed_final_sitk] = ip.resampler_sitk(
                        pyr[dvf_composed_final_sitk],
                        scale=1 / stage,
                        im_ref_size=pyr['fixed_im_s1_sitk'].GetSize(),
                        interpolator=sitk.sitkLinear)
            if not setting['WriteNoDVF']:
                sitk.WriteImage(
                    sitk.Cast(pyr[dvf_composed_final_sitk],
                              sitk.sitkVectorFloat32),
                    su.address_generator(setting,
                                         'dvf_s0',
                                         pair_info=pair_info,
                                         stage_list=stage_list))
            dvf_t = sitk.DisplacementFieldTransform(
                pyr[dvf_composed_final_sitk])
            pyr['moved_im_s0_sitk'] = ip.resampler_by_transform(
                pyr['moving_im_s1_sitk'],
                dvf_t,
                default_pixel_value=setting['data'][
                    pair_info[1]['data']]['DefaultPixelValue'])
            sitk.WriteImage(
                sitk.Cast(pyr['moved_im_s0_sitk'],
                          setting['data'][pair_info[1]['data']]['ImageByte']),
                moved_im_s0_address)

            if setting['WriteMasksForLSTM']:
                mask_to_zero_stage = setting['network_dict'][
                    'stage' + str(stage)]['MaskToZero']
                if setting['read_pair_mode'] == 'synthetic':
                    moving_mask_sitk = sitk.ReadImage(
                        su.address_generator(setting,
                                             mask_to_zero_stage,
                                             stage=1,
                                             **im_info_su))
                    moved_mask_stage1 = ip.resampler_by_transform(
                        moving_mask_sitk,
                        dvf_t,
                        default_pixel_value=0,
                        interpolator=sitk.sitkNearestNeighbor)
                    sitk.WriteImage(
                        moved_mask_stage1,
                        su.address_generator(setting,
                                             'Moved' + mask_to_zero_stage +
                                             '_AG',
                                             stage=1,
                                             **im_info_su))
                    logging.debug('writing ' +
                                  su.address_generator(setting,
                                                       'Moved' +
                                                       mask_to_zero_stage +
                                                       '_AG',
                                                       stage=1,
                                                       **im_info_su))

    time_after_dvf = time.time()
    logging.debug(
        pair_info[0]['data'] + ', CN{}, ImType{} is done in {:.2f}s '.format(
            pair_info[0]['cn'], pair_info[0]['type_im'], time_after_dvf -
            time_before_dvf))

    return 0
Example #17
0
def load_atlas_custom_images(wdpath):
    # params_list = list(data_batch.items())
    # print(params_list[0] )
    t1w_list = []
    t2w_list = []
    gt_label_list = []
    brain_mask_list = []
    transform_list = []

    #Load the train labels_native with their transform
    for dirpath, subdirs, files in os.walk(wdpath):
        # print("dirpath", dirpath)
        # print("subdirs", subdirs)
        # print("files", files)
        for x in files:
            if x.endswith("T1native.nii.gz"):
                t1w_list.append(sitk.ReadImage(os.path.join(dirpath, x)))
            elif x.endswith("T2native.nii.gz"):
                t2w_list.append(sitk.ReadImage(os.path.join(dirpath, x)))
            elif x.endswith("labels_native.nii.gz"):
                gt_label_list.append(sitk.ReadImage(os.path.join(dirpath, x)))
            elif x.endswith("Brainmasknative.nii.gz"):
                brain_mask_list.append(sitk.ReadImage(os.path.join(dirpath,
                                                                   x)))
            elif x.endswith("affine.txt"):
                transform_list.append(
                    sitk.ReadTransform(os.path.join(dirpath, x)))
            # else:
            #     print("Problem in CustomAtlas in folder", dirpath)

    #Resample and thershold to get the label
    white_matter_list = []
    grey_matter_list = []
    hippocampus_list = []
    amygdala_list = []
    thalamus_list = []
    for i in range(0, len(gt_label_list)):
        resample_img = sitk.Resample(gt_label_list[i], atlas_t1,
                                     transform_list[i],
                                     sitk.sitkNearestNeighbor, 0,
                                     gt_label_list[i].GetPixelIDValue())
        white_matter_list.append(sitk.Threshold(resample_img, 1, 1, 0))
        grey_matter_list.append(sitk.Threshold(resample_img, 2, 2, 0))
        hippocampus_list.append(sitk.Threshold(resample_img, 3, 3, 0))
        amygdala_list.append(sitk.Threshold(resample_img, 4, 4, 0))
        thalamus_list.append(sitk.Threshold(resample_img, 5, 5, 0))

    #Save each label from first data
    path_to_save = '../bin/custom_atlas_result/'
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)
    sitk.WriteImage(hippocampus_list[0],
                    os.path.join(path_to_save, 'Hippocampus_label.nii'), True)
    sitk.WriteImage(white_matter_list[0],
                    os.path.join(path_to_save, 'White_matter_label.nii'), True)
    sitk.WriteImage(grey_matter_list[0],
                    os.path.join(path_to_save, 'Grey_matter_label.nii'), True)
    sitk.WriteImage(amygdala_list[0],
                    os.path.join(path_to_save, 'Amygdala_label.nii'), True)
    sitk.WriteImage(thalamus_list[0],
                    os.path.join(path_to_save, 'Thalamus_label.nii'), True)

    #Save an image resampled to show segmentation
    sitk.WriteImage(gt_label_list[0],
                    os.path.join(path_to_save, 'Train_image_1_resampled.nii'),
                    True)

    # sum them up and divide by their number of images to make a probability map
    white_matter_map = 0
    grey_matter_map = 0
    hippocampus_map = 0
    amygdala_map = 0
    thalamus_map = 0
    for i in range(1, len(gt_label_list)):
        white_matter_map = sitk.Add(white_matter_map, white_matter_list[i])
        grey_matter_map = sitk.Add(grey_matter_map, grey_matter_list[i])
        hippocampus_map = sitk.Add(hippocampus_map, hippocampus_list[i])
        amygdala_map = sitk.Add(amygdala_map, amygdala_list[i])
        thalamus_map = sitk.Add(thalamus_map, thalamus_list[i])

    white_matter_map = sitk.Divide(white_matter_map, len(white_matter_list))
    grey_matter_map = sitk.Divide(grey_matter_map, len(grey_matter_list))
    hippocampus_map = sitk.Divide(hippocampus_map, len(hippocampus_list))
    amygdala_map = sitk.Divide(amygdala_map, len(amygdala_list))
    thalamus_map = sitk.Divide(thalamus_map, len(thalamus_list))
    #atlas = sitk.Divide(sum_images, len(test_resample))
    #slice = sitk.GetArrayFromImage(atlas)[90,:,:]
    #plt.imshow(slice)

    #Register without threshold
    path_to_save = '../bin/custom_atlas_result/'
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)
    sitk.WriteImage(
        grey_matter_map,
        os.path.join(path_to_save, 'grey_matter_map_no_threshold.nii'), True)
    sitk.WriteImage(
        white_matter_map,
        os.path.join(path_to_save, 'white_matter_map_no_threshold.nii'), True)
    sitk.WriteImage(
        hippocampus_map,
        os.path.join(path_to_save, 'hippocampus_map_no_threshold.nii'), True)
    sitk.WriteImage(
        amygdala_map,
        os.path.join(path_to_save, 'amygdala_map_no_threshold.nii'), True)
    sitk.WriteImage(
        thalamus_map,
        os.path.join(path_to_save, 'thalamus_map_no_threshold.nii'), True)

    #Threhold the 5 different maps to get a binary map
    white_matter_map = sitk.BinaryThreshold(white_matter_map, 0.3, 1, 1, 0)
    grey_matter_map = sitk.BinaryThreshold(grey_matter_map, 0.6, 2, 2, 0)
    hippocampus_map = sitk.BinaryThreshold(hippocampus_map, 0.9, 3, 3, 0)
    amygdala_map = sitk.BinaryThreshold(amygdala_map, 1.2, 4, 4, 0)
    thalamus_map = sitk.BinaryThreshold(thalamus_map, 1.5, 5, 5, 0)

    #Save the images
    path_to_save = '../bin/custom_atlas_result/'
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)
    sitk.WriteImage(grey_matter_map,
                    os.path.join(path_to_save, 'grey_matter_map.nii'), True)
    sitk.WriteImage(white_matter_map,
                    os.path.join(path_to_save, 'white_matter_map.nii'), True)
    sitk.WriteImage(hippocampus_map,
                    os.path.join(path_to_save, 'hippocampus_map.nii'), True)
    sitk.WriteImage(amygdala_map, os.path.join(path_to_save,
                                               'amygdala_map.nii'), True)
    sitk.WriteImage(thalamus_map, os.path.join(path_to_save,
                                               'thalamus_map.nii'), True)

    # Load the test labels_native and their transform
    path_to_test = '../data/test'
    test_gt_label_list = []
    test_transform_list = []

    for dirpath, subdirs, files in os.walk(path_to_test):
        for x in files:
            if x.endswith("labels_native.nii.gz"):
                test_gt_label_list.append(
                    sitk.ReadImage(os.path.join(dirpath, x)))
            if x.endswith("affine.txt"):
                test_transform_list.append(
                    sitk.ReadTransform(os.path.join(dirpath, x)))

    #Resample the labels_native with the transform
    test_resample_img = []
    for i in range(0, len(test_gt_label_list)):
        resample_img = sitk.Resample(test_gt_label_list[i], atlas_t1,
                                     test_transform_list[i],
                                     sitk.sitkNearestNeighbor, 0,
                                     test_gt_label_list[i].GetPixelIDValue())

        test_resample_img.append(resample_img)

    sitk.WriteImage(test_resample_img[0],
                    os.path.join(path_to_save, 'Test_data_1_resampled.nii'),
                    True)

    # Save the first test patient labels
    # path_to_save = '../bin/temp_test_result/'
    # if not os.path.exists(path_to_save):
    #     os.makedirs(path_to_save)
    # sitk.WriteImage(test_resample_img[0], os.path.join(path_to_save, 'FirstPatienFromTestList.nii'), False)

    #Compute the dice coeefficent (and the Hausdorff distance)
    label_list = [
        'White Matter', 'Grey Matter', 'Hippocampus', 'Amygdala', 'Thalamus'
    ]
    map_list = [
        white_matter_map, grey_matter_map, hippocampus_map, amygdala_map,
        thalamus_map
    ]
    dice_list = []

    path_to_save = '../bin/DiceTestResult/'
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)
    for i in range(0, 5):
        evaluator = eval_.Evaluator(eval_.ConsoleEvaluatorWriter(5))
        evaluator.metrics = [
            metric.DiceCoefficient(),
            metric.HausdorffDistance()
        ]
        evaluator.add_writer(
            eval_.CSVEvaluatorWriter(
                os.path.join(path_to_save,
                             'DiceResults_' + label_list[i] + '.csv')))
        evaluator.add_label(i + 1, label_list[i])
        for j in range(0, len(test_resample_img)):
            evaluator.evaluate(test_resample_img[j], map_list[i],
                               'Patient ' + str(j))

    print("END Custom loadAtlas")
Example #18
0
def multi_stage(setting, network_dict, pair_info, overwrite=False):
    """
    :param setting:
    :param network_dict:
    :param pair_info: information of the pair to be registered.
    :param overwrite:
    :return: The output moved images and dvf will be written to the disk.
             1: registration is performed correctly
             2: skip overwriting
    """
    stage_list = setting['ImagePyramidSchedule']
    final_moved_image_address = su.address_generator(setting, 'moved_image', pair_info=pair_info, stage=0, stage_list=stage_list)
    if os.path.isfile(final_moved_image_address):
        if not overwrite:
            print('overwrite=False, file '+final_moved_image_address+' already exists, skipping .....')
            return 2
        else:
            print('overwrite=True, file '+final_moved_image_address+' already exists, but overwriting .....')

    pair_stage1 = real_pair.Images(setting, pair_info, stage=1)
    pyr = dict()  # pyr: a dictionary of pyramid images
    pyr['fixed_im_s1_sitk'] = pair_stage1.get_fixed_im_sitk()
    pyr['moving_im_s1_sitk'] = pair_stage1.get_moved_im_affine_sitk()
    pyr['fixed_im_s1'] = pair_stage1.get_fixed_im()
    pyr['moving_im_s1'] = pair_stage1.get_moved_im_affine()
    if setting['torsoMask']:
        pyr['fixed_torso_s1_sitk'] = pair_stage1.get_fixed_torso_sitk()
        pyr['moving_torso_s1_sitk'] = pair_stage1.get_moved_torso_affine_sitk()
    if not (os.path.isdir(su.address_generator(setting, 'full_reg_folder', pair_info=pair_info, stage_list=stage_list))):
        os.makedirs(su.address_generator(setting, 'full_reg_folder', pair_info=pair_info, stage_list=stage_list))

    time_before_dvf = time.time()
    for i_stage, stage in enumerate(setting['ImagePyramidSchedule']):
        if stage != 1:
            pyr['fixed_im_s'+str(stage)+'_sitk'] = ip.downsampler_gpu(pyr['fixed_im_s1_sitk'], stage,
                                                                      default_pixel_value=setting['data'][pair_info[0]['data']]['defaultPixelValue'])
            pyr['moving_im_s'+str(stage)+'_sitk'] = ip.downsampler_gpu(pyr['moving_im_s1_sitk'], stage,
                                                                       default_pixel_value=setting['data'][pair_info[1]['data']]['defaultPixelValue'])
        if setting['torsoMask']:
            pyr['fixed_torso_s'+str(stage)+'_sitk'] = ip.downsampler_sitk(pyr['fixed_torso_s1_sitk'],
                                                                          stage,
                                                                          im_ref=pyr['fixed_im_s' + str(stage) + '_sitk'],
                                                                          default_pixel_value=0,
                                                                          interpolator=sitk.sitkNearestNeighbor)
            pyr['moving_torso_s'+str(stage)+'_sitk'] = ip.downsampler_sitk(pyr['moving_torso_s1_sitk'],
                                                                           stage,
                                                                           im_ref=pyr['moving_im_s' + str(stage) + '_sitk'],
                                                                           default_pixel_value=0,
                                                                           interpolator=sitk.sitkNearestNeighbor)
        else:
            pyr['fixed_torso_s'+str(stage)+'_sitk'] = None
            pyr['moving_torso_s'+str(stage)+'_sitk'] = None
        input_regnet_moving_torso = None
        if i_stage == 0:
            input_regnet_moving = 'moving_im_s'+str(stage)+'_sitk'
            if setting['torsoMask']:
                input_regnet_moving_torso = 'moving_torso_s'+str(stage)+'_sitk'
        else:
            previous_pyramid = setting['ImagePyramidSchedule'][i_stage - 1]
            dvf_composed_previous_up_sitk = 'DVF_s'+str(previous_pyramid)+'_composed_up_sitk'
            dvf_composed_previous_sitk = 'DVF_s'+str(previous_pyramid)+'_composed_sitk'
            if i_stage == 1:
                pyr[dvf_composed_previous_sitk] = pyr['DVF_s'+str(setting['ImagePyramidSchedule'][i_stage-1])+'_sitk']
            elif i_stage > 1:
                pyr[dvf_composed_previous_sitk] = sitk.Add(pyr['DVF_s'+str(setting['ImagePyramidSchedule'][i_stage-2])+'_composed_up_sitk'],
                                                           pyr['DVF_s'+str(setting['ImagePyramidSchedule'][i_stage - 1])+'_sitk'])
            pyr[dvf_composed_previous_up_sitk] = ip.upsampler_gpu(pyr[dvf_composed_previous_sitk],
                                                                  round(previous_pyramid/stage),
                                                                  dvf_output_size=pyr['fixed_im_s'+str(stage)+'_sitk'].GetSize()[::-1],
                                                                  )
            if setting['WriteAfterEachStage']:
                sitk.WriteImage(sitk.Cast(pyr[dvf_composed_previous_up_sitk], sitk.sitkVectorFloat32),
                                su.address_generator(setting, 'dvf_s_up', pair_info=pair_info, stage=previous_pyramid, stage_list=stage_list))

            dvf_t = sitk.DisplacementFieldTransform(pyr[dvf_composed_previous_up_sitk])
            # after this line DVF_composed_previous_up_sitk is converted to a transform. so we need to load it again.
            pyr['moved_im_s'+str(stage)+'_sitk'] = ip.resampler_by_dvf(pyr['moving_im_s' + str(stage)+'_sitk'],
                                                                       dvf_t,
                                                                       default_pixel_value=setting['data'][pair_info[1]['data']]['defaultPixelValue'])
            if setting['torsoMask']:
                pyr['moved_torso_s'+str(stage)+'_sitk'] = ip.resampler_by_dvf(pyr['moving_torso_s'+str(stage)+'_sitk'],
                                                                              dvf_t,
                                                                              default_pixel_value=0,
                                                                              interpolator=sitk.sitkNearestNeighbor)
            pyr[dvf_composed_previous_up_sitk] = dvf_t.GetDisplacementField()
            if setting['WriteAfterEachStage']:
                sitk.WriteImage(sitk.Cast(pyr['moved_im_s'+str(stage)+'_sitk'], setting['data'][pair_info[1]['data']]['imageByte']),
                                su.address_generator(setting, 'moved_image', pair_info=pair_info, stage=stage, stage_list=stage_list))
            input_regnet_moving = 'moved_im_s'+str(stage)+'_sitk'
            if setting['torsoMask']:
                input_regnet_moving_torso = 'moved_torso_s'+str(stage)+'_sitk'

        pyr['DVF_s'+str(stage)] = np.zeros(np.r_[pyr['fixed_im_s'+str(stage)+'_sitk'].GetSize()[::-1], 3], dtype=np.float64)
        pair_pyramid = real_pair.Images(setting, pair_info, stage=stage,
                                        fixed_im_sitk=pyr['fixed_im_s'+str(stage)+'_sitk'],
                                        moved_im_affine_sitk=pyr[input_regnet_moving],
                                        fixed_torso_sitk=pyr['fixed_torso_s'+str(stage)+'_sitk'],
                                        moved_torso_affine_sitk=pyr[input_regnet_moving_torso]
                                        )

        # building and loading network
        tf.reset_default_graph()
        setting['R'] = network_dict['Stage'+str(stage)]['R']    # Radius of normal resolution patch size. Total size is (2*R +1)
        setting['Ry'] = network_dict['Stage'+str(stage)]['Ry']  # Radius of output. Total size is (2*Ry +1)
        setting['Ry_erode'] = network_dict['Stage'+str(stage)]['Ry_erode']  # at the test time, sometimes there are some problems at the border
        images_tf = tf.placeholder(tf.float32,
                                   shape=[None, 2 * setting['R'] + 1, 2 * setting['R'] + 1, 2 * setting['R'] + 1, 2],
                                   name="Images")
        bn_training = tf.placeholder(tf.bool, name='bn_training')
        x_fixed = images_tf[:, :, :, :, 0, np.newaxis]
        x_deformed = images_tf[:, :, :, :, 1, np.newaxis]
        dvf_tf = getattr(RegNetModel, network_dict['Stage'+str(stage)]['NetworkDesign'])(x_fixed, x_deformed, bn_training)
        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        logging.debug(' Total number of variables %s' % (np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, su.address_generator(setting, 'saved_model_with_step',
                                                 current_experiment=network_dict['Stage'+str(stage)]['NetworkLoad'],
                                                 step=network_dict['Stage'+str(stage)]['GlobalStepLoad']))
        while not pair_pyramid.get_sweep_completed():
            # The pyr[DVF_S] is the numpy DVF which will be filled. the dvf_np is an output
            # path from the network. We control the spatial location of both dvf in this function
            batch_im, win_center, win_r_before, win_r_after, predicted_begin, predicted_end = pair_pyramid.next_sweep_patch()
            time_before_gpu = time.time()
            [dvf_np] = sess.run([dvf_tf], feed_dict={images_tf: batch_im, bn_training: 0})
            time_after_gpu = time.time()
            logging.debug('GPU: Data='+pair_info[0]['data']+' CN = {} center = {} is done in {:.2f}s '.format(
                pair_info[0]['cn'], win_center, time_after_gpu - time_before_gpu))

            pyr['DVF_s'+str(stage)][win_center[0] - win_r_before[0]: win_center[0] + win_r_after[0],
                                    win_center[1] - win_r_before[1]: win_center[1] + win_r_after[1],
                                    win_center[2] - win_r_before[2]: win_center[2] + win_r_after[2], :] = \
                dvf_np[0, predicted_begin[0]:predicted_end[0], predicted_begin[1]:predicted_end[1], predicted_begin[2]:predicted_end[2], :]

        pyr['DVF_s'+str(stage)+'_sitk'] = ip.array_to_sitk(pyr['DVF_s' + str(stage)],
                                                           im_ref=pyr['fixed_im_s'+str(stage)+'_sitk'],
                                                           is_vector=True)

        if i_stage == (len(setting['ImagePyramidSchedule'])-1):
            # when all stages are finished, final dvf and moved image are written
            dvf_composed_final_sitk = 'DVF_s'+str(stage)+'_composed_sitk'
            if len(setting['ImagePyramidSchedule']) == 1:
                pyr[dvf_composed_final_sitk] = pyr['DVF_s'+str(stage)+'_sitk']
            else:
                pyr[dvf_composed_final_sitk] = sitk.Add(pyr['DVF_s'+str(setting['ImagePyramidSchedule'][-2])+'_composed_up_sitk'],
                                                        pyr['DVF_s'+str(stage)+'_sitk'])
            sitk.WriteImage(sitk.Cast(pyr[dvf_composed_final_sitk], sitk.sitkVectorFloat32),
                            su.address_generator(setting, 'dvf_s0', pair_info=pair_info, stage_list=stage_list))
            dvf_t = sitk.DisplacementFieldTransform(pyr[dvf_composed_final_sitk])
            pyr['moved_im_s0_sitk'] = ip.resampler_by_dvf(pyr['moving_im_s'+str(stage)+'_sitk'], dvf_t,
                                                          default_pixel_value=setting['data'][pair_info[1]['data']]['defaultPixelValue'])
            sitk.WriteImage(sitk.Cast(pyr['moved_im_s0_sitk'],
                                      setting['data'][pair_info[1]['data']]['imageByte']),
                            su.address_generator(setting, 'moved_image', pair_info=pair_info, stage=0, stage_list=stage_list))
    time_after_dvf = time.time()
    logging.debug('Data='+pair_info[0]['data']+' CN = {} ImType = {} is done in {:.2f}s '.format(
        pair_info[0]['cn'], pair_info[0]['type_im'],time_after_dvf - time_before_dvf))

    return 1
def evaluate_segmentation_performance_repeatability_holdout(
        testcases, dataset):
    """
    Evaluates lesion detection and segmentation performance of the cross validation set. 
    Also evaluates repeatability of lesion segmentation in terms of dice similarity coefficient

    testcases : filenames of testcases 
    dataset: b-value settings or the dataset name

    returns: a tuple (Mean and standard deviation of dice of first scan, 
                        Mean, standard deviation of dice of second scan,
                        Mean, standard deviation of dice between scans (repeatability)
                        # hits (No of lesions detected), # misses, # false positives for scan1,
                        # # hits (No of lesions detected), # misses, # false positives for scan2,
                        # agreement and disagreements between the network.)

    """

    dices = []
    h1 = 0
    h2 = 0
    h3 = 0

    f1 = 0
    f2 = 0
    f3 = 0

    m1 = 0
    m2 = 0
    m3 = 0

    for case in testcases:

        probs1 = None
        probs2 = None

        gtpath1 = fr"outputs\segmentations\{dataset}\1_0\{case}\gt.nii.gz"
        gt1 = sitk.ReadImage(gtpath1)
        gt1 = sitk.GetArrayFromImage(gt1)
        gt1[gt1 == 1] = 0
        gt1 = sitk.GetImageFromArray(gt1)
        gt1 = DataUtil.convert2binary(gt1)

        for cv in range(3):

            probpath1 = fr"outputs\segmentations\{dataset}\1_{cv}\{case}\prob.nii.gz"
            probpath2 = fr"outputs\segmentations\{dataset}\2_{cv}\{case}\prob.nii.gz"

            probs1_ = DataUtil.convert2binary(sitk.ReadImage(probpath1))
            probs2_ = DataUtil.convert2binary(sitk.ReadImage(probpath2))

            probs1 = probs1_ if probs1 is None else sitk.Add(probs1, probs1_)
            probs2 = probs2_ if probs2 is None else sitk.Add(probs2, probs2_)

        probs1 = filterSegmentation(probs1)
        probs2 = filterSegmentation(probs2)

        probs1 = removeSmallLesions(probs1)
        probs2 = removeSmallLesions(probs2)

        dice1, hits1, misses1, fps1 = get_dice_repeatability(gt1, probs1)
        dice2, hits2, misses2, fps2 = get_dice_repeatability(gt1, probs2)
        dice3, hits3, misses3, fps3 = get_dice_repeatability(probs1, probs2)

        dices.append((dice1, dice2, dice3))

        h1 += hits1
        m1 += misses1
        f1 += fps1

        h2 += hits2
        m2 += misses2
        f2 += fps2

        h3 += hits3
        m3 += misses3
        f3 += fps3

    dice1, dice2, dice3 = zip(*dices)

    dice1 = [y for x in dice1 if x is not None for y in x]
    dice2 = [y for x in dice2 if x is not None for y in x]
    dice3 = [y for x in dice3 if x is not None for y in x]

    print(np.mean(dice1), np.std(dice1))
    print(np.mean(dice2), np.std(dice2))
    print(np.mean(dice3), np.std(dice3))

    print(h1, m1, f1)
    print(h2, m2, f2)
    print(h3, m3, f3)

    return ((np.mean(dice1), np.std(dice1)), (np.mean(dice2), np.std(dice2)),
            (np.mean(dice3), np.std(dice3)), (h1, m1, f1), (h2, m2,
                                                            f2), (h3, m3 + f3))