Ejemplo n.º 1
0
def omni_generator(f_input, begin=0, end=64, step=64):
    for i in range(begin, end, step):
        if i + step > end:
            data = dxchange.read_tiff_stack(f_input, ind=range(i, end))
        else:
            data = dxchange.read_tiff_stack(f_input, ind=range(i, i + step))
        yield (i, data)
Ejemplo n.º 2
0
def omni_read(f_input, begin=None, end=None):
    '''support tiff, tiff stack, hdf5'''
    if not f_input:
        return None
    f_input = os.path.abspath(f_input)
    matches = re.match(
        r'^(?P<dirname>.*)/(?P<fname>.*)\.(?P<ext>[^:]*)($|:(?P<dataset>.*$))',
        f_input)
    # print(matches.groupdict())
    if matches['ext'] == 'tif' or matches['ext'] == 'tiff':
        if begin is not None and end is not None:
            data = dxchange.read_tiff_stack(f_input, ind=range(begin, end))
        else:
            # print(f_input)
            S, L = check_stack_len(f_input)
            # print(S,L)
            if L > 1:
                data = dxchange.read_tiff_stack(f_input, ind=range(S, S + L))
            else:
                data = dxchange.read_tiff(f_input)
    elif matches['ext'] == 'h5' or matches['ext'] == 'hdf5':
        tokens = f_input.split(':')
        dataset_name = tokens[1]
        f_input = h5py.File(tokens[0], 'r')
        data = np.asarray(f_input[tokens[1]])
    else:
        print('not implemented file type')
    return data
Ejemplo n.º 3
0
def get_reslice(recon_folder,
                chunk_size=50,
                slice_y=None,
                slice_x=None,
                rotate=0):

    filelist = glob.glob(os.path.join(recon_folder, 'recon*.tiff'))
    inds = []
    digit = None
    for i in filelist:
        i = os.path.split(i)[-1]
        regex = re.compile(r'\d+')
        a = regex.findall(i)[0]
        if digit is None:
            digit = len(a)
        inds.append(int(a))
    chunks = []
    chunk_st = np.min(inds)
    chunk_end = chunk_st + chunk_size
    sino_end = np.max(inds) + 1

    while chunk_end < sino_end:
        chunks.append((chunk_st, chunk_end))
        chunk_st = chunk_end
        chunk_end += chunk_size
    chunks.append((chunk_st, sino_end))

    a = dxchange.read_tiff_stack(filelist[0], range(chunks[0][0],
                                                    chunks[0][1]), digit)
    if rotate != 0:
        a = scipy.ndimage.interpolation.rotate(a,
                                               rotate,
                                               axes=(1, 2),
                                               reshape=False)
    if slice_y is not None:
        slice = a[:, slice_y, :]
    elif slice_x is not None:
        slice = a[:, ::-1, slice_x]
    else:
        raise ValueError('Either slice_y or slice_x must be specified.')

    for (chunk_st, chunk_end) in chunks[1:]:
        a = dxchange.read_tiff_stack(filelist[0], range(chunk_st, chunk_end),
                                     digit)
        if rotate != 0:
            a = scipy.ndimage.interpolation.rotate(a,
                                                   rotate,
                                                   axes=(1, 2),
                                                   reshape=False)
        if slice_y is not None:
            slice = np.append(slice, a[:, slice_y, :], axis=0)
        elif slice_x is not None:
            slice = np.append(slice, a[:, ::-1, slice_x], axis=0)
        else:
            raise ValueError('Either slice_y or slice_x must be specified.')

    return slice
Ejemplo n.º 4
0
def load_raw(top, index_start):
    """
    Function description.

    Parameters
    ----------
    parameter_01 : type
        Description.

    parameter_02 : type
        Description.

    parameter_03 : type
        Description.

    Returns
    -------
    return_01
        Description.
    """
    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tif'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)
    return rdata
Ejemplo n.º 5
0
def load_raw(top, index_start):
    """
    Load a stack of tiff images.

    Parameters
    ----------
    top : str
        Top data directory.

    index_start : int
        Image index start.

    Returns
    -------
    ndarray
        3D stack of images.
    """
    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tif'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)
    return rdata
Ejemplo n.º 6
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument("top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start", nargs='?', const=1, type=int, default=1, help="index of the first image: 10001 (default 1)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tif'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)

    particle_bed_reference = particle_bed_location(rdata[0], plot=False)

    print("Particle bed location: ", particle_bed_reference)
    print("Laser on?: ", laser_on(rdata, particle_bed_reference))
    print("Shutter closed on image: ", shutter_off(rdata))
Ejemplo n.º 7
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start",
                        nargs='?',
                        const=1,
                        type=int,
                        default=0,
                        help="index of the first image: 100 (default 0)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tiff'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)
    slider(rdata)
Ejemplo n.º 8
0
    def execute(self, startswith, seperator, idx, output_dataobj):
        if (startswith is None) or (seperator is None) or (idx is None):
            self.LogError('Please enter scanning requirements for files')
            return False
        filelist = os.listdir(self.ipath)
        filelist = list(
            filter(lambda filename: filename.startswith(startswith), filelist))
        if len(filelist) == 0:
            self.LogError('Can not scan required files, Please check!')
            return False
        else:
            filelist.sort()
            self.LogInfo('First Input File: ' + filelist[0])
        if not (filelist[0].split('.')[-1].upper() == 'TIF'):
            self.LogError('Files should be TIF')
            return False

        fileidxs = list(
            map(
                lambda filename: '.'.join(filename.split('.')[:-1]).split(
                    seperator)[idx], filelist))
        if len(fileidxs) == 0:
            self.LogError('Can not scan required files, Please check!')
            return False
        else:
            self.LogInfo('Get ' + str(len(fileidxs)) + ' files in the Path ' +
                         self.ipath)

        self.data[output_dataobj] = dxchange.read_tiff_stack(self.ipath + '/' +
                                                             filelist[0],
                                                             ind=fileidxs)
        self.LogInfo('Load data ' + output_dataobj + ' from ' + self.ipath)

        return True
Ejemplo n.º 9
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument("top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start", nargs='?', const=0, type=int, default=0, help="index of the first image: 10001 (default 0)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), 'sdat*'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    print(fname, ind_tomo)

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)
    sdata = rdata
    for index in ind_tomo:
        print(index)
        sdata[index] = rdata[index].byteswap()
    #rdata[0] = rdata[0].byteswap()
    print(sdata.shape)
    slider(sdata)
Ejemplo n.º 10
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument("top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start", nargs='?', const=1, type=int, default=0, help="index of the first image: 100 (default 0)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tiff'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)

    # View the data
    slider(rdata)

    # Apply the sobel filter
    ndata = scale_to_one(rdata)
    ndata = sobel_stack(ndata)
    slider(ndata)

    blur_radius = 3.0
    threshold = .04
    nddata = label(ndata, blur_radius, threshold)
    slider(ndata)
def large_data_generator(stack_name,
                         begin=0,
                         end=64,
                         step=64,
                         dtype=None,
                         multi=False):
    for i in range(begin, end, step):
        if i + step > end:
            data = dxchange.read_tiff_stack(stack_name, ind=range(i, end))
        else:
            data = dxchange.read_tiff_stack(stack_name, ind=range(i, i + step))
        if not multi and dtype == 'uint32':
            data = np.nan_to_num(data > 0)
        if dtype:
            data = data.astype(dtype)
        data = np.moveaxis(data, 0, 2)
        yield (i, data)
Ejemplo n.º 12
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start",
                        nargs='?',
                        const=1,
                        type=int,
                        default=1,
                        help="index of the first image: 10001 (default 1)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tif'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)
    particle_bed_reference = particle_bed_location(rdata[0], plot=False)
    print("Particle bed location: ", particle_bed_reference)

    # Cut the images to remove the particle bed
    cdata = rdata[:, 0:particle_bed_reference, :]

    # Find the image when the shutter starts to close
    dark_index = shutter_off(rdata)
    print("shutter closes on image: ", dark_index)
    # Set the [start, end] index of the blocked images, flat and dark.
    flat_range = [0, 1]
    data_range = [48, dark_index]
    dark_range = [dark_index, nfile]

    # # for fast testing
    # data_range = [48, dark_index]

    flat = cdata[flat_range[0]:flat_range[1], :, :]
    proj = cdata[data_range[0]:data_range[1], :, :]
    dark = np.zeros(
        (dark_range[1] - dark_range[0], proj.shape[1], proj.shape[2]))

    # if you want to use the shutter closed images as dark uncomment this:
    #dark = cdata[dark_range[0]:dark_range[1], :, :]

    ndata = tomopy.normalize(proj, flat, dark)
    ndata = tomopy.normalize_bg(ndata, air=ndata.shape[2] / 2.5)
    ndata = tomopy.minus_log(ndata)
    sharpening(ndata)
    slider(ndata)
Ejemplo n.º 13
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument("top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start", nargs='?', const=1, type=int, default=1, help="index of the first image: 1000 (default 1)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[0]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tif'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)
    
    fname = top + template

    print (nfile, index_start, index_end, fname)


    # Select the sinogram range to reconstruct.
    start = 0
    end = 512
    sino=(start, end)

    # Read the tiff raw data.
    ndata = dxchange.read_tiff_stack(fname, ind=ind_tomo, slc=(sino, None))
 
    # Normalize to 1 using the air counts
    ndata = tomopy.normalize_bg(ndata, air=5)

    # Set data collection angles as equally spaced between 0-180 degrees.
    theta = tomopy.angles(ndata.shape[0])

    ndata = tomopy.minus_log(ndata)

    # Set binning and number of iterations
    binning = 8
    iters = 21

    print("Original", ndata.shape)
    ndata = tomopy.downsample(ndata, level=binning, axis=1)
#    ndata = tomopy.downsample(ndata, level=binning, axis=2)
    print("Processing:", ndata.shape)

    fdir = 'aligned' + '/noblur_iter_' + str(iters) + '_bin_' + str(binning) 

    print(fdir)
    cprj, sx, sy, conv = alignment.align_seq(ndata, theta, fdir=fdir, iters=iters, pad=(10, 10), blur=False, save=True, debug=True)

    np.save(fdir + '/shift_x', sx)
    np.save(fdir + '/shift_y', sy)

    # Write aligned projections as stack of TIFs.
    dxchange.write_tiff_stack(cprj, fname=fdir + '/radios/image')
Ejemplo n.º 14
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument("top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start", nargs='?', const=1, type=int, default=1, help="index of the first image: 1000 (default 1)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[0]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tif'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)
    
    fname = top + template

    print (nfile, index_start, index_end, fname)


    # Select the sinogram range to reconstruct.
    start = 0
    end = 512
    sino=(start, end)

    # Read the tiff raw data.
    ndata = dxchange.read_tiff_stack(fname, ind=ind_tomo, slc=(sino, None))

    print(ndata.shape)
    binning = 8
    ndata = tomopy.downsample(ndata, level=binning, axis=1)
    print(ndata.shape)
    
    # Normalize to 1 using the air counts
    ndata = tomopy.normalize_bg(ndata, air=5)

    ## slider(ndata)

    # Set data collection angles as equally spaced between 0-180 degrees.
    theta = tomopy.angles(ndata.shape[0])
   
    rot_center = 960
    print("Center of rotation: ", rot_center)

    ndata = tomopy.minus_log(ndata)

    # Reconstruct object using Gridrec algorithm.
    rec = tomopy.recon(ndata, theta, center=rot_center, algorithm='gridrec')

    # Mask each reconstructed slice with a circle.
    rec = tomopy.circ_mask(rec, axis=0, ratio=0.95)

    # Write data as stack of TIFs.
    dxchange.write_tiff_stack(rec, fname='/local/dataraid/mark/rec/recon')
Ejemplo n.º 15
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start",
                        nargs='?',
                        const=1,
                        type=int,
                        default=1,
                        help="index of the first image: 1000 (default 1)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[0]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tiff'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    print(nfile, index_start, index_end, fname)

    # Select the sinogram range to reconstruct.
    start = 70
    end = 72
    sino = (start, end)

    # Read the tiff raw data.
    ndata = dxchange.read_tiff_stack(fname, ind=ind_tomo, slc=(sino, None))

    # Set data collection angles as equally spaced between 0-180 degrees.
    theta = tomopy.angles(ndata.shape[0])

    rot_center = 251
    print("Center of rotation: ", rot_center)

    #ndata = tomopy.minus_log(ndata)

    # Reconstruct object using Gridrec algorithm.
    rec = tomopy.recon(ndata, theta, center=rot_center, algorithm='gridrec')

    # Mask each reconstructed slice with a circle.
    rec = tomopy.circ_mask(rec, axis=0, ratio=0.95)

    # Write data as stack of TIFs.
    dxchange.write_tiff_stack(rec, fname='/local/dataraid/mark/rec/recon')
Ejemplo n.º 16
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start",
                        nargs='?',
                        const=1,
                        type=int,
                        default=1,
                        help="index of the first image: 10001 (default 1)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tiff'))
    index_end = index_start + nfile
    index_end = 10
    ind_tomo = range(index_start, index_end)

    fname = top + template

    print(nfile, index_start, index_end, fname)
    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)
    ndata = rdata

    #    slider(ndata)

    #    ndata = sobel_stack(rdata)
    #    print(ndata[0, :, :])
    #    slider(ndata)

    #    blur_radius = 2
    #    threshold = 0.0000001
    #    ndata = label(rdata, blur_radius, threshold)
    #    slider(ndata)

    idata = se.rescale_intensity(ndata[0, :, :], out_range=(0, 256))
    distance = ndi.distance_transform_edt(idata)
    local_maxi = sf.peak_local_max(distance,
                                   labels=idata,
                                   footprint=np.ones((3, 3)),
                                   indices=False)
    markers = ndi.label(local_maxi)[0]
    labels = seg.watershed(-distance, markers, mask=idata)
    slider(labels)
Ejemplo n.º 17
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument("top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start", nargs='?', const=1, type=int, default=1, help="index of the first image: 10001 (default 1)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tif'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)
    particle_bed_reference = particle_bed_location(rdata[0], plot=False)
    print("Particle bed location: ", particle_bed_reference)
    
    # Cut the images to remove the particle bed
    cdata = rdata[:, 0:particle_bed_reference, :]

    # Find the image when the shutter starts to close
    dark_index = shutter_off(rdata)
    print("shutter closes on image: ", dark_index)
    # Set the [start, end] index of the blocked images, flat and dark.
    flat_range = [0, 1]
    data_range = [48, dark_index]
    dark_range = [dark_index, nfile]

    # # for fast testing
    # data_range = [48, dark_index]

    flat = cdata[flat_range[0]:flat_range[1], :, :]
    proj = cdata[data_range[0]:data_range[1], :, :]
    dark = np.zeros((dark_range[1]-dark_range[0], proj.shape[1], proj.shape[2]))  

    # if you want to use the shutter closed images as dark uncomment this:
    #dark = cdata[dark_range[0]:dark_range[1], :, :]  

    ndata = tomopy.normalize(proj, flat, dark)
    ndata = tomopy.normalize_bg(ndata, air=ndata.shape[2]/2.5)
    ndata = tomopy.minus_log(ndata)
    sharpening(ndata)
    slider(ndata)
Ejemplo n.º 18
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start",
                        nargs='?',
                        const=1,
                        type=int,
                        default=0,
                        help="index of the first image: 100 (default 0)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tiff'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)

    # View the data
    slider(rdata)

    # Apply the sobel filter
    ndata = scale_to_one(rdata)
    ndata = sobel_stack(ndata)
    slider(ndata)

    blur_radius = 3.0
    threshold = .04
    nddata = label(ndata, blur_radius, threshold)
    slider(ndata)
Ejemplo n.º 19
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument("top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start", nargs='?', const=1, type=int, default=0, help="index of the first image: 100 (default 0)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[1]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tiff'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    # Read the tiff raw data.
    rdata = dxchange.read_tiff_stack(fname, ind=ind_tomo)
    slider(rdata)
Ejemplo n.º 20
0
    return data


if __name__ == "__main__":

    in_file = sys.argv[1]
    order = int(sys.argv[2])
    iter = int(sys.argv[3])
    binning = int(sys.argv[4])
    #idF = in_file.find('rect')
    out_file = in_file + 'rotated' + str(order) + '_' + str(iter) + '_' + str(
        binning)
    print('rotate', out_file)
    #data = dxchange.read_tiff_stack(in_file+'/results_admm/u/r_00000.tiff', ind=range(0, 2048))
    data = dxchange.read_tiff_stack(in_file + '/data/of_recon/recon/iter' +
                                    str(iter) + '_00000.tiff',
                                    ind=range(0, 2048 // pow(2, binning)))

    data = rotate_batch(data, 51, order)
    data = data.swapaxes(0, 2)
    data = rotate_batch(data, 34, order)
    data = data.swapaxes(0, 2)

    # data = rotate_batch(data, -39)
    # data = data.swapaxes(0,1)
    # data = rotate_batch(data, 34)
    # data = data.swapaxes(0,1)

    dxchange.write_tiff_stack(data[400 // pow(2, binning):800 //
                                   pow(2, binning)],
                              out_file + '/r',
Ejemplo n.º 21
0
patch_size = (dim_img, dim_img)
batch_size = 50
nb_classes = 2
nb_epoch = 12

# number of convolutional filters to use
nb_filters = 32
# size of pooling area for max pooling
nb_pool = 2
# convolution kernel size
nb_conv = 3

fname = '../../test/test_data/1038.tiff'
ind_uncenter1 = range(1038, 1047)
ind_uncenter2 = range(1049, 1057)
uncenter1 = dxchange.read_tiff_stack(fname, ind=ind_uncenter1, digit=4)
uncenter2 = dxchange.read_tiff_stack(fname, ind=ind_uncenter2, digit=4)
uncenter = np.concatenate((uncenter1, uncenter2), axis=0)
uncenter = nor_data(uncenter)
print (uncenter.shape)
uncenter = img_window(uncenter[:, 360:1460, 440:1440], 200)
print (uncenter.shape)
uncenter_patches = extract_3d(uncenter, patch_size, 1)
np.random.shuffle(uncenter_patches)
print (uncenter_patches.shape)
# print uncenter_patches.shape
center_img = dxchange.read_tiff('../../test/test_data/1048.tiff')
center_img = nor_data(center_img)
print (center_img.shape)
center_img = img_window(center_img[360:1460, 440:1440], 400)
center_patches = extract_3d(center_img, patch_size, 1)
Ejemplo n.º 22
0
    radialprofile = (tbinre+1j*tbinim) / np.sqrt(nr)

    return radialprofile


if __name__ == "__main__":

    fname1 = sys.argv[1]
    fname2 = sys.argv[2]
    fnameout = sys.argv[3]
    wsize = int(sys.argv[4])
    pixel = float(sys.argv[5])
    frac = float(sys.argv[6])
    sslice = int(sys.argv[7])
    
    f1 = dxchange.read_tiff_stack(fname1, ind=np.arange(sslice,sslice+wsize))
    f2 = dxchange.read_tiff_stack(fname2, ind=np.arange(sslice,sslice+wsize))

    f1 = f1[f1.shape[0]//2-wsize//2:f1.shape[0]//2+wsize // 2, f1.shape[1]//2-wsize //
            2:f1.shape[1]//2+wsize//2, f1.shape[2]//2-wsize//2:f1.shape[2]//2+wsize//2]
    f2 = f2[f2.shape[0]//2-wsize//2:f2.shape[0]//2+wsize // 2, f2.shape[1]//2-wsize //
            2:f2.shape[1]//2+wsize//2, f2.shape[2]//2-wsize//2:f2.shape[2]//2+wsize//2]

    print(1)
    ff1 = sp.fft.fftshift(sp.fft.fftn(sp.fft.fftshift(f1),workers=-1))
    print(ff1.shape)    
    ff2 = sp.fft.fftshift(sp.fft.fftn(sp.fft.fftshift(f2),workers=-1))
    print(3)
    
    frc1 = radial_profile3d(ff1*np.conj(ff2), np.array(ff1.shape)//2) /\
        np.sqrt(radial_profile3d(np.abs(ff1)**2, np.array(ff1.shape)//2)
Ejemplo n.º 23
0
from tomoalign.utils import find_min_max
from skimage.registration import phase_cross_correlation
#from silx.image import sift
from pystackreg import StackReg

data_prefix = '/local/data/vnikitin/nanomax/'
if __name__ == "__main__":

    # Model parameters
    n = 512  # object size n x,y
    nz = 512  # object size in z
    nmodes = 4
    nscan = 10000
    ntheta = 173  # number of angles (rotations)
    data = dxchange.read_tiff_stack('sorted/psiangle' + str(nmodes) +
                                    str(nscan) + '/r_00000.tiff',
                                    ind=np.arange(0, 173)).astype('float32')
    data_sift = dxchange.read_tiff('siftaligned/data.tif').astype('float32')
    print(data.shape)
    print(data_sift.shape)
    shift = np.zeros((ntheta, 2), dtype='float32')
    for k in range(ntheta):
        shift0, error, diffphase = phase_cross_correlation(data[k],
                                                           data_sift[k],
                                                           upsample_factor=1)
        print(shift0)
        data[k] = np.roll(data[k], (-int(shift0[0]), -int(shift0[1])),
                          axis=(0, 1))
        shift[k] = shift0
    dxchange.write_tiff_stack(data,
                              'sift_aligned_check/psiangle' + str(nmodes) +
Ejemplo n.º 24
0
patch_size = (dim_img, dim_img)
batch_size = 50
nb_classes = 2
nb_epoch = 12

# number of convolutional filters to use
nb_filters = 32
# size of pooling area for max pooling
nb_pool = 2
# convolution kernel size
nb_conv = 3

fname = '../../test/test_data/1038.tiff'
ind_uncenter1 = range(1038, 1047)
ind_uncenter2 = range(1049, 1057)
uncenter1 = dxchange.read_tiff_stack(fname, ind=ind_uncenter1, digit=4)
uncenter2 = dxchange.read_tiff_stack(fname, ind=ind_uncenter2, digit=4)
uncenter = np.concatenate((uncenter1, uncenter2), axis=0)
uncenter = nor_data(uncenter)
print (uncenter.shape)
uncenter = img_window(uncenter[:, 360:1460, 440:1440], 200)
print (uncenter.shape)
uncenter_patches = extract_3d(uncenter, patch_size, 1)
np.random.shuffle(uncenter_patches)
print (uncenter_patches.shape)
# print uncenter_patches.shape
center_img = dxchange.read_tiff('../../test/test_data/1048.tiff')
center_img = nor_data(center_img)
print (center_img.shape)
center_img = img_window(center_img[360:1460, 440:1440], 400)
center_patches = extract_3d(center_img, patch_size, 1)
Ejemplo n.º 25
0
# for k in range(ndsets):
#     data[k*nth:(k+1)*nth] = np.load(fname+'_bin'+str(binning)+str(k)+'.npy').astype('float32')
#     theta[k*nth:(k+1)*nth] = np.load(fname+'_theta'+str(k)+'.npy').astype('float32')
# data[np.isnan(data)]=0
# data-=np.mean(data)
# for k in range(7):
#     dxchange.write_tiff(data[k*100]-data[0], '/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/proj/d/d_0000'+str(k),overwrite=True)
#     dxchange.write_tiff(data[k*100], '/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/proj/r/r_0000'+str(k),overwrite=True)
# exit()

# u = dxchange.read_tiff_stack('/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/results_admm/u/r_00000.tiff',ind=np.arange(0,1024))
# ucg = dxchange.read_tiff_stack('/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/results_cg/u/r_00000.tiff',ind=np.arange(0,1024))
# ucgn = dxchange.read_tiff_stack('/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/revision_cg_nop_1/results/u/r_00000.tiff',ind=np.arange(0,1024))
# un = dxchange.read_tiff_stack('/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/results_admm_reg3e-06/u/r_00000.tiff',ind=np.arange(0,1024))
upsi3 = dxchange.read_tiff_stack(
    '/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/c2_4_64_80_60_4_FBP_full/tomo_delta__ram-lak_freqscl_1.00_0001.tif',
    ind=np.arange(1, 1025))
#upsi7 = dxchange.read_tiff_stack('/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/c7_4_64_80_60_4_SART_full/tomo_delta__ram-lak_freqscl_1.00_0001.tif',ind=np.arange(1,1025))

# u=  u[:,u.shape[1]//2-612:u.shape[1]//2+612,u.shape[2]//2-612:u.shape[2]//2+612]
# un=un[:,un.shape[1]//2-612:un.shape[1]//2+612,un.shape[2]//2-612:un.shape[2]//2+612]
upsi3 = upsi3[:, upsi3.shape[1] // 2 - 612:upsi3.shape[1] // 2 + 612,
              upsi3.shape[2] // 2 - 612:upsi3.shape[2] // 2 + 612]

# vmin=-0.0018
# vmax=0.0018
# a=u[u.shape[0]//2];a[0]=vmin;a[1]=vmax
# plt.imsave('/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/figs/uz.png',a,vmin=vmin,vmax=vmax,cmap='gray')
# a=u[:,u.shape[1]//2];a[0]=vmin;a[1]=vmax
# plt.imsave('/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/figs/uy.png',a,vmin=vmin,vmax=vmax,cmap='gray')
# a=u[:,:,u.shape[2]//2];a[0]=vmin;a[1]=vmax
Ejemplo n.º 26
0
    ##fname = '/local/decarlo/data/hzg/nanotomography/scan_renamed_450projections_crop_aligned/align_iter_40/radios/image_00000.tiff'
    ##fname = '/local/decarlo/data/hzg/nanotomography/scan_renamed_450projections_crop_rotate_aligned/align_iter_39/radios/image_00000.tiff'
    fname = '/local/decarlo/data/hzg/nanotomography/scan_renamed_450projections_crop_rotate_aligned/align_iter_40/radios/image_00000.tiff'

    sample_detector_distance = 18.8e2
    detector_pixel_size_x = 19.8e-7
    monochromator_energy = 11.0

    # for scan_renamed_450projections
    proj_start = 0
    proj_end = 451

    ind_tomo = range(proj_start, proj_end)

    # Read normalized, centered and -log() data generated by the Doga's alignment routine.
    proj = dxchange.read_tiff_stack(fname, ind=ind_tomo, digit=5)

    # Set data collection angles as equally spaced between 0-180 degrees.
    theta = tomopy.angles(proj.shape[0])

    rot_center = (proj.shape[2]) / 2.0
    print("Center of rotation: ", rot_center)

    # Reconstruct object using Gridrec algorithm.
    rec = tomopy.recon(proj, theta, center=rot_center, algorithm='gridrec')

    # Mask each reconstructed slice with a circle.
    rec = tomopy.circ_mask(rec, axis=0, ratio=0.95)

    # Write data as stack of TIFs.
    ##fname='/local/decarlo/data/hzg/nanotomography/scan_renamed_450projections_crop_aligned/align_iter_40/recon_dir/aligned_gridrec/recon'
Ejemplo n.º 27
0
import numpy as np
import dxchange
import sys
import matplotlib.pyplot as plt

name = sys.argv[1]
id = sys.argv[2]
a = dxchange.read_tiff_stack(name + '/r_00000.tiff',
                             ind=np.arange(0, 512))[:, 64:-64, 64:-64]
print(a.shape)
print(np.linalg.norm(a))
m = np.mean(a[128:-128, 128:-128, 128:-128])
s = np.std(a[128:-128, 128:-128, 128:-128])
a[a > m + 2.5 * s] = m + 2.5 * s
a[a < m - 2.5 * s] = m - 2.5 * s
plt.imsave('figs/z' + str(id) + '.png', a[a.shape[0] // 2], cmap='gray')
plt.imsave('figs/y' + str(id) + '.png', a[:, a.shape[1] // 2], cmap='gray')
plt.imsave('figs/x' + str(id) + '.png', a[:, :, a.shape[2] // 2], cmap='gray')
Ejemplo n.º 28
0
import os
import glob
import dxchange
import numpy as np
import h5py

print("Working Dir")
print(os.getcwd())

# cube = dxchange.read_tiff_stack('cube/z01.tiff',np.arange(1,25)) #raw data 8bit, change "256" to # of sections
cube1 = dxchange.read_tiff_stack('G0013_cube/z0001.tif', np.arange(
    1, 100))  #raw data 8bit, change "256" to # of sections
#cube2 = dxchange.read_tiff_stack('inference_cube_02/z_0000.tif',np.arange(0,75)) #raw data 8bit, change "256" to # of sections

#cube3 = np.append(cube2,cube,axis=0)

# print (cube.shape)
# print ('Mean : '+str(cube.mean()))
# print ('Std : '+str(cube.std()))

print(cube1.shape)
print('Mean : ' + str(cube1.mean()))
print('Std : ' + str(cube1.std()))

#h5file = h5py.File('inference_data.h5', 'w')
#h5file.create_dataset('inference1',data=cube3)
h5file = h5py.File('G0013_cube_data.h5', 'w')
h5file.create_dataset('inference1', data=cube1)
h5file.close()

print("Finished!! Goodbye!!")
Ejemplo n.º 29
0
import numpy as np
from transform import train_patch, predict_patch, train_filter, predict_filter
batch_size = 2200
nb_epoch = 40
patch_step = 1
nb_filters = 16
nb_conv = 3
patch_size = 64
patch_step = 1


spath = '/home/beams/YANGX/cnn_prj_enhance/tf_prd_battery_20170501/'
ipath = 'weights/tf_mouse.h5'
wpath = 'weights/tf_battery.h5'

proj_start = 1200
proj_end = 1201
ind_tomo = range(proj_start, proj_end)
fname = '/home/beams1/YANGX/cnn_prj_enhance/battery1_ds/prj_00000.tiff'

#
# imgx = dxchange.read_tiff('/home/beams1/YANGX/cnn_prj_enhance/battery1_train/trainx.tif')
# imgy = dxchange.read_tiff('/home/beams1/YANGX/cnn_prj_enhance/battery1_train/trainy.tif')
#
# mdl = train_patch(imgx, imgy, patch_size, 3, nb_filters, nb_conv, batch_size, nb_epoch, ipath)
# mdl.save_weights(wpath)



img_n = dxchange.read_tiff_stack(fname, ind_tomo, digit = 5)
predict_patch(img_n, patch_size, 1, nb_filters, nb_conv, batch_size, wpath, spath)
Ejemplo n.º 30
0
def main(arg):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "top", help="top directory where the tiff images are located: /data/")
    parser.add_argument("start",
                        nargs='?',
                        const=1,
                        type=int,
                        default=1,
                        help="index of the first image: 1000 (default 1)")

    args = parser.parse_args()

    top = args.top
    index_start = int(args.start)

    template = os.listdir(top)[0]

    nfile = len(fnmatch.filter(os.listdir(top), '*.tif'))
    index_end = index_start + nfile
    ind_tomo = range(index_start, index_end)

    fname = top + template

    print(nfile, index_start, index_end, fname)

    # Select the sinogram range to reconstruct.
    start = 0
    end = 512
    sino = (start, end)

    # Read the tiff raw data.
    ndata = dxchange.read_tiff_stack(fname, ind=ind_tomo, slc=(sino, None))

    # Normalize to 1 using the air counts
    ndata = tomopy.normalize_bg(ndata, air=5)

    # Set data collection angles as equally spaced between 0-180 degrees.
    theta = tomopy.angles(ndata.shape[0])

    ndata = tomopy.minus_log(ndata)

    # Set binning and number of iterations
    binning = 8
    iters = 21

    print("Original", ndata.shape)
    ndata = tomopy.downsample(ndata, level=binning, axis=1)
    #    ndata = tomopy.downsample(ndata, level=binning, axis=2)
    print("Processing:", ndata.shape)

    fdir = 'aligned' + '/noblur_iter_' + str(iters) + '_bin_' + str(binning)

    print(fdir)
    cprj, sx, sy, conv = alignment.align_seq(ndata,
                                             theta,
                                             fdir=fdir,
                                             iters=iters,
                                             pad=(10, 10),
                                             blur=False,
                                             save=True,
                                             debug=True)

    np.save(fdir + '/shift_x', sx)
    np.save(fdir + '/shift_y', sy)

    # Write aligned projections as stack of TIFs.
    dxchange.write_tiff_stack(cprj, fname=fdir + '/radios/image')
Ejemplo n.º 31
0
    tbinre = np.bincount(r.ravel(), data.real.ravel())
    tbinim = np.bincount(r.ravel(), data.imag.ravel())

    nr = np.bincount(r.ravel())
    radialprofile = (tbinre + 1j * tbinim) / np.sqrt(nr)

    return radialprofile


wsize = 160
fname1 = '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/d/results_admm0/u/r_00000.tiff'
fname2 = '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/d/results_admm1/u/r_00000.tiff'
# fname1 = '/data/staff/tomograms/vviknik/tomoalign_vincent_data/brain/Brain_Petrapoxy_day2_2880prj_1440deg_167/cga_resolution1__1440/rect3/r_00000.tiff'
# fname2 = '/data/staff/tomograms/vviknik/tomoalign_vincent_data/brain/Brain_Petrapoxy_day2_2880prj_1440deg_167/cga_resolution2__1440/rect3/r_00000.tiff'

f1 = dxchange.read_tiff_stack(fname1, ind=np.arange(0,
                                                    320))[:].astype('float32')
f2 = dxchange.read_tiff_stack(fname2, ind=np.arange(0,
                                                    320))[:].astype('float32')
f1 = f1[f1.shape[0] // 2, f1.shape[1] // 2 - wsize:f1.shape[1] // 2 + wsize,
        f1.shape[2] // 2 - wsize:f1.shape[2] // 2 + wsize]
f2 = f2[f2.shape[0] // 2, f2.shape[1] // 2 - wsize:f2.shape[1] // 2 + wsize,
        f2.shape[2] // 2 - wsize:f2.shape[2] // 2 + wsize]
print(f1.shape)
dxchange.write_tiff(
    f1, '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/t1')
dxchange.write_tiff(
    f2, '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/t2')
ff1 = np.fft.fftshift(np.fft.fft2(f1))
ff2 = np.fft.fftshift(np.fft.fft2(f2))
print(ff1.shape)
print(np.array(ff1.shape) // 2)
Ejemplo n.º 32
0
import os
import glob
import dxchange
import numpy as np
import h5py

print("Working Dir")
print(os.getcwd())

cube = dxchange.read_tiff_stack('cube02/z01.tif', np.arange(
    1, 50))  #raw data 8bit, change "256" to # of sections
labels = dxchange.read_tiff_stack('labels02/l01.tif', np.arange(
    1, 50))  #label "ground truth" uint 8 or 32

print('Cube Properties!')
print(cube.dtype)
print(cube.shape)

print('Mean : ' + str(cube.mean()))
print('Std : ' + str(cube.std()))

print('Labels Properties!')
print(labels.dtype)
print(labels.shape)

print('Ids Properties!')
ids = np.unique(labels, return_counts=1)
print(ids)

#raf added here to pad256
# cube  = np.pad(cube,((115,116),(0,0),(0,0)),'reflect')
Ejemplo n.º 33
0
# ndsets = 7
# nth = 100
# data = np.zeros([ndsets*nth,2048//pow(2,binning),2448//pow(2,binning)],dtype='float32')
# theta = np.zeros(ndsets*nth,dtype='float32')
# for k in range(ndsets):
#     data[k*nth:(k+1)*nth] = np.load(fname+'_bin'+str(binning)+str(k)+'.npy').astype('float32')
#     theta[k*nth:(k+1)*nth] = np.load(fname+'_theta'+str(k)+'.npy').astype('float32')
# data[np.isnan(data)]=0
# data-=np.mean(data)
# for k in range(7):
#     dxchange.write_tiff(data[k*100]-data[0], '/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/proj/d/d_0000'+str(k),overwrite=True)
#     dxchange.write_tiff(data[k*100], '/data/staff/tomograms/vviknik/tomoalign_vincent_data/mask/Run4_9_1_40min_8keV_phase_100proj_per_rot_interlaced_1201prj_1s_024/proj/r/r_0000'+str(k),overwrite=True)
# exit()

u = dxchange.read_tiff_stack(
    '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/d/results_admm/u/r_00000.tiff',
    ind=np.arange(0, 320))[10:-10, 70:-70, 70:-70]
ucg = dxchange.read_tiff_stack(
    '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/d/results_cg/u/r_00000.tiff',
    ind=np.arange(0, 320))[10:-10, 70:-70, 70:-70]
un = dxchange.read_tiff_stack(
    '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/d/results_admm_reg4e-06/u/r_00000.tiff',
    ind=np.arange(0, 320))[10:-10, 70:-70, 70:-70]

ucgn = dxchange.read_tiff_stack(
    '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/TIFF_beta_ram-lak_freqscl_1.00_nonrigid_SART/tomo_delta__ram-lak_freqscl_1.00_0001.tif',
    ind=np.arange(1, 321)).astype('float32')[10:-10, 70:-70,
                                             70:-70]  #.swapaxes(1,2)

vmax = -5e-4
vmin = -0.006
Ejemplo n.º 34
0
def reconstruct_fullfield(fname, theta_st=0, theta_end=PI, n_epochs='auto', crit_conv_rate=0.03, max_nepochs=200,
                          alpha=1e-7, alpha_d=None, alpha_b=None, gamma=1e-6, learning_rate=1.0,
                          output_folder=None, minibatch_size=None, save_intermediate=False, full_intermediate=False,
                          energy_ev=5000, psize_cm=1e-7, n_epochs_mask_release=None, cpu_only=False, save_path='.',
                          phantom_path='phantom', shrink_cycle=20, core_parallelization=True, free_prop_cm=None,
                          multiscale_level=1, n_epoch_final_pass=None, initial_guess=None, n_batch_per_update=5,
                          dynamic_rate=True, probe_type='plane', probe_initial=None, probe_learning_rate=1e-3,
                          pupil_function=None, theta_downsample=None, forward_algorithm='fresnel', random_theta=True,
                          object_type='normal', **kwargs):
    """
    Reconstruct a beyond depth-of-focus object.
    :param fname: Filename and path of raw data file. Must be in HDF5 format.
    :param theta_st: Starting rotation angle.
    :param theta_end: Ending rotation angle.
    :param n_epochs: Number of epochs to be executed. If given 'auto', optimizer will stop
                     when reduction rate of loss function goes below crit_conv_rate.
    :param crit_conv_rate: Reduction rate of loss function below which the optimizer should
                           stop.
    :param max_nepochs: The maximum number of epochs to be executed if n_epochs is 'auto'.
    :param alpha: Weighting coefficient for both delta and beta regularizer. Should be None
                  if alpha_d and alpha_b are specified.
    :param alpha_d: Weighting coefficient for delta regularizer.
    :param alpha_b: Weighting coefficient for beta regularizer.
    :param gamma: Weighting coefficient for TV regularizer.
    :param learning_rate: Learning rate of ADAM.
    :param output_folder: Name of output folder. Put None for auto-generated pattern.
    :param downsample: Downsampling (not implemented yet).
    :param minibatch_size: Size of minibatch.
    :param save_intermediate: Whether to save the object after each epoch.
    :param energy_ev: Beam energy in eV.
    :param psize_cm: Pixel size in cm.
    :param n_epochs_mask_release: The number of epochs after which the finite support mask
                                  is released. Put None to disable this feature.
    :param cpu_only: Whether to disable GPU.
    :param save_path: The location of finite support mask, the prefix of output_folder and
                      other metadata.
    :param phantom_path: The location of phantom objects (for test version only).
    :param shrink_cycle: Shrink-wrap is executed per every this number of epochs.
    :param core_parallelization: Whether to use Horovod for parallelized computation within
                                 this function.
    :param free_prop_cm: The distance to propagate the wavefront in free space after exiting
                         the sample, in cm.
    :param multiscale_level: The level of multiscale processing. When this number is m and
                             m > 1, m - 1 low-resolution reconstructions will be performed
                             before reconstructing with the original resolution. The downsampling
                             factor for these coarse reconstructions will be [2^(m - 1),
                             2^(m - 2), ..., 2^1].
    :param n_epoch_final_pass: specify a number of iterations for the final pass if multiscale
                               is activated. If None, it will be the same as n_epoch.
    :param initial_guess: supply an initial guess. If None, object will be initialized with noises.
    :param n_batch_per_update: number of minibatches during which gradients are accumulated, after
                               which obj is updated.
    :param dynamic_rate: when n_batch_per_update > 1, adjust learning rate dynamically to allow it
                         to decrease with epoch number
    :param probe_type: type of wavefront. Can be 'plane', '  fixed', or 'optimizable'. If 'optimizable',
                           the probe function will be optimized along with the object.
    :param probe_initial: can be provided for 'optimizable' probe_type, and must be provided for
                              'fixed'.
    """

    def rotate_and_project(i, loss, obj_delta, obj_beta):

        warnings.warn('Obsolete function. The output loss is scaled by minibatch_size. Proceed with caution.')
        obj_rot = tf_rotate(tf.stack([obj_delta, obj_beta], axis=-1), this_theta_batch[i], interpolation='BILINEAR')
        if not cpu_only:
            with tf.device('/gpu:0'):
                exiting = multislice_propagate(obj_rot[:, :, :, 0], obj_rot[:, :, :, 1], probe_real, probe_imag, energy_ev, psize_cm * ds_level, h=h, free_prop_cm=free_prop_cm)
        else:
            exiting = multislice_propagate(obj_rot[:, :, :, 0], obj_rot[:, :, :, 1], probe_real, probe_imag, energy_ev, psize_cm * ds_level, h=h, free_prop_cm=free_prop_cm)
        loss += tf.reduce_mean(tf.squared_difference(tf.abs(exiting), tf.abs(this_prj_batch[i])))
        # i = tf.add(i, 1)
        return (i, loss, obj)

    def rotate_and_project_batch(obj_delta, obj_beta):

        obj_rot_batch = []
        for i in range(minibatch_size):
            obj_rot_batch.append(tf_rotate(tf.stack([obj_delta, obj_beta], axis=-1), this_theta_batch[i], interpolation='BILINEAR'))
        # obj_rot = apply_rotation(obj, coord_ls[rand_proj], 'arrsize_64_64_64_ntheta_500')
        obj_rot_batch = tf.stack(obj_rot_batch)
        if probe_type == 'point':
            exiting_batch = multislice_propagate_spherical(obj_rot_batch[:, :, :, :, 0], obj_rot_batch[:, :, :, :, 1],
                                                           probe_real, probe_imag, energy_ev,
                                                           psize_cm * ds_level, dist_to_source_cm, det_psize_cm,
                                                           theta_max, phi_max, free_prop_cm,
                                                           obj_batch_shape=[minibatch_size, *obj_size])
        else:
            if forward_algorithm == 'fresnel':
                exiting_batch = multislice_propagate_batch(obj_rot_batch[:, :, :, :, 0], obj_rot_batch[:, :, :, :, 1],
                                                        probe_real, probe_imag, energy_ev,
                                                        psize_cm * ds_level, free_prop_cm=free_prop_cm, obj_batch_shape=[minibatch_size, *obj_size])
            elif forward_algorithm == 'fd':
                exiting_batch = multislice_propagate_fd(obj_rot_batch[:, :, :, :, 0], obj_rot_batch[:, :, :, :, 1],
                                                        probe_real, probe_imag, energy_ev,
                                                        psize_cm * ds_level, free_prop_cm=free_prop_cm,
                                                        obj_batch_shape=[minibatch_size, *obj_size])
        loss = tf.reduce_mean(tf.squared_difference(tf.abs(exiting_batch), tf.abs(this_prj_batch)), name='loss')
        return loss, exiting_batch

    # import Horovod or its fake shell
    if core_parallelization is False:
        warnings.warn('Parallelization is disabled in the reconstruction routine. ')
        from pseudo import hvd
    else:
        try:
            import horovod.tensorflow as hvd
            hvd.init()
        except:
            from pseudo import Hvd
            hvd = Hvd()
            warnings.warn('Unable to import Horovod.')
        try:
            assert hvd.mpi_threads_supported()
        except:
            warnings.warn('MPI multithreading is not supported.')
        try:
            import mpi4py.rc
            mpi4py.rc.initialize = False
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            mpi4py_is_ok = True
            assert hvd.size() == comm.Get_size()
        except:
            warnings.warn('Unable to import mpi4py. Using multiple threads with n_epoch set to "auto" may lead to undefined behaviors.')
            from pseudo import Mpi
            comm = Mpi()
            mpi4py_is_ok = False

    # global_step = tf.Variable(0, trainable=False, name='global_step')

    t0 = time.time()

    # read data
    print_flush('Reading data...')
    f = h5py.File(os.path.join(save_path, fname), 'r')
    prj_0 = f['exchange/data'][...].astype('complex64')
    theta = -np.linspace(theta_st, theta_end, prj_0.shape[0], dtype='float32')
    n_theta = len(theta)
    prj_theta_ind = np.arange(n_theta, dtype=int)
    if theta_downsample is not None:
        prj_0 = prj_0[::theta_downsample]
        theta = theta[::theta_downsample]
        prj_theta_ind = prj_theta_ind[::theta_downsample]
        n_theta = len(theta)
    original_shape = prj_0.shape
    comm.Barrier()
    print_flush('Data reading: {} s'.format(time.time() - t0))
    print_flush('Data shape: {}'.format(original_shape))
    comm.Barrier()

    if probe_type == 'point':
        dist_to_source_cm = kwargs['dist_to_source_cm']
        det_psize_cm = kwargs['det_psize_cm']
        theta_max = kwargs['theta_max']
        phi_max = kwargs['phi_max']

    initializer_flag = False

    if output_folder is None:
        output_folder = 'recon_360_minibatch_{}_' \
                        'mskrls_{}_' \
                        'shrink_{}_' \
                        'iter_{}_' \
                        'alphad_{}_' \
                        'alphab_{}_' \
                        'gamma_{}_' \
                        'rate_{}_' \
                        'energy_{}_' \
                        'size_{}_' \
                        'ntheta_{}_' \
                        'prop_{}_' \
                        'ms_{}_' \
                        'cpu_{}' \
            .format(minibatch_size, n_epochs_mask_release, shrink_cycle,
                    n_epochs, alpha_d, alpha_b,
                    gamma, learning_rate, energy_ev,
                    prj_0.shape[-1], prj_0.shape[0], free_prop_cm,
                    multiscale_level, cpu_only)
        if abs(PI - theta_end) < 1e-3:
            output_folder += '_180'

    if save_path != '.':
        output_folder = os.path.join(save_path, output_folder)

    for ds_level in range(multiscale_level - 1, -1, -1):

        graph = tf.Graph()
        graph.as_default()

        ds_level = 2 ** ds_level
        print_flush('Multiscale downsampling level: {}'.format(ds_level))
        comm.Barrier()

        # downsample data
        prj = np.copy(prj_0)
        if ds_level > 1:
            prj = prj[:, ::ds_level, ::ds_level]
            prj = prj.astype('complex64')
        comm.Barrier()

        dim_y, dim_x = prj.shape[-2:]
        if random_theta:
            prj_dataset = tf.data.Dataset.from_tensor_slices((theta, prj)).shard(hvd.size(), hvd.rank()).shuffle(
                buffer_size=100).repeat().batch(minibatch_size)
        else:
            prj_dataset = tf.data.Dataset.from_tensor_slices((theta, prj)).shard(hvd.size(), hvd.rank()).repeat().batch(minibatch_size)
        prj_iter = prj_dataset.make_one_shot_iterator()
        this_theta_batch, this_prj_batch = prj_iter.get_next()
        comm.Barrier()

        # # read rotation data
        # try:
        #     coord_ls = read_all_origin_coords('arrsize_64_64_64_ntheta_500', n_theta)
        # except:
        #     save_rotation_lookup([dim_y, dim_x, dim_x], n_theta)
        #     coord_ls = read_all_origin_coords('arrsize_64_64_64_ntheta_500', n_theta)

        if minibatch_size is None:
            minibatch_size = n_theta

        if n_epochs_mask_release is None:
            n_epochs_mask_release = np.inf

        # =============== finite support mask ==============
        try:
            mask = dxchange.read_tiff_stack(os.path.join(save_path, 'fin_sup_mask', 'mask_00000.tiff'), range(prj_0.shape[1]), 5)
        except:
            try:
                mask = dxchange.read_tiff(os.path.join(save_path, 'fin_sup_mask', 'mask.tiff'))
            except:
                obj_pr = dxchange.read_tiff_stack(os.path.join(save_path, 'paganin_obj/recon_00000.tiff'), range(prj_0.shape[1]), 5)
                obj_pr = gaussian_filter(np.abs(obj_pr), sigma=3, mode='constant')
                mask = np.zeros_like(obj_pr)
                mask[obj_pr > 1e-5] = 1
                dxchange.write_tiff_stack(mask, os.path.join(save_path, 'fin_sup_mask/mask'), dtype='float32', overwrite=True)
        if ds_level > 1:
            mask = mask[::ds_level, ::ds_level, ::ds_level]
        mask_np = mask
        mask = tf.convert_to_tensor(mask, dtype=tf.float32, name='mask')
        dim_z = mask.shape[-1]

        # unify random seed for all threads
        comm.Barrier()
        seed = int(time.time() / 60)
        np.random.seed(seed)
        comm.Barrier()

        # initializer_flag = True
        np.random.seed(int(time.time()))
        if initializer_flag == False:
            if initial_guess is None:
                print_flush('Initializing with Gaussian random.')
                # grid_delta = np.load(os.path.join(phantom_path, 'grid_delta.npy'))
                # grid_beta = np.load(os.path.join(phantom_path, 'grid_beta.npy'))
                obj_delta_init = np.random.normal(size=[dim_y, dim_x, dim_z], loc=8.7e-7, scale=1e-7) * mask_np
                obj_beta_init = np.random.normal(size=[dim_y, dim_x, dim_z], loc=5.1e-8, scale=1e-8) * mask_np
                obj_delta_init[obj_delta_init < 0] = 0
                obj_beta_init[obj_beta_init < 0] = 0
            else:
                print_flush('Using supplied initial guess.')
                sys.stdout.flush()
                obj_delta_init = initial_guess[0]
                obj_beta_init = initial_guess[1]
        else:
            print_flush('Initializing with Gaussian random.')
            obj_delta_init = dxchange.read_tiff(os.path.join(output_folder, 'delta_ds_{}.tiff'.format(ds_level * 2)))
            obj_beta_init = dxchange.read_tiff(os.path.join(output_folder, 'beta_ds_{}.tiff'.format(ds_level * 2)))
            obj_delta_init = upsample_2x(obj_delta_init)
            obj_beta_init = upsample_2x(obj_beta_init)
            obj_delta_init += np.random.normal(size=[dim_y, dim_x, dim_z], loc=8.7e-7, scale=1e-7) * mask_np
            obj_beta_init += np.random.normal(size=[dim_y, dim_x, dim_z], loc=5.1e-8, scale=1e-8) * mask_np
            obj_delta_init[obj_delta_init < 0] = 0
            obj_beta_init[obj_beta_init < 0] = 0
        obj_size = obj_delta_init.shape
        if object_type == 'phase_only':
            obj_beta_init[...] = 0
            obj_delta = tf.Variable(initial_value=obj_delta_init, dtype=tf.float32, name='obj_delta')
            obj_beta = tf.constant(obj_beta_init, dtype=tf.float32, name='obj_beta')
        elif object_type == 'absorption_only':
            obj_delta_init[...] = 0
            obj_delta = tf.constant(obj_delta_init, dtype=tf.float32, name='obj_delta')
            obj_beta = tf.Variable(initial_value=obj_beta_init, dtype=tf.float32, name='obj_beta')
        else:
            obj_delta = tf.Variable(initial_value=obj_delta_init, dtype=tf.float32, name='obj_delta')
            obj_beta = tf.Variable(initial_value=obj_beta_init, dtype=tf.float32, name='obj_beta')
        # ====================================================



        if probe_type == 'plane':
            probe_real = tf.constant(np.ones([dim_y, dim_x]), dtype=tf.float32)
            probe_imag = tf.constant(np.zeros([dim_y, dim_x]), dtype=tf.float32)
        elif probe_type == 'optimizable':
            if probe_initial is not None:
                probe_mag, probe_phase = probe_initial
                probe_real, probe_imag = mag_phase_to_real_imag(probe_mag, probe_phase)
            else:
                # probe_mag = np.ones([dim_y, dim_x])
                # probe_phase = np.zeros([dim_y, dim_x])
                back_prop_cm = (free_prop_cm + (psize_cm * obj_size[2])) if free_prop_cm is not None else (psize_cm * obj_size[2])
                probe_init = create_probe_initial_guess(os.path.join(save_path, fname), back_prop_cm * 1.e7, energy_ev, psize_cm * 1.e7)
                probe_real = probe_init.real
                probe_imag = probe_init.imag
            if pupil_function is not None:
                probe_real = probe_real * pupil_function
                probe_imag = probe_imag * pupil_function
                pupil_function = tf.convert_to_tensor(pupil_function, dtype=tf.float32)
            probe_real = tf.Variable(probe_real, dtype=tf.float32, trainable=True)
            probe_imag = tf.Variable(probe_imag, dtype=tf.float32, trainable=True)
        elif probe_type == 'fixed':
            probe_mag, probe_phase = probe_initial
            probe_real, probe_imag = mag_phase_to_real_imag(probe_mag, probe_phase)
            probe_real = tf.constant(probe_real, dtype=tf.float32)
            probe_imag = tf.constant(probe_imag, dtype=tf.float32)
        elif probe_type == 'point':
            # this should be in spherical coordinates
            probe_real = tf.constant(np.ones([dim_y, dim_x]), dtype=tf.float32)
            probe_imag = tf.constant(np.zeros([dim_y, dim_x]), dtype=tf.float32)
        elif probe_type == 'gaussian':
            probe_mag_sigma = kwargs['probe_mag_sigma']
            probe_phase_sigma = kwargs['probe_phase_sigma']
            probe_phase_max = kwargs['probe_phase_max']
            py = np.arange(obj_size[0]) - (obj_size[0] - 1.) / 2
            px = np.arange(obj_size[1]) - (obj_size[1] - 1.) / 2
            pxx, pyy = np.meshgrid(px, py)
            probe_mag = np.exp(-(pxx ** 2 + pyy ** 2) / (2 * probe_mag_sigma ** 2))
            probe_phase = probe_phase_max * np.exp(
                -(pxx ** 2 + pyy ** 2) / (2 * probe_phase_sigma ** 2))
            probe_real, probe_imag = mag_phase_to_real_imag(probe_mag, probe_phase)
            probe_real = tf.constant(probe_real, dtype=tf.float32)
            probe_imag = tf.constant(probe_imag, dtype=tf.float32)
        else:
            raise ValueError('Invalid wavefront type. Choose from \'plane\', \'fixed\', \'optimizable\'.')

        # =============finite support===================
        obj_delta = obj_delta * mask
        obj_beta = obj_beta * mask
        # obj_delta = tf.nn.relu(obj_delta)
        # obj_beta = tf.nn.relu(obj_beta)
        # ==============================================
        # ================shrink wrap===================
        def shrink_wrap():
            boolean = tf.cast(obj_delta > 1e-15, dtype=tf.float32)
            _mask = mask * boolean
            return _mask
        def return_mask(): return mask
        # if initializer_flag == False:
        i_epoch = tf.Variable(0, trainable=False, dtype='float32')
        if shrink_cycle is not None:
            mask = tf.cond(tf.greater(i_epoch, shrink_cycle), shrink_wrap, return_mask)
        # if hvd.rank() == 0 and hvd.local_rank() == 0:
        #     dxchange.write_tiff(np.squeeze(sess.run(mask)),
        #                         os.path.join(save_path, 'fin_sup_mask/runtime_mask/epoch_{}'.format(epoch)), dtype='float32', overwrite=True)
        # ==============================================

        # generate Fresnel kernel
        voxel_nm = np.array([psize_cm] * 3) * 1.e7 * ds_level
        lmbda_nm = 1240. / energy_ev
        delta_nm = voxel_nm[-1]
        kernel = get_kernel(delta_nm, lmbda_nm, voxel_nm, [dim_y, dim_y, dim_x])
        h = tf.convert_to_tensor(kernel, dtype=tf.complex64, name='kernel')

        # loss = tf.constant(0.)
        # i = tf.constant(0)
        # for j in range(minibatch_size):
        #     i, loss, obj = rotate_and_project(i, loss, obj)
        loss, exiting = rotate_and_project_batch(obj_delta, obj_beta)

        # loss = loss / n_theta + alpha * tf.reduce_sum(tf.image.total_variation(obj))
        # loss = loss / n_theta + gamma * energy_leak(obj, mask_add)
        if alpha_d is None:
            reg_term = alpha * (tf.norm(obj_delta, ord=1) + tf.norm(obj_delta, ord=1)) + gamma * total_variation_3d(obj_delta)
        else:
            if gamma == 0:
                reg_term = alpha_d * tf.norm(obj_delta, ord=1) + alpha_b * tf.norm(obj_beta, ord=1)
            else:
                reg_term = alpha_d * tf.norm(obj_delta, ord=1) + alpha_b * tf.norm(obj_beta, ord=1) + gamma * total_variation_3d(obj_delta)
        loss = loss + reg_term

        if probe_type == 'optimizable':
            probe_reg = 1.e-10 * (tf.image.total_variation(tf.reshape(probe_real, [dim_y, dim_x, -1])) +
                                   tf.image.total_variation(tf.reshape(probe_real, [dim_y, dim_x, -1])))
            loss = loss + probe_reg
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('regularizer', reg_term)
        tf.summary.scalar('error', loss - reg_term)

        if dynamic_rate and n_batch_per_update > 1:
            # modifier =  1. / n_batch_per_update
            modifier = tf.exp(-i_epoch) * (n_batch_per_update - 1) + 1
            optimizer = tf.train.AdamOptimizer(learning_rate=float(learning_rate) * hvd.size() * modifier)
        else:
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate * hvd.size())
        optimizer = hvd.DistributedOptimizer(optimizer, name='distopt_{}'.format(ds_level))
        if n_batch_per_update > 1:
            accum_grad_delta = tf.Variable(tf.zeros_like(obj_delta.initialized_value()), trainable=False)
            accum_grad_beta = tf.Variable(tf.zeros_like(obj_beta.initialized_value()), trainable=False)
            this_grad_delta = optimizer.compute_gradients(loss, obj_delta)
            this_grad_delta = this_grad_delta[0]
            this_grad_beta = optimizer.compute_gradients(loss, this_grad_beta)
            this_grad_beta = this_grad_beta[0]
            initialize_grad_delta = accum_grad_delta.assign(tf.zeros_like(accum_grad_delta))
            initialize_grad_beta = accum_grad_beta.assign(tf.zeros_like(accum_grad_beta))
            accum__op_delta = accum_grad_delta.assign_add(this_grad_delta[0])
            accum_op_beta = accum_grad_beta.assign_add(this_grad_beta[0])
            update_obj_delta = optimizer.apply_gradients([(accum_grad_delta / n_batch_per_update, this_grad_beta[1])])
            update_obj_beta = optimizer.apply_gradients([(accum_grad_beta / n_batch_per_update, this_grad_beta[1])])
        else:
            ###
            this_grad_delta = optimizer.compute_gradients(loss, obj_delta)
            ###
            if object_type == 'normal':
                optimizer = optimizer.minimize(loss)
            elif object_type == 'phase_only':
                optimizer = optimizer.minimize(loss)
            elif object_type == 'absorption_only':
                optimizer = optimizer.minimize(loss)
            else:
                raise ValueError
        # if minibatch_size >= n_theta:
        #     optimizer = optimizer.minimize(loss, var_list=[obj])
        # hooks = [hvd.BroadcastGlobalVariablesHook(0)]

        if probe_type == 'optimizable':
            optimizer_probe = tf.train.AdamOptimizer(learning_rate=probe_learning_rate * hvd.size())
            optimizer_probe = hvd.DistributedOptimizer(optimizer_probe, name='distopt_probe_{}'.format(ds_level))
            if n_batch_per_update > 1:
                accum_grad_probe = [tf.Variable(tf.zeros_like(probe_real.initialized_value()), trainable=False),
                                    tf.Variable(tf.zeros_like(probe_imag.initialized_value()), trainable=False)]
                this_grad_probe = optimizer_probe.compute_gradients(loss, [probe_real, probe_imag])
                initialize_grad_probe = [accum_grad_probe[i].assign(tf.zeros_like(accum_grad_probe[i])) for i in range(2)]
                accum_op_probe = [accum_grad_probe[i].assign_add(this_grad_probe[i][0]) for i in range(2)]
                update_probe = [optimizer_probe.apply_gradients([(accum_grad_probe[i] / n_batch_per_update, this_grad_probe[i][1])]) for i in range(2)]
            else:
                optimizer_probe = optimizer_probe.minimize(loss, var_list=[probe_real, probe_imag])
            if minibatch_size >= n_theta:
                optimizer_probe = optimizer_probe.minimize(loss, var_list=[probe_real, probe_imag])

        # =============finite support===================
        obj_delta = obj_delta * mask
        obj_beta = obj_beta * mask
        obj_delta = tf.nn.relu(obj_delta)
        obj_beta = tf.nn.relu(obj_beta)
        # ==============================================

        loss_ls = []
        reg_ls = []

        merged_summary_op = tf.summary.merge_all()

        # create benchmarking metadata
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()
        if cpu_only:
            sess = tf.Session(config=tf.ConfigProto(device_count = {'GPU': 0}, allow_soft_placement=True))
        else:
            config = tf.ConfigProto(log_device_placement=False)
            config.gpu_options.allow_growth = True
            config.gpu_options.visible_device_list = str(hvd.local_rank())
            sess = tf.Session(config=config)

        sess.run(tf.global_variables_initializer())
        hvd.broadcast_global_variables(0)

        if hvd.rank() == 0:
            preset = 'pp' if probe_type == 'point' else 'fullfield'
            create_summary(output_folder, locals(), preset=preset)
            summary_writer = tf.summary.FileWriter(os.path.join(output_folder, 'tb'), sess.graph)

        t0 = time.time()

        print_flush('Optimizer started.')

        n_loop = n_epochs if n_epochs != 'auto' else max_nepochs
        if ds_level == 1 and n_epoch_final_pass is not None:
            n_loop = n_epoch_final_pass
        n_batch = int(np.ceil(float(n_theta) / minibatch_size) / hvd.size())
        t00 = time.time()
        for epoch in range(n_loop):
            if mpi4py_is_ok:
                stop_iteration = False
            else:
                stop_iteration_file = open('.stop_itertion', 'w')
                stop_iteration_file.write('False')
                stop_iteration_file.close()
            i_epoch = i_epoch + 1
            if minibatch_size <= n_theta:
                batch_counter = 0
                for i_batch in range(n_batch):
                    try:
                        if n_batch_per_update > 1:
                            t0_batch = time.time()
                            if probe_type == 'optimizable':
                                _, _, _, current_loss, current_reg, current_probe_reg, summary_str = sess.run(
                                    [accum__op_delta, accum_op_beta, accum_op_probe, loss, reg_term, probe_reg, merged_summary_op], options=run_options,
                                    run_metadata=run_metadata)
                            else:
                                _, _, current_loss, current_reg, summary_str = sess.run(
                                    [accum__op_delta, accum_op_beta, loss, reg_term, merged_summary_op], options=run_options,
                                    run_metadata=run_metadata)
                            print_flush('Minibatch done in {} s (rank {}); current loss = {}, probe reg. = {}.'.format(time.time() - t0_batch, hvd.rank(), current_loss, current_probe_reg))
                            batch_counter += 1
                            if batch_counter == n_batch_per_update or i_batch == n_batch - 1:
                                sess.run([update_obj_delta, update_obj_beta])
                                sess.run([initialize_grad_delta, initialize_grad_beta])
                                if probe_type == 'optimizable':
                                    sess.run(update_probe)
                                    sess.run(initialize_grad_probe)
                                batch_counter = 0
                                print_flush('Gradient applied.')
                        else:
                            t0_batch = time.time()
                            if probe_type == 'optimizable':
                                _, _, current_loss, current_reg, current_probe_reg, summary_str = sess.run([optimizer, optimizer_probe, loss, reg_term, probe_reg, merged_summary_op], options=run_options, run_metadata=run_metadata)
                                print_flush(
                                    'Minibatch done in {} s (rank {}); current loss = {}, probe reg. = {}.'.format(
                                        time.time() - t0_batch, hvd.rank(), current_loss, current_probe_reg))
                            else:
                                _, current_loss, current_reg, summary_str, mask_int, current_obj, current_grad = sess.run([optimizer, loss, reg_term, merged_summary_op, mask, obj_delta, this_grad_delta], options=run_options, run_metadata=run_metadata)
                                print_flush(
                                    'Minibatch done in {} s (rank {}); current loss = {}; current reg = {}.'.format(
                                        time.time() - t0_batch, hvd.rank(), current_loss, current_reg))
                                # dxchange.write_tiff(current_obj[int(105./256*current_obj.shape[0])], os.path.join(output_folder, 'intermediate_minibatch/delta'), dtype='float32')
                                # dxchange.write_tiff(current_grad[0][0][int(105./256*current_obj.shape[0])], os.path.join(output_folder, 'intermediate_minibatch/grad'), dtype='float32')
                                # dxchange.write_tiff(abs(mask_int), os.path.join(output_folder, 'masks', 'mask_{}'.format(epoch)), dtype='float32')
                                # dxchange.write_tiff(exiting_wave, save_path + '/exit', dtype='float32')
                        # enforce pupil function
                        if probe_type == 'optimizable' and pupil_function is not None:
                            probe_real = probe_real * pupil_function
                            probe_imag = probe_imag * pupil_function

                    except tf.errors.OutOfRangeError:
                        break
            else:
                if probe_type == 'optimizable':
                    _, _, current_loss, current_reg, summary_str = sess.run([optimizer, optimizer_probe, loss, reg_term, merged_summary_op], options=run_options, run_metadata=run_metadata)
                else:
                    _, current_loss, current_reg, summary_str = sess.run([optimizer, loss, reg_term, merged_summary_op], options=run_options, run_metadata=run_metadata)

            # timeline for benchmarking
            if hvd.rank() == 0:
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                try:
                    os.makedirs(os.path.join(output_folder, 'profiling'))
                except:
                    pass
                with open(os.path.join(output_folder, 'profiling', 'time_{}.json'.format(epoch)), 'w') as f:
                    f.write(ctf)
                    f.close()

            # check stopping criterion
            if n_epochs == 'auto':
                if len(loss_ls) > 0:
                    print_flush('Reduction rate of loss is {}.'.format((current_loss - loss_ls[-1]) / loss_ls[-1]))
                    sys.stdout.flush()
                if len(loss_ls) > 0 and -crit_conv_rate < (current_loss - loss_ls[-1]) / loss_ls[-1] < 0 and hvd.rank() == 0:
                    loss_ls.append(current_loss)
                    reg_ls.append(current_reg)
                    summary_writer.add_summary(summary_str, epoch)
                    if mpi4py_is_ok:
                        stop_iteration = True
                    else:
                        stop_iteration = open('.stop_iteration', 'w')
                        stop_iteration.write('True')
                        stop_iteration.close()
                comm.Barrier()
                if mpi4py_is_ok:
                    stop_iteration = comm.bcast(stop_iteration, root=0)
                else:
                    stop_iteration_file = open('.stop_iteration', 'r')
                    stop_iteration = stop_iteration_file.read()
                    stop_iteration_file.close()
                    stop_iteration = True if stop_iteration == 'True' else False
                if stop_iteration:
                    break
            # if epoch < n_epochs_mask_release:
            #     # ================shrink wrap===================
            #     if epoch % shrink_cycle == 0 and epoch > 0:
            #         mask_temp = sess.run(obj_delta > 1e-8)
            #         boolean = tf.convert_to_tensor(mask_temp, dtype=tf.float32)
            #         mask = mask * boolean
            #         if hvd.rank() == 0 and hvd.local_rank() == 0:
            #             dxchange.write_tiff(np.squeeze(sess.run(mask)),
            #                                 os.path.join(save_path, 'fin_sup_mask/runtime_mask/epoch_{}'.format(epoch)), dtype='float32', overwrite=True)
            #     # ==============================================
            # =============finite support===================
            # obj_delta = obj_delta * mask
            # obj_beta = obj_beta * mask
            # obj_delta = tf.nn.relu(obj_delta)
            # obj_beta = tf.nn.relu(obj_beta)
            # ==============================================
            if hvd.rank() == 0:
                loss_ls.append(current_loss)
                reg_ls.append(current_reg)
                summary_writer.add_summary(summary_str, epoch)
            if save_intermediate and hvd.rank() == 0:
                temp_obj_delta, temp_obj_beta = sess.run([obj_delta, obj_beta])
                temp_obj_delta = np.abs(temp_obj_delta)
                temp_obj_beta = np.abs(temp_obj_beta)
                if full_intermediate:
                    dxchange.write_tiff(temp_obj_delta,
                                        fname=os.path.join(output_folder, 'intermediate', 'ds_{}_iter_{:03d}'.format(ds_level, epoch)),
                                        dtype='float32',
                                        overwrite=True)
                else:
                    dxchange.write_tiff(temp_obj_delta[int(temp_obj_delta.shape[0] / 2), :, :],
                                        fname=os.path.join(output_folder, 'intermediate', 'ds_{}_iter_{:03d}'.format(ds_level, epoch)),
                                        dtype='float32',
                                        overwrite=True)
                    probe_current_real, probe_current_imag = sess.run([probe_real, probe_imag])
                    probe_current_mag, probe_current_phase = real_imag_to_mag_phase(probe_current_real, probe_current_imag)
                    dxchange.write_tiff(probe_current_mag,
                                        fname=os.path.join(output_folder, 'intermediate', 'probe',
                                                           'mag_ds_{}_iter_{:03d}'.format(ds_level, epoch)),
                                        dtype='float32',
                                        overwrite=True)
                    dxchange.write_tiff(probe_current_phase,
                                        fname=os.path.join(output_folder, 'intermediate', 'probe',
                                                           'phase_ds_{}_iter_{:03d}'.format(ds_level, epoch)),
                                        dtype='float32',
                                        overwrite=True)
                dxchange.write_tiff(temp_obj_delta, os.path.join(output_folder, 'current', 'delta'), dtype='float32', overwrite=True)
                print_flush('Iteration {} (rank {}); loss = {}; time = {} s'.format(epoch, hvd.rank(), current_loss, time.time() - t00))
            sys.stdout.flush()
            # except:
            #     # if one thread breaks out after meeting stopping criterion, intercept Horovod error and break others
            #     break

            print_flush('Total time: {}'.format(time.time() - t0))
        sys.stdout.flush()

        if hvd.rank() == 0:

            res_delta, res_beta = sess.run([obj_delta, obj_beta])
            res_delta *= mask_np
            res_beta *= mask_np
            res_delta = np.clip(res_delta, 0, None)
            res_beta = np.clip(res_beta, 0, None)
            dxchange.write_tiff(res_delta, fname=os.path.join(output_folder, 'delta_ds_{}'.format(ds_level)), dtype='float32', overwrite=True)
            dxchange.write_tiff(res_beta, fname=os.path.join(output_folder, 'beta_ds_{}'.format(ds_level)), dtype='float32', overwrite=True)

            probe_final_real, probe_final_imag = sess.run([probe_real, probe_imag])
            probe_final_mag, probe_final_phase = real_imag_to_mag_phase(probe_final_real, probe_final_imag)
            dxchange.write_tiff(probe_final_mag, fname=os.path.join(output_folder, 'probe_mag_ds_{}'.format(ds_level)), dtype='float32', overwrite=True)
            dxchange.write_tiff(probe_final_phase, fname=os.path.join(output_folder, 'probe_phase_ds_{}'.format(ds_level)), dtype='float32', overwrite=True)

            error_ls = np.array(loss_ls) - np.array(reg_ls)

            x = len(loss_ls)
            plt.figure()
            plt.semilogy(range(x), loss_ls, label='Total loss')
            plt.semilogy(range(x), reg_ls, label='Regularizer')
            plt.semilogy(range(x), error_ls, label='Error term')
            plt.legend()
            try:
                os.makedirs(os.path.join(output_folder, 'convergence'))
            except:
                pass
            plt.savefig(os.path.join(output_folder, 'convergence', 'converge_ds_{}.png'.format(ds_level)), format='png')
            np.save(os.path.join(output_folder, 'convergence', 'total_loss_ds_{}'.format(ds_level)), loss_ls)
            np.save(os.path.join(output_folder, 'convergence', 'reg_ds_{}'.format(ds_level)), reg_ls)
            np.save(os.path.join(output_folder, 'convergence', 'error_ds_{}'.format(ds_level)), error_ls)

            print_flush('Clearing current graph...')
        sess.run(tf.global_variables_initializer())
        sess.close()
        tf.reset_default_graph()
        initializer_flag = True
        print_flush('Current iteration finished.')
Ejemplo n.º 35
0
    '220': 1400,
    '221': 3000,
    '222': 2200,
}

######################################################
file_name = sys.argv[1]
center = centers[file_name[-6:-3]]
ntheta = nthetas[file_name[-6:-3]]

ngpus = 8
pnz = 16  # chunk size for slices
ptheta = 20  # chunk size for angles

# read data
prj = dxchange.read_tiff_stack(f'{file_name[:-3]}/data/d_00000.tiff',
                               ind=range(ntheta))
theta = np.load(file_name[:-3] + '/data/theta.npy')
nz, n = prj.shape[1:]

niteradmm = [96, 48, 24, 12]  # number of iterations in the ADMM scheme
# niteradmm = [2,2,2]  # number of iterations in the ADMM scheme
startwin = [256, 128, 64,
            32]  # starting window size in optical flow estimation
stepwin = [2, 2, 2,
           2]  # step for decreasing the window size in optical flow estimtion

res = tomoalign.admm_of_levels(prj,
                               theta,
                               pnz,
                               ptheta,
                               center,
Ejemplo n.º 36
0
import dxchange
import numpy as np
import dxchange
from scipy.ndimage import rotate
import sys
import matplotlib.pyplot as plt
if __name__ == "__main__":

    in_file = sys.argv[1]
    out_file = sys.argv[2]
    #mul=2
    binning=0
    data = dxchange.read_tiff_stack(in_file+'/r_00000.tiff', ind=[691,657,672])
    
      #bin0 cutes    
    vmin = -0.001*2/3
    vmax = 0.002*2/3
    plt.imsave(out_file+'2000bin0p1.png',data[0],vmin=vmin,vmax=vmax,cmap='gray')
    plt.imsave(out_file+'2000bin0p2.png',data[1],vmin=vmin,vmax=vmax,cmap='gray')
    plt.imsave(out_file+'2000bin0p3.png',data[2],vmin=vmin,vmax=vmax,cmap='gray')
Ejemplo n.º 37
0
    tbinre = np.bincount(r.ravel(), data.real.ravel())
    tbinim = np.bincount(r.ravel(), data.imag.ravel())

    nr = np.bincount(r.ravel())
    radialprofile = (tbinre + 1j * tbinim) / np.sqrt(nr)

    return radialprofile


wsize = 272
fname1 = '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/d/results_admm0/u/r_00000.tiff'
fname2 = '/data/staff/tomograms/vviknik/tomoalign_vincent_data/psi/d/results_admm1/u/r_00000.tiff'
# fname1 = '/data/staff/tomograms/vviknik/tomoalign_vincent_data/brain/Brain_Petrapoxy_day2_2880prj_1440deg_167/cga_resolution1__1440/rect3/r_00000.tiff'
# fname2 = '/data/staff/tomograms/vviknik/tomoalign_vincent_data/brain/Brain_Petrapoxy_day2_2880prj_1440deg_167/cga_resolution2__1440/rect3/r_00000.tiff'

f1 = dxchange.read_tiff_stack(fname1, ind=np.arange(0, 544))[:]
f2 = dxchange.read_tiff_stack(fname2, ind=np.arange(0, 544))[:]
f1 = f1[f1.shape[0] // 2 - wsize:f1.shape[0] // 2 + wsize,
        f1.shape[1] // 2 - wsize:f1.shape[1] // 2 + wsize,
        f1.shape[2] // 2 - wsize:f1.shape[2] // 2 + wsize]
f2 = f2[f2.shape[0] // 2 - wsize:f2.shape[0] // 2 + wsize,
        f2.shape[1] // 2 - wsize:f2.shape[1] // 2 + wsize,
        f2.shape[2] // 2 - wsize:f2.shape[2] // 2 + wsize]
print(f1.shape)

ff1 = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(f1)))
ff2 = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(f2)))

frc1 = radial_profile3d(ff1*np.conj(ff2),np.array(ff1.shape)//2)/\
    np.sqrt(radial_profile3d(np.abs(ff1)**2,np.array(ff1.shape)//2)*radial_profile3d(np.abs(ff2)**2,np.array(ff1.shape)//2))
# print(np.min(frc1[:wsize].real))