Example #1
0
def downsample_channel(image_in,
                       ch,
                       resolution_level=-1,
                       dsfacs=[1, 4, 4, 1, 1],
                       ismask=False,
                       output='',
                       report={}):
    """Downsample an image."""

    ods = 'data' if not ismask else 'mask'

    # return in case no mask provided
    if not image_in:
        return None, report, ''

    if resolution_level != -1 and not ismask:  # we should have an Imaris pyramid
        image_in = '{}/DataSet/ResolutionLevel {}'.format(
            image_in, resolution_level)

    # load data
    im = Image(image_in, permission='r')
    im.load(load_data=False)
    props = im.get_props()
    if len(im.dims) > 4:
        im.slices[im.axlab.index('t')] = slice(0, 1, 1)
        props = im.squeeze_props(props, dim=4)
    if len(im.dims) > 3:
        im.slices[im.axlab.index('c')] = slice(ch, ch + 1, 1)
        props = im.squeeze_props(props, dim=3)
    data = im.slice_dataset()
    im.close()

    # downsample
    dsfac = tuple(dsfacs[:len(data.shape)])
    if not ismask:
        data = downscale_local_mean(data, dsfac).astype('float32')
    else:
        data = block_reduce(data, dsfac, np.max)

    # generate output
    props['axlab'] = 'zyx'  # FIXME: axlab returns as string-list
    props['shape'] = data.shape
    props['elsize'] = [es * ds for es, ds in zip(im.elsize[:3], dsfac)]
    props['slices'] = None
    mo = write_data(data, props, output, ods)

    # report data
    thr = 1000
    meds_mask = data < thr
    report['medians'][ods] = get_zyx_medians(data, meds_mask)

    c_slcs = {dim: get_centreslice(mo, '', dim) for dim in 'zyx'}
    report['centreslices'][ods] = c_slcs

    return mo, report, meds_mask
Example #2
0
def mergeblocks(
    images_in,
    dataslices=None,
    blocksize=[],
    blockmargin=[],
    blockrange=[],
    blockoffset=[0, 0, 0],
    fullsize=[],
    is_labelimage=False,
    relabel=False,
    neighbourmerge=False,
    save_fwmap=False,
    blockreduce=[],
    func='np.amax',
    datatype='',
    usempi=False,
    outputpath='',
    save_steps=False,
    protective=False,
):
    """Merge blocks of data into a single hdf5 file."""

    if blockrange:
        images_in = images_in[blockrange[0]:blockrange[1]]

    mpi = wmeMPI(usempi)

    im = Image(images_in[0], permission='r')
    im.load(mpi.comm, load_data=False)
    props = im.get_props(protective=protective, squeeze=True)
    ndim = im.get_ndim()

    props['dtype'] = datatype or props['dtype']
    props['chunks'] = props['chunks'] or None

    # get the size of the outputfile
    # TODO: option to derive fullsize from dset_names?
    if blockreduce:
        datasize = np.subtract(fullsize, blockoffset)
        outsize = [
            int(np.ceil(d / np.float(b)))
            for d, b in zip(datasize, blockreduce)
        ]
        props['elsize'] = [e * b for e, b in zip(im.elsize, blockreduce)]
    else:  # FIXME: 'zyx(c)' stack assumed
        outsize = np.subtract(fullsize, blockoffset)

    if ndim == 4:
        outsize = list(outsize) + [im.ds.shape[3]]  # TODO: flexible insert

    if outputpath.endswith('.ims'):
        mo = LabelImage(outputpath)
        mo.create(comm=mpi.comm)
    else:
        props['shape'] = outsize
        mo = LabelImage(outputpath, **props)
        mo.create(comm=mpi.comm)

    mpi.blocks = [{'path': image_in} for image_in in images_in]
    mpi.nblocks = len(images_in)
    mpi.scatter_series()

    # merge the datasets
    maxlabel = 0
    for i in mpi.series:

        block = mpi.blocks[i]
        # try:
        maxlabel = process_block(block['path'], ndim, blockreduce, func,
                                 blockoffset, blocksize, blockmargin, fullsize,
                                 mo, is_labelimage, relabel, neighbourmerge,
                                 save_fwmap, maxlabel, mpi)
        print('processed block {:03d}: {}'.format(i, block['path']))
        # except Exception as e:
        #     print('failed block {:03d}: {}'.format(i, block['path']))
        #     print(e)

    im.close()
    mo.close()

    return mo
Example #3
0
def apply_bias_field_full(image_in,
                          bias_in,
                          dsfacs=[1, 64, 64, 1],
                          in_place=False,
                          write_to_single_file=False,
                          blocksize_xy=1280,
                          outputpath='',
                          channel=None):
    """single-core in ~200 blocks"""

    perm = 'r+' if in_place else 'r'
    im = Image(image_in, permission=perm)
    im.load(load_data=False)

    bf = Image(bias_in, permission='r')
    bf.load(load_data=False)

    if channel is not None:
        im.slices[3] = slice(channel, channel + 1)
    if write_to_single_file:  # assuming single-channel copied file here
        mo = Image(outputpath)
        mo.load()
        mo.slices[3] = slice(0, 1, 1)

    mpi = wmeMPI(usempi=False)
    mpi_nm = wmeMPI(usempi=False)
    if blocksize_xy:
        blocksize = [im.dims[0], blocksize_xy, blocksize_xy, 1, 1]
        blockmargin = [0, im.chunks[1], im.chunks[2], 0, 0]
    else:
        blocksize = im.dims[:3] + [1, 1]
        blockmargin = [0] * len(im.dims)
    mpi.set_blocks(im, blocksize, blockmargin)
    mpi_nm.set_blocks(im, blocksize)
    mpi.scatter_series()

    for i in mpi.series:
        print(i)
        block = mpi.blocks[i]
        data_shape = list(im.slices2shape(block['slices']))
        block_nm = mpi_nm.blocks[i]
        it = zip(block['slices'], block_nm['slices'], blocksize, data_shape)
        data_shape = list(im.slices2shape(block_nm['slices']))
        data_slices = []
        for b_slc, n_slc, bs, ds in it:
            m_start = n_slc.start - b_slc.start
            m_stop = m_start + bs
            m_stop = min(m_stop, ds)
            data_slices.append(slice(m_start, m_stop, None))
        data_slices[3] = block['slices'][3]
        data_shape = list(im.slices2shape(data_slices))

        # get the fullres image block
        im.slices = block['slices']
        data = im.slice_dataset().astype('float')

        # get the upsampled bias field
        bias = get_bias_field_block(bf, im.slices, data.shape)
        data /= bias
        data = np.nan_to_num(data, copy=False)

        if in_place:
            im.slices = block_nm['slices']
            data = data[tuple(data_slices[:3])].astype(im.dtype)
            im.write(data)
        elif write_to_single_file:
            mo.slices = block_nm['slices']
            mo.slices[3] = slice(0, 1, 1)
            data = data[tuple(data_slices[:3])].astype(mo.dtype)
            mo.write(data)
        else:
            props = im.get_props()
            if len(im.dims) > 4:
                props = im.squeeze_props(props, dim=4)
            if len(im.dims) > 3:
                props = im.squeeze_props(props, dim=3)
            props['axlab'] = 'zyx'  # FIXME: axlab return as string-list
            props['shape'] = bias.shape
            props['slices'] = None
            props['dtype'] = bias.dtype
            mo = Image(block['path'], **props)  # FIXME: needs channel
            mo.create(comm=mpi.comm)
            mo.slices = None
            mo.set_slices()
            mo.write(data=bias)
            mo.close()

    im.close()
    bf.close()
Example #4
0
def cell_segmentation(
    plan_path,
    memb_path,
    dapi_path,
    mean_path,
    dapi_shift_planes=0,
    nucl_opening_footprint=[3, 7, 7],
    dapi_filter='median',
    dapi_sigma=1,
    dapi_dog_sigma1=2,
    dapi_dog_sigma2=4,
    dapi_thr=0,
    sauvola_window_size=[19, 75, 75],
    sauvola_k=0.2,
    dapi_absmin=500,
    dapi_erodisk=0,
    dist_max=5,
    peaks_size=[11, 19, 19],
    peaks_thr=1.0,
    peaks_dil_footprint=[3, 7, 7],
    compactness=0.80,
    memb_filter='median',
    memb_sigma=3,
    planarity_thr=0.0005,
    dset_mask_filter='gaussian',
    dset_mask_sigma=50,
    dset_mask_thr=1000,
    steps=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
    outputstem='',
    save_steps=False,
):

    step = 'segment'
    paths = get_paths(plan_path, -1, 0, outputstem, step, save_steps)
    report = {
        'parameters': locals(),
        'paths': paths,
        'medians': {},
        'centreslices': {}
    }

    # load images
    im_dapi = Image(dapi_path)
    im_dapi.load()
    nucl_props = im_dapi.get_props()

    im_memb = MaskImage(memb_path)
    im_memb.load()
    memb_props = im_memb.get_props()

    im_plan = MaskImage(plan_path)
    im_plan.load()

    # im_dset_mask = Image(dset_mask_path, permission='r')
    # im_dset_mask.load(load_data=False)
    im_mean = Image(mean_path)
    im_mean.load()

    # preprocess dapi channel
    # .h5/nucl/dapi<_shifted><_opened><_preprocess>
    stage = 'nucleus channel'
    t = time.time()
    outstem = '{}.h5{}'.format(outputstem, '/nucl/dapi')
    if 0 not in steps:
        op = 'reading'
        im_dapi_pp = get_image('{}{}'.format(outstem, '_preprocess'))
    else:
        op = 'processing'
        im_dapi_pp = preprocess_nucl(
            im_dapi,
            dapi_shift_planes,
            dapi_filter,
            dapi_sigma,
            outstem,
            save_steps,
        )
    elapsed = time.time() - t
    print('{} ({}) took {:1f} s'.format(stage, op, elapsed))

    # create a nuclear mask from the dapi channel
    # .h5/nucl/dapi<_mask_thr><_sauvola><_mask><_mask_ero>
    stage = 'nucleus mask'
    t = time.time()
    outstem = '{}.h5{}'.format(outputstem, '/nucl/dapi')
    if 1 not in steps:
        op = 'reading'
        im_dapi_mask = get_image('{}{}'.format(outstem, '_mask_ero'))
    else:
        op = 'processing'
        im_dapi_mask = create_nuclear_mask(
            im_dapi_pp,
            dapi_thr,
            sauvola_window_size,
            sauvola_k,
            dapi_absmin,
            dapi_erodisk,
            outstem,
            save_steps,
        )
    elapsed = time.time() - t
    print('{} ({}) took {:1f} s'.format(stage, op, elapsed))

    # create a membrane mask from the membrane mean
    # .h5/memb/planarity<_mask>
    stage = 'membrane mask'
    t = time.time()
    outstem = '{}.h5{}'.format(outputstem, '/memb/planarity')
    if 2 not in steps:
        op = 'reading'
        im_memb_mask = get_image('{}{}'.format(outstem, '_mask'))
    else:
        op = 'processing'
        im_memb_mask = create_membrane_mask(
            im_plan,
            planarity_thr,
            outstem,
            save_steps,
        )
    elapsed = time.time() - t
    print('{} ({}) took {:1f} s'.format(stage, op, elapsed))

    # combine nuclear and membrane mask
    # .h5/segm/seeds<_mask>
    stage = 'mask combination'
    t = time.time()
    outstem = '{}.h5{}'.format(outputstem, '/segm/seeds')
    if 3 not in steps:
        op = 'reading'
        im_nucl_mask = get_image('{}{}'.format(outstem, '_mask'))
    else:
        op = 'processing'
        im_nucl_mask = combine_nucl_and_memb_masks(
            im_memb_mask,
            im_dapi_mask,
            nucl_opening_footprint,
            outstem,
            save_steps,
        )
    elapsed = time.time() - t
    print('{} ({}) took {:1f} s'.format(stage, op, elapsed))

    # find seeds for watershed
    stage = 'nucleus detection'
    t = time.time()
    outstem = '{}.h5{}'.format(outputstem, '/segm/seeds')
    if 4 not in steps:
        op = 'reading'
        im_dt = get_image('{}{}'.format(outstem, '_edt'))
        im_peaks = get_image('{}{}'.format(outstem, '_peaks'))
    # .h5/segm/seeds<_edt><_mask_distmax><_peaks><_peaks_dil>
    else:
        op = 'processing'
        im_dt, im_peaks = define_seeds(
            im_nucl_mask,
            im_memb_mask,
            im_dapi_pp,
            dapi_dog_sigma1,
            dapi_dog_sigma2,
            dist_max,
            peaks_size,
            peaks_thr,
            peaks_dil_footprint,
            outstem,
            save_steps,
        )
    elapsed = time.time() - t
    print('{} ({}) took {:1f} s'.format(stage, op, elapsed))

    # preprocess membrane mean channel
    # .h5/memb/preprocess<_smooth>
    stage = 'membrane channel'
    t = time.time()
    outstem = '{}.h5{}'.format(outputstem, '/memb/mean')
    if 5 not in steps:
        op = 'reading'
        im_memb_pp = get_image('{}{}'.format(outstem, '_smooth'))
    else:
        op = 'processing'
        im_memb_pp = preprocess_memb(
            im_memb,
            memb_filter,
            memb_sigma,
            outstem,
            save_steps,
        )
    elapsed = time.time() - t
    print('{} ({}) took {:1f} s'.format(stage, op, elapsed))

    # perform watershed from the peaks to fill the nuclei
    # .h5/segm/labels<_edt><_memb>
    stage = 'watershed'
    t = time.time()
    outstem = '{}.h5{}'.format(outputstem, '/segm/labels')
    if 6 not in steps:
        op = 'reading'
        im_ws = get_image('{}{}'.format(outstem, '_memb'), imtype='Label')
    else:
        op = 'processing'
        im_ws = perform_watershed(
            im_peaks,
            im_memb_pp,
            im_dt,
            peaks_thr,
            memb_sigma,
            memb_filter,
            compactness,
            outstem,
            save_steps,
        )
    elapsed = time.time() - t
    print('{} ({}) took {:1f} s'.format(stage, op, elapsed))

    # generate a dataset mask from the mean of all channels
    # .h5/mean<_smooth><_mask>
    stage = 'dataset mask'
    t = time.time()
    outstem = '{}.h5{}'.format(outputstem, '/mean')
    if 7 not in steps:
        op = 'reading'
        im_dset_mask = get_image('{}{}'.format(outstem, '_mask'),
                                 imtype='Mask')
    else:
        op = 'processing'
        im_dset_mask = create_dataset_mask(
            im_mean,
            filter=dset_mask_filter,
            sigma=dset_mask_sigma,
            threshold=dset_mask_thr,
            outstem=outstem,
            save_steps=save_steps,
        )
    elapsed = time.time() - t
    print('{} ({}) took {:1f} s'.format(stage, op, elapsed))

    # filter the segments with the dataset mask
    # .h5/segm/labels<_memb_del>
    # .h5/mask
    stage = 'segment filter'
    t = time.time()
    outstem = '{}.h5{}'.format(outputstem, '/segm/labels')
    if 8 not in steps:
        im_ws_pp = get_image('{}{}'.format(outstem, '_memb_del'),
                             imtype='Label')
    else:
        op = 'processing'
        im_ws_pp = segmentation_postprocessing(
            im_dset_mask,
            im_ws,
            outstem,
            save_steps,
        )
    elapsed = time.time() - t
    print('{} ({}) took {:1f} s'.format(stage, op, elapsed))

    # write report
    generate_report('{}.h5/{}'.format(outputstem, 'mean_mask'))

    return im_ws_pp