Beispiel #1
0
            def testGetHyperslab(self):
                """hyperslab should be same as slice from data array"""

                inputFile = inputFile_ushort
                #inputFile='/export01/data/vfonov/src1/minc2-simple/python/test_icbm.mnc'

                v = minc2_file(inputFile)
                v.setup_standard_order()
                sliceFromData_x = v.data[10, :, :]
                sliceFromData_y = v.data[:, 10, :]
                sliceFromData_z = v.data[:, :, 10]
                v.close()

                b = minc2_file(inputFile)
                b.setup_standard_order()
                hyperslab_x = b.load_hyperslab_t([10, None, None]).squeeze()
                hyperslab_y = b.load_hyperslab_t([None, 10, None]).squeeze()
                hyperslab_z = b.load_hyperslab_t([None, None, 10]).squeeze()
                b.close()

                self.assertEqual(
                    torch.mean((sliceFromData_x - hyperslab_x)**2), 0.0)
                self.assertEqual(
                    torch.mean((sliceFromData_y - hyperslab_y)**2), 0.0)
                self.assertEqual(
                    torch.mean((sliceFromData_z - hyperslab_z)**2), 0.0)
Beispiel #2
0
    def testSetHyperslabFloat(self):
        """setting hyperslab should change underlying volume (float)"""

        # read some data from somwhere
        v = minc2_file(inputFile_ushort)
        dims = v.store_dims()
        v.setup_standard_order()
        # use the whole volume
        hyperslab_a = v.data[10, :, :]
        hyperslab_a_ = v[10, :, :]
        v.close()

        print("Hyperslab:", hyperslab_a.shape)
        print("Hyperslab2:", hyperslab_a_.shape)
        print("dims:", dims)

        v2 = minc2_file()
        v2.define(dims, 'float32', 'float32')
        v2.create(outputFilename)
        v2.setup_standard_order()

        # because we are saving float32 , we don't need slice normalization
        v2.save_hyperslab(hyperslab_a, [10, None, None])
        v2.close()

        v3 = minc2_file(outputFilename)
        hyperslab_b = v3.load_hyperslab([10, None, None])

        print(N.average((hyperslab_a - hyperslab_b)**2))

        self.assertEqual(N.average((hyperslab_a - hyperslab_b)**2), 0.0)
        v3.close()
Beispiel #3
0
    def testSetSliceFloat(self):
        """volume slice setting should change underlying volume (float)"""

        # read some data from somwhere
        v = minc2_file(inputFile_ushort)
        dims = v.store_dims()
        v.setup_standard_order()
        hyperslab_a = v.data[10, :, :]
        v.close()

        v2 = minc2_file()
        v2.define(dims, 'float32', 'float32')
        v2.create(outputFilename)
        v2.setup_standard_order()

        # because we are saving float32 , we don't need slice normalization
        v2[10, :, :] = hyperslab_a
        v2.close()

        v3 = minc2_file(inputFile_ushort)
        v3.setup_standard_order()
        hyperslab_b = v3[10, :, :]
        v3.close()

        self.assertEqual(N.average((hyperslab_a - hyperslab_b)**2), 0.0)
Beispiel #4
0
            def testSetHyperslabShort(self):
                """setting hyperslab should change underlying volume (short)"""

                # read some data from somwhere
                v = minc2_file(inputFile_ushort)
                dims = v.store_dims()
                v.setup_standard_order()
                hyperslab_a = v.load_hyperslab_t([10, None, None])

                # try with normalization
                v2 = minc2_file()
                v2.define(dims, 'uint16', 'float32')  # , global_scaling=True
                v2.create(outputFilename)
                v2.set_volume_range(torch.min(hyperslab_a),
                                    torch.max(hyperslab_a))
                v2.setup_standard_order()

                # have to set slice normalization
                v2.save_hyperslab_t(hyperslab_a, [10, None, None])
                hyperslab_b = v2.load_hyperslab_t([10, None, None])

                # compare results
                self.assertAlmostEqual(
                    torch.mean((hyperslab_a - hyperslab_b)**2).item(), 0.0, 8)
                v2.close()
                v.close()
Beispiel #5
0
 def testHyperslabArray(self):
     """hyperslab should be reinsertable into volume"""
     if False:
         v = minc2_file(inputFile_ushort)
         v2 = minc2_file()
         v2.create(outputFilename)
         v2.close()
         v.close()
def save_labels(outfile, reference, data, history=None):
    # TODO: add history
    ref = minc2_file(reference)
    out = minc2_file()
    out.define(ref.store_dims(), minc2_file.MINC2_BYTE, minc2_file.MINC2_INT)
    out.create(outfile)
    out.copy_metadata(ref)
    out.setup_standard_order()
    out.save_complete_volume(data)
Beispiel #7
0
 def testVectorRead(self):
     """make sure that a vector file can be read correctly"""
     v = minc2_file(inputVector)
     v.setup_standard_order()
     dims = v.representation_dims()
     self.assertEqual(dims[0].id, minc2_file.MINC2_DIM_VEC)
     v.close()
Beispiel #8
0
    def testNonDefaultDirCos3DVFF(self):
        """testing reading the direction cosines of a file with non-standard values (volumeFromFile)"""
        v = minc2_file(input3DdirectionCosines)
        v.setup_standard_order()
        dims = v.representation_dims()

        pipe = os.popen(
            "mincinfo -attvalue xspace:direction_cosines %s" %
            input3DdirectionCosines, "r")
        from_file = pipe.read().rstrip().split(" ")
        pipe.close()

        self.assertAlmostEqual(dims[0].dir_cos[0], float(from_file[0]), 8)
        self.assertAlmostEqual(dims[0].dir_cos[1], float(from_file[1]), 8)
        self.assertAlmostEqual(dims[0].dir_cos[2], float(from_file[2]), 8)

        pipe = os.popen(
            "mincinfo -attvalue yspace:direction_cosines %s" %
            input3DdirectionCosines, "r")
        from_file = pipe.read().rstrip().split(" ")
        pipe.close()
        self.assertAlmostEqual(dims[1].dir_cos[0], float(from_file[0]), 8)
        self.assertAlmostEqual(dims[1].dir_cos[1], float(from_file[1]), 8)
        self.assertAlmostEqual(dims[1].dir_cos[2], float(from_file[2]), 8)

        pipe = os.popen(
            "mincinfo -attvalue zspace:direction_cosines %s" %
            input3DdirectionCosines, "r")
        from_file = pipe.read().rstrip().split(" ")
        pipe.close()
        self.assertAlmostEqual(dims[2].dir_cos[0], float(from_file[0]), 8)
        self.assertAlmostEqual(dims[2].dir_cos[1], float(from_file[1]), 8)
        self.assertAlmostEqual(dims[2].dir_cos[2], float(from_file[2]), 8)
Beispiel #9
0
 def testVectorRead2(self):
     """make sure that volume has four dimensions"""
     v = minc2_file(inputVector)
     ndims = v.ndim()
     self.assertEqual(ndims, 4)
     data = v.data
     self.assertEqual(len(data.shape), 4)
     v.close()
Beispiel #10
0
 def testWorldToVoxel(self):
     """testing world_to_voxel conversion of a file with non-standard values"""
     v = minc2_file(input3DdirectionCosines)
     v.setup_standard_order()
     xyz = N.array([50.0, -80.0, -30.0])
     ijk = v.world_to_voxel(xyz)
     self.assertAlmostEqual(ijk[0], -6.362013409627703453, 8)
     self.assertAlmostEqual(ijk[1], 6.6280285942264356436, 8)
     self.assertAlmostEqual(ijk[2], 75.806692060998855709, 8)
Beispiel #11
0
 def testFromFileDataDouble(self):
     """ensure that double data is read correct with a precision of 8 decimals on a call to aveage()"""
     v = minc2_file(inputFile_double)
     a = N.average(v.data)
     v.close()
     pipe = os.popen("mincstats -mean -quiet %s" % inputFile_double, "r")
     output = float(pipe.read())
     pipe.close()
     self.assertAlmostEqual(a, output, 8)
Beispiel #12
0
 def testFromFileDataFloat(self):
     """ensure that float data is read correct with a precision of 8 decimals on a call to aveage()"""
     v = minc2_file(inputFile_float)
     a = N.average(v.load_complete_volume('float64'))
     v.close()
     pipe = os.popen("mincstats -mean -quiet %s" % inputFile_float, "r")
     output = float(pipe.read())
     pipe.close()
     self.assertAlmostEqual(a, output, 8)
Beispiel #13
0
    def testVoxelToWorld(self):
        """testing voxel_to_world conversion of a file with non-standard values"""
        v = minc2_file(input3DdirectionCosines)
        v.setup_standard_order()

        ijk = N.array([10, 20, 30])
        xyz = v.voxel_to_world(ijk)
        self.assertAlmostEqual(xyz[0], -0.32337910608109865507, 8)
        self.assertAlmostEqual(xyz[1], -73.698635237250869068, 8)
        self.assertAlmostEqual(xyz[2], -29.421450173791534155, 8)
Beispiel #14
0
 def testFromFileDataInt(self):
     """ensure that int data is read correct with a precision of 8 decimals on a call to average()"""
     v = minc2_file(inputFile_int)
     a = v.load_complete_volume_tensor(
         'torch.DoubleTensor').mean().item()
     v.close()
     pipe = os.popen("mincstats -mean -quiet %s" % inputFile_int, "r")
     output = float(pipe.read())
     pipe.close()
     self.assertAlmostEqual(a, output, 8)
Beispiel #15
0
            def testSetHyperslabFloat(self):
                """setting hyperslab should change underlying volume (float)"""

                # read some data from somwhere
                v = minc2_file(inputFile_ushort)
                dims = v.store_dims()
                v.setup_standard_order()
                hyperslab_a = v.load_hyperslab_t([10, None, None])

                v2 = minc2_file()
                v2.define(dims, 'float32', 'float32')
                v2.create(outputFilename)
                v2.setup_standard_order()

                # because we are saving float32 , we don't need slice normalization
                v2.save_hyperslab_t(hyperslab_a, [10, None, None])
                hyperslab_b = v2.load_hyperslab_t([10, None, None])
                self.assertEqual(N.average((hyperslab_a - hyperslab_b)**2),
                                 0.0)
                v2.close()
                v.close()
Beispiel #16
0
def merge_segmentations(inputs, output, partition, parameters):
    patch_size = parameters.get('patch_size', 1)
    border = patch_size * 2
    out = None
    strip = None
    for i in range(len(inputs)):
        d = minc2_file(inputs[i]).data
        if out is None:
            out = np.zeros(d.shape, dtype=np.int32)
            strip = d.shape[2] / partition

        beg = strip * i
        end = strip * (i + 1)

        if i == (partition - 1):
            end = d.shape[2]

        out[:, :, beg:end] = d[:, :, beg:end]

    out_i = minc2_file()
    out_i.imitate(inputs[0], path=output)
    out_i.data = out
Beispiel #17
0
    def testSlicingGet(self):
        """volume slice should be same as slice from data array"""

        inputFile = inputFile_ushort

        v = minc2_file(inputFile)
        v.setup_standard_order()
        sliceFromData_x = v.data[10, :, :]
        sliceFromData_y = v.data[:, 10, :]
        sliceFromData_z = v.data[:, :, 10]
        v.close()

        b = minc2_file(inputFile)
        b.setup_standard_order()
        hyperslab_x = b[10, :, :]
        hyperslab_y = b[:, 10, :]
        hyperslab_z = b[:, :, 10]
        b.close()

        self.assertEqual(N.average((sliceFromData_x - hyperslab_x)**2), 0.0)
        self.assertEqual(N.average((sliceFromData_y - hyperslab_y)**2), 0.0)
        self.assertEqual(N.average((sliceFromData_z - hyperslab_z)**2), 0.0)
Beispiel #18
0
    def testDims(self):
        """Check data dimensions are correct"""
        v = minc2_file(inputFile_double)
        v.setup_standard_order()
        dims = v.store_dims()

        self.assertEqual(len(dims), 3)
        # '100', '150', '125'
        self.assertEqual(dims[0].id, minc2_file.MINC2_DIM_X)  ## X
        self.assertEqual(dims[0].length, 125)
        self.assertEqual(dims[1].id, minc2_file.MINC2_DIM_Y)  ## Y
        self.assertEqual(dims[1].length, 150)
        self.assertEqual(dims[2].id, minc2_file.MINC2_DIM_Z)  ## X
        self.assertEqual(dims[2].length, 100)
Beispiel #19
0
def load_minc_images(path, winsorize_low=5, winsorize_high=95):
    from minc2_simple import minc2_file
    import numpy as np

    input_minc = minc2_file(path)
    input_minc.setup_standard_order()

    sz = input_minc.shape

    input_images = [
        input_minc[sz[0] // 2, :, :], input_minc[:, :, sz[2] // 2],
        input_minc[:, sz[1] // 2, :]
    ]

    # normalize between 5 and 95th percentile
    _all_voxels = np.concatenate(tuple((np.ravel(i) for i in input_images)))
    # _all_voxels=input_minc[:,:,:] # this is slower
    _min = np.percentile(_all_voxels, winsorize_low)
    _max = np.percentile(_all_voxels, winsorize_high)
    input_images = [(i - _min) * (1.0 / (_max - _min)) - 0.5
                    for i in input_images]

    # flip, resize and crop
    for i in range(3):
        #
        _scale = min(256.0 / input_images[i].shape[0],
                     256.0 / input_images[i].shape[1])
        # vertical flip and resize
        input_images[i] = transform.rescale(input_images[i][::-1, :],
                                            _scale,
                                            mode='constant',
                                            clip=False,
                                            anti_aliasing=False,
                                            multichannel=False)

        sz = input_images[i].shape
        # pad image
        dummy = np.zeros((256, 256), )
        dummy[int((256 - sz[0]) / 2):int((256 - sz[0]) / 2) + sz[0],
              int((256 - sz[1]) / 2):int((256 - sz[1]) / 2) +
              sz[1]] = input_images[i]

        # crop
        input_images[i] = dummy[16:240, 16:240]

    return [torch.from_numpy(i).float().unsqueeze_(0) for i in input_images]
Beispiel #20
0
    def testWorldToVoxelVec(self):
        """Compare against binary world to voxel"""
        v = minc2_file(input3DdirectionCosines)
        v.setup_standard_order()

        x, y, z = N.meshgrid(N.linspace(-10, 10, 3), N.linspace(0, 20, 3),
                             N.linspace(-5, 15, 3))
        xyz = N.column_stack((N.ravel(x), N.ravel(y), N.ravel(z)))

        ijk = v.world_to_voxel(xyz)

        for i, x in enumerate(xyz):
            pipe = os.popen(
                "worldtovoxel {} {} {} {}".format(input3DdirectionCosines,
                                                  x[0], x[1], x[2]), "r")
            from_file = [float(i) for i in pipe.read().rstrip().split(" ")]
            pipe.close()
            for k in range(3):
                self.assertAlmostEqual(from_file[k], ijk[i, k], 8)
Beispiel #21
0
    def testDefaultDirCos3DVFF(self):
        """testing reading the direction cosines of a file with standard values (volumeFromFile)"""
        v = minc2_file(inputFile_ushort)
        #
        # This file was created without explicitly setting the direction cosines.
        # in that case, the attribute is not set altogether, so we should test
        # for it using the known defaults, because libminc does extract the correct
        # default values
        #
        v.setup_standard_order()
        dims = v.representation_dims()
        self.assertAlmostEqual(dims[0].dir_cos[0], 1.0, 8)
        self.assertAlmostEqual(dims[0].dir_cos[1], 0.0, 8)
        self.assertAlmostEqual(dims[0].dir_cos[2], 0.0, 8)

        self.assertAlmostEqual(dims[1].dir_cos[0], 0.0, 8)
        self.assertAlmostEqual(dims[1].dir_cos[1], 1.0, 8)
        self.assertAlmostEqual(dims[1].dir_cos[2], 0.0, 8)

        self.assertAlmostEqual(dims[2].dir_cos[0], 0.0, 8)
        self.assertAlmostEqual(dims[2].dir_cos[1], 0.0, 8)
        self.assertAlmostEqual(dims[2].dir_cos[2], 1.0, 8)
Beispiel #22
0
    def get_minc_metadata(self, files):
        for f in files:
            # Get all minc headers from the minc2 file.
            m = minc2_file(f)
            meta = m.metadata()

            # We need to convert our minc2-simple dictionary
            # to a version of the dictionary that works with
            # datalad. That means that the keys need to be strings,
            # not bytes, and that we strip out anything that isn't
            # hashable or decodeable, to ensure that we restrict
            # ourselves to headers that can be serialized to json.
            # Note that this only handles headers that are exactly
            # 2 deep.
            strmeta = {}
            for key in meta:
                # Convert the key from bytes to string so that datalad doesn't
                # die on key.startswith.
                kd = key.decode()
                strmeta[kd] = {}
                for subkey in meta[key]:
                    # Do the same for the subkeys.
                    skd = subkey.decode()
                    if meta[key][subkey].__hash__:
                        try:
                            # convert the value to utf-8. If it can't be converted
                            # to utf-8, it can't be serialized to JSON, so isn't indexable
                            # by datalad
                            v = meta[key][subkey]
                            encodedv = meta[key][subkey].decode('utf-8')
                            strmeta[kd][skd] = encodedv
                        except UnicodeDecodeError:
                            lgr.debug("Skipped %s.%s in %s" % (
                                key,
                                subkey,
                                f,
                            ))
            yield f, strmeta
Beispiel #23
0
def load_image(infile):
    #with minc2_file(infile) as m:
    m = minc2_file(infile)
    m.setup_standard_order()
    data = m.load_complete_volume(minc2_file.MINC2_FLOAT)
    return data
Beispiel #24
0
from minc2_simple import minc2_xfm

import sys
import numpy as np


if __name__ == "__main__":
    
    if len(sys.argv)<3:
        print("Usage: {} input.mnc output.mnc".format(sys.argv[0]))
        sys.exit(1)
    
    infile=sys.argv[1]
    outfile=sys.argv[2]
    
    m=minc2_file(infile)
    o=minc2_file()
    # will create file with same dimensions
    o.define(m.store_dims(), minc2_file.MINC2_BYTE, minc2_file.MINC2_FLOAT)
    
    print("Will create new volume...")
    o.create(outfile)
    
    meta=m.metadata()
    
    print("Metadata:")
    print(repr(meta))
    
    print("History:")
    print(m.read_attribute("","history"))
    
Beispiel #25
0
def qc(
        input,
        output,
        image_range=None,
        mask=None,
        mask_range=None,
        title=None,
        image_cmap='gray',
        mask_cmap='red',
        samples=5,
        mask_bg=None,
        use_max=False,
        use_over=False,
        show_image_bar=False,  # TODO:implement this?
        show_overlay_bar=False,
        dpi=100,
        ialpha=0.8,
        oalpha=0.2,
        format=None):
    """QC image generation, drop-in replacement for minc_qc.pl
    Arguments:
        input -- input minc file
        output -- output QC graphics file 
        
    Keyword arguments:
        image_range -- (optional) intensity range for image
        mask  -- (optional) input mask file
        mask_range -- (optional) mask file range
        title  -- (optional) QC title
        image_cmap -- (optional) color map name for image, 
                       possibilities: red, green,blue and anything from matplotlib
        mask_cmap -- (optional) color map for mask, default red
        samples -- number of slices to show , default 5
        mask_bg  -- (optional) level for mask to treat as background
        use_max -- (optional) use 'max' colour mixing
        use_over -- (optional) use 'over' colour mixing
        show_image_bar -- show color bar for intensity range, default false
        show_overlay_bar  -- show color bar for mask intensity range, default false
        dpi -- graphics file DPI, default 100
        ialpha -- alpha channel for colour mixing of main image
        oalpha -- alpha channel for colour mixing of mask image
    """

    #_img=minc.Image(input)
    #_idata=_img.data
    _img = minc2_file(input)
    _img.setup_standard_order()
    _idata = _img.load_complete_volume(minc2_file.MINC2_FLOAT)
    _idims = _img.representation_dims()

    data_shape = _idata.shape
    spacing = [_idims[0].step, _idims[1].step, _idims[2].step]

    _ovl = None
    _odata = None
    omin = 0
    omax = 1

    if mask is not None:
        _ovl = minc2_file(mask)
        _ovl.setup_standard_order()
        _ovl_data = _ovl.load_complete_volume(minc2_file.MINC2_FLOAT)
        if _ovl_data.shape != data_shape:
            raise "Overlay shape does not match image!\nOvl={} Image={}".format(
                repr(_ovl_data.shape), repr(data_shape))
        if mask_range is None:
            omin = np.nanmin(_ovl_data)
            omax = np.nanmax(_ovl_data)
        else:
            omin = mask_range[0]
            omax = mask_range[1]
        _odata = _ovl_data

        if mask_bg is not None:
            _odata = ma.masked_less(_odata, mask_bg)

    slices = []

    # setup ranges
    vmin = vmax = 0.0
    if image_range is not None:
        vmin = image_range[0]
        vmax = image_range[1]
    else:
        vmin = np.nanmin(_idata)
        vmax = np.nanmax(_idata)

    cm = copy.copy(plt.get_cmap(image_cmap))
    cmo = copy.copy(plt.get_cmap(mask_cmap))
    cmo.set_bad('k', alpha=0.0)

    cNorm = colors.Normalize(vmin=vmin, vmax=vmax)
    oNorm = colors.Normalize(vmin=omin, vmax=omax)

    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
    oscalarMap = cmx.ScalarMappable(norm=oNorm, cmap=cmo)
    aspects = []

    # axial slices
    for j in range(0, samples):
        i = int((data_shape[0] / samples) * j + (data_shape[0] % samples) / 2)
        si = scalarMap.to_rgba(_idata[i, :, :])

        if _ovl is not None:
            so = oscalarMap.to_rgba(_odata[i, :, :])
            if use_max: si = max_blend(si, so)
            elif use_over: si = over_blend(si, so, ialpha, oalpha)
            else: si = alpha_blend(si, so, ialpha, oalpha)
        slices.append(si)
        aspects.append(spacing[0] / spacing[1])
    # coronal slices
    for j in range(0, samples):
        i = int((data_shape[1] / samples) * j + (data_shape[1] % samples) / 2)
        si = scalarMap.to_rgba(_idata[:, i, :])

        if _ovl is not None:
            so = oscalarMap.to_rgba(_odata[:, i, :])
            if use_max: si = max_blend(si, so)
            elif use_over: si = over_blend(si, so, ialpha, oalpha)
            else: si = alpha_blend(si, so, ialpha, oalpha)
        slices.append(si)
        aspects.append(spacing[2] / spacing[0])

    # sagittal slices
    for j in range(0, samples):
        i = int((data_shape[2] / samples) * j + (data_shape[2] % samples) / 2)
        si = scalarMap.to_rgba(_idata[:, :, i])
        if _ovl is not None:
            so = oscalarMap.to_rgba(_odata[:, :, i])
            if use_max: si = max_blend(si, so)
            elif use_over: si = over_blend(si, so, ialpha, oalpha)
            else: si = alpha_blend(si, so, ialpha, oalpha)
        slices.append(si)
        aspects.append(spacing[2] / spacing[1])

    w, h = plt.figaspect(3.0 / samples)
    fig = plt.figure(figsize=(w, h))

    #outer_grid = gridspec.GridSpec((len(slices)+1)/2, 2, wspace=0.0, hspace=0.0)
    ax = None
    imgplot = None
    for i, j in enumerate(slices):
        ax = plt.subplot2grid((3, samples), (int(i / samples), i % samples))
        imgplot = ax.imshow(j, origin='lower', cmap=cm, aspect=aspects[i])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.title.set_visible(False)
    # show for the last plot
    if show_image_bar:
        cbar = fig.colorbar(imgplot)

    if title is not None:
        plt.suptitle(title, fontsize=20)
        plt.subplots_adjust(wspace=0.0, hspace=0.0)
    else:
        plt.subplots_adjust(top=1.0,
                            bottom=0.0,
                            left=0.0,
                            right=1.0,
                            wspace=0.0,
                            hspace=0.0)

    #fig.tight_layout()
    #plt.show()
    plt.savefig(output, bbox_inches='tight', dpi=dpi, format=format)
    plt.close()
    plt.close('all')
Beispiel #26
0
def qc_field_contour(
        input,
        output,
        image_range=None,
        title=None,
        image_cmap='gray',
        samples=5,
        show_image_bar=False,  # TODO:implement this?
        dpi=100,
        format=None):
    """show field contours
    """

    _img = minc2_file(input)
    _img.setup_standard_order()
    _idata = _img.load_complete_volume(minc2_file.MINC2_FLOAT)
    _idims = _img.representation_dims()

    data_shape = _idata.shape
    spacing = [_idims[0].step, _idims[1].step, _idims[2].step]

    slices = []

    # setup ranges
    vmin = vmax = 0.0
    if image_range is not None:
        vmin = image_range[0]
        vmax = image_range[1]
    else:
        vmin = np.nanmin(_idata)
        vmax = np.nanmax(_idata)

    cm = plt.get_cmap(image_cmap)

    cNorm = colors.Normalize(vmin=vmin, vmax=vmax)

    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)

    for j in range(0, samples):
        i = (data_shape[0] / samples) * j + (data_shape[0] % samples) / 2
        si = _idata[i, :, :]
        slices.append(si)

    for j in range(0, samples):
        i = (data_shape[1] / samples) * j + (data_shape[1] % samples) / 2
        si = _idata[:, i, :]
        slices.append(si)

    for j in range(0, samples):
        i = (data_shape[2] / samples) * j + (data_shape[2] % samples) / 2
        si = _idata[:, :, i]
        slices.append(si)

    w, h = plt.figaspect(3.0 / samples)
    fig = plt.figure(figsize=(w, h))

    #outer_grid = gridspec.GridSpec((len(slices)+1)/2, 2, wspace=0.0, hspace=0.0)
    ax = None
    imgplot = None
    for i, j in enumerate(slices):
        ax = plt.subplot2grid((3, samples), (i / samples, i % samples))
        imgplot = ax.contour(j,
                             origin='lower',
                             cmap=cm,
                             norm=cNorm,
                             levels=np.linspace(vmin, vmax, 20))
        #plt.clabel(imgplot, inline=1, fontsize=8)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.title.set_visible(False)
    # show for the last plot
    if show_image_bar:
        cbar = fig.colorbar(imgplot)

    if title is not None:
        plt.suptitle(title, fontsize=20)
        plt.subplots_adjust(wspace=0.0, hspace=0.0)
    else:
        plt.subplots_adjust(top=1.0,
                            bottom=0.0,
                            left=0.0,
                            right=1.0,
                            wspace=0.0,
                            hspace=0.0)

    plt.savefig(output, bbox_inches='tight', dpi=dpi)
    plt.close('all')
Beispiel #27
0
def errorCorrectionApply(input_images,
                         output,
                         input_mask=None,
                         parameters=None,
                         debug=False,
                         history=None,
                         input_auto=None,
                         partition=None,
                         part=None,
                         multilabel=1,
                         debug_files=None):
    try:
        use_coord = parameters.get('use_coord', True)
        use_joint = parameters.get('use_joint', True)
        patch_size = parameters.get('patch_size', 1)
        normalize_input = parameters.get('normalize_input', True)
        primary_features = parameters.get('primary_features', 1)

        method = parameters.get('method', 'lSVC')
        method2 = parameters.get('method2', method)

        training = parameters['training']

        clf = None
        clf2 = None

        border = patch_size * 2

        if patch_size == 0:
            border = 2

        if debug:
            print(
                "Running error-correction, input_image:{} trining:{} partition:{} part:{} output:{} input_auto:{}"
                .format(repr(input_images), training, partition, part, output,
                        input_auto))

        if method == 'xgb' and method2 == 'xgb':
            # need to convert from Unicode
            _training = str(training)
            clf = xgb.Booster(model_file=_training)
            if multilabel > 1:
                clf2 = xgb.Booster(model_file=_training + '_2')
        else:
            with open(training, 'rb') as f:
                c = pickle.load(f)
                clf = c[0]
                clf2 = c[1]

        if debug:
            print(clf)
            print(clf2)
            print("Loading input images...")

        input_data = [
            minc2_file(k).load_complete_volume('float32') for k in input_images
        ]
        shape = input_data[0].shape

        #features = [ extract_part( minc.Image(k, dtype=np.float32).data, partition, part, border) for k in inp[0:-3] ]
        #if normalize_input:
        #features = [ extract_part( preprocessing.scale( k ), partition, part, border) for k in input_data ]
        #else:
        features = [
            extract_part(k, partition, part, border) for k in input_data
        ]

        coords = None

        if use_coord:
            c = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]]
            coords = [
                extract_part((c[j] - shape[j] / 2.0) / (shape[j] / 2.0),
                             partition, part, border) for j in range(3)
            ]

        if debug:
            print("Features data size:{}".format(len(features)))

        mask = None

        mask_size = shape[0] * shape[1] * shape[2]

        if input_mask is not None:
            mask = extract_part(
                minc2_file(input_mask).data, partition, part, border)
            mask_size = np.sum(mask)

        out_cls = None
        out_corr = None

        test_x = convert_image_list([
            prepare_features(features,
                             coords,
                             mask=mask,
                             use_coord=use_coord,
                             use_joint=use_joint,
                             patch_size=patch_size,
                             primary_features=primary_features)
        ])

        if input_auto is not None:
            out_corr = np.copy(
                extract_part(
                    minc2_file(input_auto).data, partition, part,
                    border))  # use input data
            out_cls = np.copy(
                extract_part(
                    minc2_file(input_auto).data, partition, part,
                    border))  # use input data
        else:
            out_corr = np.zeros(shape, dtype=np.int32)
            out_cls = np.zeros(shape, dtype=np.int32)

        if mask_size > 0 and not isinstance(clf, dummy.DummyClassifier):
            if debug:
                print("Running classifier 1 ...")

            if method != 'xgb':
                pred = np.asarray(clf.predict(test_x), dtype=np.int32)
            else:
                xg_predict = xgb.DMatrix(test_x)
                pred = np.array(clf.predict(xg_predict), dtype=np.int32)

            if debug_files is not None:
                out_dbg = np.zeros(shape, dtype=np.int32)
                if mask is not None:
                    out_dbg[mask > 0] = pred
                else:
                    out_dbg = pred

                out_dbg_m = minc2_file()
                out_dbg_m.imitate(input_images[0], path=debug_files[0])
                out_dbg_m.data = pad_data(out_dbg, shape, partition, part,
                                          border)

            if mask is not None:
                out_corr[mask > 0] = pred
            else:
                out_corr = pred

            if multilabel > 1 and clf2 is not None:
                if mask is not None:
                    mask = np.logical_and(mask > 0, out_corr > 0)
                else:
                    mask = (out_corr > 0)

                if debug:
                    print("Running classifier 2 ...")

                test_x = convert_image_list([
                    prepare_features(features,
                                     coords,
                                     mask=mask,
                                     use_coord=use_coord,
                                     use_joint=use_joint,
                                     patch_size=patch_size,
                                     primary_features=primary_features)
                ])
                if method2 != 'xgb':
                    pred = np.asarray(clf2.predict(test_x), dtype=np.int32)
                else:
                    xg_predict = xgb.DMatrix(test_x)
                    pred = np.array(clf2.predict(xg_predict), dtype=np.int32)

                out_cls[mask > 0] = pred

                if debug_files is not None:
                    out_dbg = np.zeros(shape, dtype=np.int32)
                    if mask is not None:
                        out_dbg[mask > 0] = pred
                    else:
                        out_dbg = pred

                    out_dbg_m = minc2_file()
                    out_dbg_m.imitate(input_images[0], path=debug_files[1])
                    out_dbg_m.data = pad_data(out_dbg, shape, partition, part,
                                              border)

            else:
                out_cls = out_corr

        else:
            pass  # nothing to do!

        if debug:
            print("Saving output...")

        out = minc2_file()
        out.imitate(input_images[0], path=output)
        out.data = pad_data(out_cls, shape, partition, part, border)

    except mincError as e:
        print("Exception in errorCorrectionApply:{}".format(str(e)))
        traceback.print_exc(file=sys.stdout)
        raise
    except:
        print("Exception in errorCorrectionApply:{}".format(sys.exc_info()[0]))
        traceback.print_exc(file=sys.stdout)
        raise
Beispiel #28
0
def errorCorrectionTrain(input_images,
                         output,
                         parameters=None,
                         debug=False,
                         partition=None,
                         part=None,
                         multilabel=1):
    try:
        use_coord = parameters.get('use_coord', True)
        use_joint = parameters.get('use_joint', True)
        patch_size = parameters.get('patch_size', 1)

        border = patch_size * 2

        if patch_size == 0:
            border = 2

        normalize_input = parameters.get('normalize_input', True)

        method = parameters.get('method', 'lSVC')
        method2 = parameters.get('method2', method)
        method_n = parameters.get('method_n', 15)
        method2_n = parameters.get('method2_n', method_n)
        method_random = parameters.get('method_random', None)
        method_max_features = parameters.get('method_max_features', 'auto')
        method_n_jobs = parameters.get('method_n_jobs', 1)
        primary_features = parameters.get('primary_features', 1)

        training_images = []
        training_diff = []
        training_images_direct = []
        training_direct = []

        if debug:
            print("errorCorrectionTrain use_coord={} use_joint={} patch_size={} normalize_input={} method={} output={} partition={} part={}".\
                    format(repr(use_coord),repr(use_joint),repr(patch_size),repr(normalize_input),method,output,partition,part))

        coords = None
        total_mask_size = 0
        total_diff_mask_size = 0

        for (i, inp) in enumerate(input_images):
            mask = None
            diff = None
            mask_diff = None

            if inp[-2] is not None:
                mask = extract_part(
                    minc2_file(inp[-2]).data, partition, part, border)

            ground_data = minc2_file(inp[-1]).data
            auto_data = minc2_file(inp[-3]).data

            ground_shape = ground_data.shape
            ground = extract_part(ground_data, partition, part, border)
            auto = extract_part(auto_data, partition, part, border)

            shape = ground_shape
            if coords is None and use_coord:
                c = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]]
                coords = [
                    extract_part((c[j] - shape[j] / 2.0) / (shape[j] / 2.0),
                                 partition, part, border) for j in range(3)
                ]

            features = [
                extract_part(minc2_file(k).data, partition, part, border)
                for k in inp[0:-3]
            ]

            mask_size = shape[0] * shape[1] * shape[2]

            if debug:
                print("Training data size:{}".format(len(features)))
                if mask is not None:
                    mask_size = np.sum(mask)
                    print("Mask size:{}".format(mask_size))
                else:
                    print("Mask absent")
            total_mask_size += mask_size

            if multilabel > 1:
                diff = (ground != auto)
                total_diff_mask_size += np.sum(mask)

                if mask is not None:
                    mask_diff = diff & (mask > 0)
                    print("Sample {} mask_diff={} diff={}".format(
                        i, np.sum(mask_diff), np.sum(diff)))
                    #print(mask_diff)
                    training_diff.append(diff[mask > 0])
                    training_direct.append(ground[mask_diff])
                else:
                    mask_diff = diff
                    training_diff.append(diff)
                    training_direct.append(ground[diff])

                training_images.append(
                    prepare_features(features,
                                     coords,
                                     mask=mask,
                                     use_coord=use_coord,
                                     use_joint=use_joint,
                                     patch_size=patch_size,
                                     primary_features=primary_features))

                training_images_direct.append(
                    prepare_features(features,
                                     coords,
                                     mask=mask_diff,
                                     use_coord=use_coord,
                                     use_joint=use_joint,
                                     patch_size=patch_size,
                                     primary_features=primary_features))

            else:
                mask_diff = mask
                if mask is not None:
                    training_diff.append(ground[mask > 0])
                else:
                    training_diff.append(ground)

                training_images.append(
                    prepare_features(features,
                                     coords,
                                     mask=mask,
                                     use_coord=use_coord,
                                     use_joint=use_joint,
                                     patch_size=patch_size,
                                     primary_features=primary_features))

            if debug:
                print("feature size:{}".format(len(training_images[-1])))

            if i == 0 and parameters.get('dump', False):
                print("Dumping feature images...")
                for (j, k) in enumerate(training_images[-1]):
                    test = np.zeros_like(images[0])
                    test[mask > 0] = k
                    out = minc2_file()
                    out.imitate(inp[0], path="dump_{}.mnc".format(j))
                    out.data = test

        # calculate normalization coeffecients

        if debug: print("Done")

        clf = None
        clf2 = None

        if total_mask_size > 0:
            training_X = convert_image_list(training_images)
            training_Y = np.ravel(
                np.concatenate(tuple(j for j in training_diff)))

            if debug: print("Fitting 1st...")

            if method == "xgb":
                clf = None
            elif method == "SVM":
                clf = svm.SVC()
            elif method == "nuSVM":
                clf = svm.NuSVC()
            elif method == 'NC':
                clf = neighbors.NearestCentroid()
            elif method == 'NN':
                clf = neighbors.KNeighborsClassifier(method_n)
            elif method == 'RanForest':
                clf = ensemble.RandomForestClassifier(
                    n_estimators=method_n,
                    n_jobs=method_n_jobs,
                    max_features=method_max_features,
                    random_state=method_random)
            elif method == 'AdaBoost':
                clf = ensemble.AdaBoostClassifier(n_estimators=method_n,
                                                  random_state=method_random)
            elif method == 'AdaBoostPP':
                clf = Pipeline(steps=[('normalizer', Normalizer()),
                                      ('AdaBoost',
                                       ensemble.AdaBoostClassifier(
                                           n_estimators=method_n,
                                           random_state=method_random))])
            elif method == 'tree':
                clf = tree.DecisionTreeClassifier(random_state=method_random)
            elif method == 'ExtraTrees':
                clf = ensemble.ExtraTreesClassifier(
                    n_estimators=method_n,
                    max_features=method_max_features,
                    n_jobs=method_n_jobs,
                    random_state=method_random)
            elif method == 'Bagging':
                clf = ensemble.BaggingClassifier(
                    n_estimators=method_n,
                    max_features=method_max_features,
                    n_jobs=method_n_jobs,
                    random_state=method_random)
            elif method == 'dumb':
                clf = dummy.DummyClassifier(strategy="constant", constant=0)
            else:
                clf = svm.LinearSVC()

            #scores = cross_validation.cross_val_score(clf, training_X, training_Y)
            #print scores
            if method == "xgb":
                xg_train = xgb.DMatrix(training_X, label=training_Y)
                param = {}
                num_round = 100
                # use softmax multi-class classification
                param['objective'] = 'multi:softmax'
                # scale weight of positive examples
                param['eta'] = 0.1
                param['max_depth'] = 8
                param['silent'] = 1
                param['nthread'] = 4
                param['num_class'] = 2
                clf = xgb.train(param, xg_train, num_round)
            elif method != 'dumb':
                clf.fit(training_X, training_Y)

            if multilabel > 1 and method != 'dumb':
                if debug: print("Fitting direct...")

                training_X = convert_image_list(training_images_direct)
                training_Y = np.ravel(
                    np.concatenate(tuple(j for j in training_direct)))

                if method2 == "xgb":
                    clf2 = None
                if method2 == "SVM":
                    clf2 = svm.SVC()
                elif method2 == "nuSVM":
                    clf2 = svm.NuSVC()
                elif method2 == 'NC':
                    clf2 = neighbors.NearestCentroid()
                elif method2 == 'NN':
                    clf2 = neighbors.KNeighborsClassifier(method_n)
                elif method2 == 'RanForest':
                    clf2 = ensemble.RandomForestClassifier(
                        n_estimators=method_n,
                        n_jobs=method_n_jobs,
                        max_features=method_max_features,
                        random_state=method_random)
                elif method2 == 'AdaBoost':
                    clf2 = ensemble.AdaBoostClassifier(
                        n_estimators=method_n, random_state=method_random)
                elif method2 == 'AdaBoostPP':
                    clf2 = Pipeline(steps=[('normalizer', Normalizer()),
                                           ('AdaBoost',
                                            ensemble.AdaBoostClassifier(
                                                n_estimators=method_n,
                                                random_state=method_random))])
                elif method2 == 'tree':
                    clf2 = tree.DecisionTreeClassifier(
                        random_state=method_random)
                elif method2 == 'ExtraTrees':
                    clf2 = ensemble.ExtraTreesClassifier(
                        n_estimators=method_n,
                        max_features=method_max_features,
                        n_jobs=method_n_jobs,
                        random_state=method_random)
                elif method2 == 'Bagging':
                    clf2 = ensemble.BaggingClassifier(
                        n_estimators=method_n,
                        max_features=method_max_features,
                        n_jobs=method_n_jobs,
                        random_state=method_random)
                elif method2 == 'dumb':
                    clf2 = dummy.DummyClassifier(strategy="constant",
                                                 constant=0)
                else:
                    clf2 = svm.LinearSVC()

                if method2 == "xgb":
                    xg_train = xgb.DMatrix(training_X, label=training_Y)

                    param = {}
                    num_round = 100
                    # use softmax multi-class classification
                    param['objective'] = 'multi:softmax'
                    # scale weight of positive examples
                    param['eta'] = 0.1
                    param['max_depth'] = 8
                    param['silent'] = 1
                    param['nthread'] = 4
                    param['num_class'] = multilabel

                    clf2 = xgb.train(param, xg_train, num_round)

                elif method != 'dumb':
                    clf2.fit(training_X, training_Y)

            #print(clf.score(training_X,training_Y))

            if debug:
                print(clf)
                print(clf2)
        else:
            print("Warning : zero total mask size!, using null classifier")
            clf = dummy.DummyClassifier(strategy="constant", constant=0)

        if method == 'xgb' and method2 == 'xgb':
            #save
            clf.save_model(output)
            clf2.save_model(output + '_2')
        else:
            with open(output, 'wb') as f:
                pickle.dump([clf, clf2], f, -1)

    except mincError as e:
        print("Exception in linear_registration:{}".format(str(e)))
        traceback.print_exc(file=sys.stdout)
        raise
    except:
        print("Exception in linear_registration:{}".format(sys.exc_info()[0]))
        traceback.print_exc(file=sys.stdout)
        raise
    c = np.mgrid[0:img.shape[0], 0:img.shape[1], 0:img.shape[2]]

    seg = np.zeros_like(img, dtype=np.int32)

    seg=( c[2]>center[0] )*1+\
        ( c[1]>center[1] )*2+\
        ( c[0]>center[2] )*4+ 1

    seg[img < 50] = 0

    return np.asarray(seg, dtype=np.int32)


if __name__ == "__main__":
    options = parse_options()
    print(repr(options))
    input = minc2_file(options.input)
    input.setup_standard_order()
    data = input.load_complete_volume(minc2_file.MINC2_FLOAT)

    #seg=np.zeros_like( input.data, dtype=np.int32 )
    center_vox = [(options.center[i] - input.representation_dims()[i].start) /
                  input.representation_dims()[i].step for i in xrange(3)]
    print(repr(center_vox))
    seg = dumb_segment(data, center_vox)

    save_labels(options.output, input, seg)

# kate: space-indent on; indent-width 4; indent-mode python;replace-tabs on;word-wrap-column 80;show-tabs on
Beispiel #30
0
def qc(
    input,
    output,
    image_range=None,
    mask=None,
    mask_range=None,
    title=None,
    image_cmap='gray',
    mask_cmap='red',
    samples=6,
    mask_bg=None,
    use_max=False,
    use_over=False,
    show_image_bar=False,   # TODO:implement this?
    show_overlay_bar=False,
    dpi=100,
    ialpha=0.8,
    oalpha=0.2,
    format=None,
    bg_color=None,
    fg_color=None
    ):
    """QC image generation, drop-in replacement for minc_qc.pl
    Arguments:
        input -- input minc file
        output -- output QC graphics file 
        
    Keyword arguments:
        image_range -- (optional) intensity range for image
        mask  -- (optional) input mask file
        mask_range -- (optional) mask file range
        title  -- (optional) QC title
        image_cmap -- (optional) color map name for image, 
                       possibilities: red, green,blue and anything from matplotlib
        mask_cmap -- (optional) color map for mask, default red
        samples -- number of slices per dimension to show, default 6
                   should be even number, becaus it will be split across two rows 
        mask_bg  -- (optional) level for mask to treat as background
        use_max -- (optional) use 'max' colour mixing
        use_over -- (optional) use 'over' colour mixing
        show_image_bar -- show color bar for intensity range, default false
        show_overlay_bar  -- show color bar for mask intensity range, default false
        dpi -- graphics file DPI, default 100
        ialpha -- alpha channel for colour mixing of main image
        oalpha -- alpha channel for colour mixing of mask image
    """
    
    _img=minc2_file(input)
    _img.setup_standard_order()
    _idata=_img.load_complete_volume(minc2_file.MINC2_FLOAT)
    _idims=_img.representation_dims()
    
    data_shape=_idata.shape
    # order of dimensions in representation dimension is reversed
    spacing=[_idims[2].step, _idims[1].step, _idims[0].step]
    
    _ovl=None
    _odata=None
    omin=0
    omax=1
    # setup view
    columns=samples//2
    rows=(samples//columns)*3

    if mask is not None:
        _ovl=minc2_file(mask)
        _ovl.setup_standard_order()
        _ovl_data=_ovl.load_complete_volume(minc2_file.MINC2_FLOAT)
        if _ovl_data.shape != data_shape:
            raise "Overlay shape does not match image!\nOvl={} Image={}".format(repr(_ovl_data.shape),repr(data_shape))
        if mask_range is None:
            omin=np.nanmin(_ovl_data)
            omax=np.nanmax(_ovl_data)
        else:
            omin=mask_range[0]
            omax=mask_range[1]
        _odata=_ovl_data
        
        if mask_bg is not None:
            _odata=ma.masked_less(_odata, mask_bg)
        
    slices=[]
    
    # setup ranges
    vmin=vmax=0.0
    if image_range is not None:
        vmin=image_range[0]
        vmax=image_range[1]
    else:
        vmin=np.nanmin(_idata)
        vmax=np.nanmax(_idata)

    cm = copy.copy(plt.get_cmap(image_cmap))
    cmo= copy.copy(plt.get_cmap(mask_cmap))

    cm.set_bad('k', alpha = 1.0)
    cmo.set_bad('k',alpha = 0.0)

    cNorm  = colors.Normalize(vmin=vmin, vmax=vmax)
    oNorm  = colors.Normalize(vmin=omin, vmax=omax)
    
    scalarMap  = cmx.ScalarMappable(norm=cNorm, cmap=cm)
    oscalarMap = cmx.ScalarMappable(norm=oNorm, cmap=cmo)
    aspects = []
    
    # axial slices
    for j in range(0, samples):
        i=int(10+(150.0-10.0)*j/(samples-1))
        i=int(data_shape[0]*i/181.0)

        si=scalarMap.to_rgba(_idata[i , : ,:])

        if _ovl is not None:
            so = oscalarMap.to_rgba(_odata[i , : ,:])
            if    use_max: si=max_blend(si, so)
            elif use_over: si=over_blend(si, so, ialpha, oalpha)
            else:          si=alpha_blend(si, so, ialpha, oalpha)
        slices.append( si )
        aspects.append( spacing[1] / spacing[2] )

    # sagittal slices
    for j in range(0, samples//2):
        i=int(28.0+(166.0-28.0)*j/(samples-1))
        i=int(data_shape[2]*i/193.0)

        si=scalarMap.to_rgba(_idata[: , : , i])
        if _ovl is not None:
            so = oscalarMap.to_rgba(_odata[: , : , i])
            if    use_max: si=max_blend(si,so)
            elif use_over: si=over_blend(si,so, ialpha, oalpha)
            else:          si=alpha_blend(si, so, ialpha, oalpha)
        slices.append( si )
        aspects.append( spacing[0] / spacing[1] )

    for j in range(samples-1, samples//2-1,-1):
        i=int(28.0+(166.0-28.0)*j/(samples-1))
        i=int(data_shape[2]*i/193.0)

        si=scalarMap.to_rgba(_idata[: , : , i])
        if _ovl is not None:
            so = oscalarMap.to_rgba(_odata[: , : , i])
            if    use_max: si=max_blend(si,so)
            elif use_over: si=over_blend(si,so, ialpha, oalpha)
            else:          si=alpha_blend(si, so, ialpha, oalpha)
        slices.append( si )
        aspects.append( spacing[0] / spacing[1] )

    # coronal slices
    for j in range(0,samples):
        i=int(25+(195.0-25.0)*j/(samples-1))
        i=int(data_shape[1]*i/217.0)
        si=scalarMap.to_rgba(_idata[: , i ,:])
        
        if _ovl is not None:
            so = oscalarMap.to_rgba(_odata[: , i ,:])
            if    use_max: si=max_blend(si,so)
            elif use_over: si=over_blend(si,so, ialpha, oalpha)
            else:          si=alpha_blend(si, so, ialpha, oalpha)
        slices.append( si )
        aspects.append( spacing[0]/spacing[2] )

    rc={'interactive': False}
    rc['figure.frameon']=False
    #rc['aa']=True

    if bg_color is not None:
         rc['figure.edgecolor']=bg_color
         rc['figure.facecolor']=bg_color     
         rc['grid.color']=bg_color
    if fg_color is not None:
         rc['text.color']=fg_color
         #axes.labelcolor
         #axes.titlecolor

    with matplotlib.rc_context(rc):
    #with plt.style.context('dark_background'):
        w, h = plt.figaspect(rows/columns)
        fig = plt.figure(figsize=(w,h))
        # if bg_color is not None:
        #     fig.set_facecolor(bg_color)
        #     fig.set_edgecolor(bg_color)
        #    fig.set_frameon(False)
        #outer_grid = gridspec.GridSpec((len(slices)+1)/2, 2, wspace=0.0, hspace=0.0)
        ax=None
        imgplot=None
        
        print(f"rows:{rows} columns:{columns}")
        for i,j in enumerate(slices):
            ax =  plt.subplot2grid( (rows, columns), ( i//columns, i%columns) )
            # if bg_color is not None:
            #     ax.set_facecolor(bg_color)
            imgplot = ax.imshow(j, origin='lower',  cmap=cm, aspect=aspects[i])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.title.set_visible(False)

        # show for the last plot
        if show_image_bar:
            cbar = fig.colorbar(imgplot)
        
        if title is not None:
            plt.suptitle(title,fontsize=20)
            plt.subplots_adjust(wspace = 0.0 ,hspace=0.0)
        else:
            plt.subplots_adjust(top=1.0,bottom=0.0,left=0.0,right=1.0,wspace = 0.0 ,hspace=0.0)

        # , facecolor=fig.get_facecolor(), edgecolor='none')
        plt.savefig(output, bbox_inches='tight', dpi=dpi, format=format)
        plt.close()
        plt.close('all')