コード例 #1
0
ファイル: user_functions.py プロジェクト: cryoem/test
def reference4( ref_data ):
	from utilities      import print_msg
	from filter         import fit_tanh, filt_tanl, filt_gaussl
	from fundamentals   import fshift, fft
	from morphology     import threshold
	#  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:

	#print_msg("reference4\n")
	cs = [0.0]*3

	stat = Util.infomask(ref_data[2], ref_data[0], False)
	volf = ref_data[2] - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])
	volf = threshold(volf)
	#Util.mul_img(volf, ref_data[0])
	#fl, aa = fit_tanh(ref_data[3])
	fl = 0.25
	aa = 0.1
	#msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	#print_msg(msg)
	volf = fft(filt_gaussl(filt_tanl(fft(volf),0.35,0.2),0.3))
	if ref_data[1] == 1:
		cs = volf.phase_cog()
		msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	return  volf, cs
コード例 #2
0
ファイル: user_functions.py プロジェクト: cryoem/test
def ref_7grp( ref_data ):
	from utilities      import print_msg
	from filter         import fit_tanh, filt_tanl, filt_gaussinv
	from fundamentals   import fshift
	from morphology     import threshold
	from math           import sqrt
	#  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:
	#cs = [0.0]*3

	stat = Util.infomask(ref_data[2], None, False)
	volf = ref_data[2] - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])
	volf = Util.muln_img(threshold(volf), ref_data[0])

	fl, aa = fit_tanh(ref_data[3])
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)
	if(ref_data[1] == 1):
		cs    = volf.phase_cog()
		msg = "Center x =	%10.3f        Center y       = %10.3f        Center z       = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	B_factor = 10.0
	volf = filt_gaussinv( volf, 10.0 )
	return  volf,cs
コード例 #3
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def reference4(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl, filt_gaussl
    from fundamentals import fshift, fft
    from morphology import threshold
    #  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:

    #print_msg("reference4\n")
    cs = [0.0] * 3

    stat = Util.infomask(ref_data[2], ref_data[0], False)
    volf = ref_data[2] - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])
    volf = threshold(volf)
    #Util.mul_img(volf, ref_data[0])
    #fl, aa = fit_tanh(ref_data[3])
    fl = 0.25
    aa = 0.1
    #msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
    #print_msg(msg)
    volf = fft(filt_gaussl(filt_tanl(fft(volf), 0.35, 0.2), 0.3))
    if ref_data[1] == 1:
        cs = volf.phase_cog()
        msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n" % (
            cs[0], cs[1], cs[2])
        print_msg(msg)
        volf = fshift(volf, -cs[0], -cs[1], -cs[2])
    return volf, cs
コード例 #4
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def ref_7grp(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl, filt_gaussinv
    from fundamentals import fshift
    from morphology import threshold
    from math import sqrt
    #  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:
    #cs = [0.0]*3

    stat = Util.infomask(ref_data[2], None, False)
    volf = ref_data[2] - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])
    volf = Util.muln_img(threshold(volf), ref_data[0])

    fl, aa = fit_tanh(ref_data[3])
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    volf = filt_tanl(volf, fl, aa)
    if (ref_data[1] == 1):
        cs = volf.phase_cog()
        msg = "Center x =	%10.3f        Center y       = %10.3f        Center z       = %10.3f\n" % (
            cs[0], cs[1], cs[2])
        print_msg(msg)
        volf = fshift(volf, -cs[0], -cs[1], -cs[2])
    B_factor = 10.0
    volf = filt_gaussinv(volf, 10.0)
    return volf, cs
コード例 #5
0
def fshift(img, x, y, z=0, out=None):
    ''' Shift an image
    
    * has alternative
    
    :Parameters:
    
    img : array
          Image data
    x : float
        Translation in the x-direction
    y : float
        Translation in the y-direction
    z : float
        Translation in the z-direction
    out : array
          Output array
        
    :Returns:
    
    img : array
          Transformed image
    '''
    
    if fundamentals is None: raise ImportError, "EMAN2/Sparx library not available, fshift requires EMAN2/Sparx"
    if out is None: out = img.copy()
    emdata = numpy2em(img)
    emdata = fundamentals.fshift(emdata, x, y, z)
    out[:, :] = em2numpy(emdata)
    return out
コード例 #6
0
ファイル: eman2_utility.py プロジェクト: raj347/arachnid
def fshift(img, x, y, z=0, out=None):
    ''' Shift an image
    
    * has alternative
    
    :Parameters:
    
    img : array
          Image data
    x : float
        Translation in the x-direction
    y : float
        Translation in the y-direction
    z : float
        Translation in the z-direction
    out : array
          Output array
        
    :Returns:
    
    img : array
          Transformed image
    '''

    if fundamentals is None:
        raise ImportError, "EMAN2/Sparx library not available, fshift requires EMAN2/Sparx"
    if out is None: out = img.copy()
    emdata = numpy2em(img)
    emdata = fundamentals.fshift(emdata, x, y, z)
    out[:, :] = em2numpy(emdata)
    return out
コード例 #7
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def ref_ali3d(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl
    from fundamentals import fshift
    from morphology import threshold
    #  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:

    global ref_ali2d_counter
    ref_ali2d_counter += 1

    fl = ref_data[2].cmp("dot", ref_data[2], {
        "negative": 0,
        "mask": ref_data[0]
    })
    print_msg("ref_ali3d    Step = %5d        GOAL = %10.3e\n" %
              (ref_ali2d_counter, fl))

    cs = [0.0] * 3
    #filt = filt_from_fsc(fscc, 0.05)
    #vol  = filt_table(vol, filt)
    # here figure the filtration parameters and filter vol for the  next iteration
    #fl, fh = filt_params(res)
    #vol	= filt_btwl(vol, fl, fh)
    # store the filtered reference volume
    #lk = 0
    #while(res[1][lk] >0.9 and res[0][lk]<0.25):
    #	lk+=1
    #fl = res[0][lk]
    #fh = min(fl+0.1,0.49)
    #vol = filt_btwl(vol, fl, fh)
    #fl, fh = filt_params(fscc)
    #print "fl, fh, iter",fl,fh,Iter
    #vol = filt_btwl(vol, fl, fh)
    stat = Util.infomask(ref_data[2], ref_data[0], False)
    volf = ref_data[2] - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])
    #volf = threshold(volf)
    Util.mul_img(volf, ref_data[0])
    fl, aa = fit_tanh(ref_data[3])
    #fl = 0.4
    #aa = 0.1
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    volf = filt_tanl(volf, fl, aa)
    if ref_data[1] == 1:
        cs = volf.phase_cog()
        msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n" % (
            cs[0], cs[1], cs[2])
        print_msg(msg)
        volf = fshift(volf, -cs[0], -cs[1], -cs[2])
    return volf, cs
コード例 #8
0
ファイル: user_functions.py プロジェクト: cryoem/test
def ref_ali3d( ref_data ):
	from utilities      import print_msg
	from filter         import fit_tanh, filt_tanl
	from fundamentals   import fshift
	from morphology     import threshold
	#  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:

	global  ref_ali2d_counter
	ref_ali2d_counter += 1

	fl = ref_data[2].cmp("dot",ref_data[2], {"negative":0, "mask":ref_data[0]} )
	print_msg("ref_ali3d    Step = %5d        GOAL = %10.3e\n"%(ref_ali2d_counter,fl))

	cs = [0.0]*3
	#filt = filt_from_fsc(fscc, 0.05)
	#vol  = filt_table(vol, filt)
	# here figure the filtration parameters and filter vol for the  next iteration
	#fl, fh = filt_params(res)
	#vol	= filt_btwl(vol, fl, fh)
	# store the filtered reference volume
	#lk = 0
	#while(res[1][lk] >0.9 and res[0][lk]<0.25):
	#	lk+=1
	#fl = res[0][lk]
	#fh = min(fl+0.1,0.49)
	#vol = filt_btwl(vol, fl, fh)
	#fl, fh = filt_params(fscc)
	#print "fl, fh, iter",fl,fh,Iter
	#vol = filt_btwl(vol, fl, fh)
	stat = Util.infomask(ref_data[2], ref_data[0], False)
	volf = ref_data[2] - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])
	#volf = threshold(volf)
	Util.mul_img(volf, ref_data[0])
	fl, aa = fit_tanh(ref_data[3])
	#fl = 0.4
	#aa = 0.1
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)
	if ref_data[1] == 1:
		cs = volf.phase_cog()
		msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	return  volf, cs
コード例 #9
0
ファイル: functions.py プロジェクト: leschzinerlab/EMAN2recon
def ref_ali3d( ref_data ):
	from utilities      import print_msg
	from filter	 import fit_tanh, filt_tanl
	from fundamentals   import fshift
	from morphology     import threshold

	fl = ref_data[2].cmp("dot",ref_data[2], {"negative":0, "mask":ref_data[0]} )
	cs = [0.0]*3
	stat = Util.infomask(ref_data[2], ref_data[0], False)
	volf = ref_data[2] - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])
	Util.mul_img(volf, ref_data[0])
	fl, aa = fit_tanh(ref_data[3])
	volf = filt_tanl(volf, fl, aa)
	volf.process_inplace("normalize")
	if ref_data[1] == 1:
		cs = volf.phase_cog()
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	return  volf, cs, fl
コード例 #10
0
ファイル: functions.py プロジェクト: alushinlab/microtubules
def ref_ali3d(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl
    from fundamentals import fshift
    from morphology import threshold

    fl = ref_data[2].cmp("dot", ref_data[2], {
        "negative": 0,
        "mask": ref_data[0]
    })
    cs = [0.0] * 3
    stat = Util.infomask(ref_data[2], ref_data[0], False)
    volf = ref_data[2] - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])
    Util.mul_img(volf, ref_data[0])
    fl, aa = fit_tanh(ref_data[3])
    volf = filt_tanl(volf, fl, aa)
    volf.process_inplace("normalize")
    if ref_data[1] == 1:
        cs = volf.phase_cog()
        volf = fshift(volf, -cs[0], -cs[1], -cs[2])
    return volf, cs, fl
コード例 #11
0
ファイル: user_functions.py プロジェクト: cryoem/test
def dovolume( ref_data ):
	from utilities      import print_msg, read_text_row
	from filter         import fit_tanh, filt_tanl
	from fundamentals   import fshift
	from morphology     import threshold
	#  Prepare the reference in 3D alignment, this function corresponds to what do_volume does.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:

	global  ref_ali2d_counter
	ref_ali2d_counter += 1

	fl = ref_data[2].cmp("dot",ref_data[2], {"negative":0, "mask":ref_data[0]} )
	print_msg("do_volume user function    Step = %5d        GOAL = %10.3e\n"%(ref_ali2d_counter,fl))

	stat = Util.infomask(ref_data[2], ref_data[0], False)
	vol = ref_data[2] - stat[0]
	Util.mul_scalar(vol, 1.0/stat[1])
	vol = threshold(vol)
	#Util.mul_img(vol, ref_data[0])
	try:
		aa = read_text_row("flaa.txt")[0]
		fl = aa[0]
		aa=aa[1]
	except:
		fl = 0.4
		aa = 0.2
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)

	from utilities    import read_text_file
	from fundamentals import rops_table, fftip, fft
	from filter       import filt_table, filt_btwl
	fftip(vol)
	try:
		rt = read_text_file( "pwreference.txt" )
		ro = rops_table(vol)
		#  Here unless I am mistaken it is enough to take the beginning of the reference pw.
		for i in xrange(1,len(ro)):  ro[i] = (rt[i]/ro[i])**0.5
		vol = fft( filt_table( filt_tanl(vol, fl, aa), ro) )
		msg = "Power spectrum adjusted\n"
		print_msg(msg)
	except:
		vol = fft( filt_tanl(vol, fl, aa) )

	stat = Util.infomask(vol, ref_data[0], False)
	vol -= stat[0]
	Util.mul_scalar(vol, 1.0/stat[1])
	vol = threshold(vol)
	vol = filt_btwl(vol, 0.38, 0.5)
	Util.mul_img(vol, ref_data[0])

	if ref_data[1] == 1:
		cs = volf.phase_cog()
		msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	else:  	cs = [0.0]*3

	return  vol, cs
コード例 #12
0
def helicalshiftali_MPI(stack,
                        maskfile=None,
                        maxit=100,
                        CTF=False,
                        snr=1.0,
                        Fourvar=False,
                        search_rng=-1):
    from applications import MPI_start_end
    from utilities import model_circle, model_blank, get_image, peak_search, get_im, pad
    from utilities import reduce_EMData_to_root, bcast_EMData_to_all, send_attr_dict, file_type, bcast_number_to_all, bcast_list_to_all
    from pap_statistics import varf2d_MPI
    from fundamentals import fft, ccf, rot_shift3D, rot_shift2D, fshift
    from utilities import get_params2D, set_params2D, chunks_distribution
    from utilities import print_msg, print_begin_msg, print_end_msg
    import os
    import sys
    from mpi import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
    from mpi import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv
    from mpi import MPI_SUM, MPI_FLOAT, MPI_INT
    from time import time
    from pixel_error import ordersegments
    from math import sqrt, atan2, tan, pi

    nproc = mpi_comm_size(MPI_COMM_WORLD)
    myid = mpi_comm_rank(MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("helical-shiftali_MPI")

    max_iter = int(maxit)
    if (myid == main_node):
        infils = EMUtil.get_all_attributes(stack, "filament")
        ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
        filaments = ordersegments(infils, ptlcoords)
        total_nfils = len(filaments)
        inidl = [0] * total_nfils
        for i in range(total_nfils):
            inidl[i] = len(filaments[i])
        linidl = sum(inidl)
        nima = linidl
        tfilaments = []
        for i in range(total_nfils):
            tfilaments += filaments[i]
        del filaments
    else:
        total_nfils = 0
        linidl = 0
    total_nfils = bcast_number_to_all(total_nfils, source_node=main_node)
    if myid != main_node:
        inidl = [-1] * total_nfils
    inidl = bcast_list_to_all(inidl, myid, source_node=main_node)
    linidl = bcast_number_to_all(linidl, source_node=main_node)
    if myid != main_node:
        tfilaments = [-1] * linidl
    tfilaments = bcast_list_to_all(tfilaments, myid, source_node=main_node)
    filaments = []
    iendi = 0
    for i in range(total_nfils):
        isti = iendi
        iendi = isti + inidl[i]
        filaments.append(tfilaments[isti:iendi])
    del tfilaments, inidl

    if myid == main_node:
        print_msg("total number of filaments: %d" % total_nfils)
    if total_nfils < nproc:
        ERROR(
            'number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'
            % (nproc, total_nfils), "ehelix_MPI", 1, myid)

    #  balanced load
    temp = chunks_distribution([[len(filaments[i]), i]
                                for i in range(len(filaments))],
                               nproc)[myid:myid + 1][0]
    filaments = [filaments[temp[i][1]] for i in range(len(temp))]
    nfils = len(filaments)

    #filaments = [[0,1]]
    #print "filaments",filaments
    list_of_particles = []
    indcs = []
    k = 0
    for i in range(nfils):
        list_of_particles += filaments[i]
        k1 = k + len(filaments[i])
        indcs.append([k, k1])
        k = k1
    data = EMData.read_images(stack, list_of_particles)
    ldata = len(data)
    print("ldata=", ldata)
    nx = data[0].get_xsize()
    ny = data[0].get_ysize()
    if maskfile == None:
        mrad = min(nx, ny) // 2 - 2
        mask = pad(model_blank(2 * mrad + 1, ny, 1, 1.0), nx, ny, 1, 0.0)
    else:
        mask = get_im(maskfile)

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(ldata):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'],
                               p['mirror'], p['scale'])

    if CTF:
        from filter import filt_ctf
        from morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None
        ctf_abs_sum = None

    from utilities import info

    for im in range(ldata):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            qctf = data[im].get_attr("ctf_applied")
            if qctf == 0:
                data[im] = filt_ctf(fft(data[im]), ctf_params)
                data[im].set_attr('ctf_applied', 1)
            elif qctf != 1:
                ERROR('Incorrectly set qctf flag', "helicalshiftali_MPI", 1,
                      myid)
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)
        else:
            data[im] = fft(data[im])

    del list_of_particles

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            tsnr = 1. / snr
            for i in range(0, nx + 2, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, tsnr)
                    temp.set_value_at(i + 1, j, 0.0)
            #info(ctf_2_sum)
            Util.add_img(ctf_2_sum, temp)
            #info(ctf_2_sum)
            del temp

    total_iter = 0
    shift_x = [0.0] * ldata

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in range(ldata):
            Util.add_img(avg, fshift(data[im], shift_x[im]))

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF: tavg = Util.divn_filter(avg, ctf_2_sum)
            else: tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = model_blank(nx, ny)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)
            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            tavg.write_image("tavg.hdf", Iter)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)
        tavg = fft(tavg)

        sx_sum = 0.0
        nxc = nx // 2

        for ifil in range(nfils):
            """
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
            # Calculate 1D ccf between each segment and filament average
            nsegms = indcs[ifil][1] - indcs[ifil][0]
            ctx = [None] * nsegms
            pcoords = [None] * nsegms
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                ctx[im - indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx,
                                                       1)
                pcoords[im - indcs[ifil][0]] = data[im].get_attr(
                    'ptcl_source_coord')
                #ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
                #print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
            # search for best x-shift
            cents = nsegms // 2

            dst = sqrt(
                max((pcoords[cents][0] - pcoords[0][0])**2 +
                    (pcoords[cents][1] - pcoords[0][1])**2,
                    (pcoords[cents][0] - pcoords[-1][0])**2 +
                    (pcoords[cents][1] - pcoords[-1][1])**2))
            maxincline = atan2(ny // 2 - 2 - float(search_rng), dst)
            kang = int(dst * tan(maxincline) + 0.5)
            #print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang

            # ## C code for alignment. @ming
            results = [0.0] * 3
            results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline,
                                         kang, search_rng, nxc)
            sib = int(results[0])
            bang = results[1]
            qm = results[2]
            #print qm, sib, bang

            # qm = -1.e23
            #
            # 			for six in xrange(-search_rng, search_rng+1,1):
            # 				q0 = ctx[cents].get_value_at(six+nxc)
            # 				for incline in xrange(kang+1):
            # 					qt = q0
            # 					qu = q0
            # 					if(kang>0):  tang = tan(maxincline/kang*incline)
            # 					else:        tang = 0.0
            # 					for kim in xrange(cents+1,nsegms):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					for kim in xrange(cents):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl =  dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					if( qt > qm ):
            # 						qm = qt
            # 						sib = six
            # 						bang = tang
            # 					if( qu > qm ):
            # 						qm = qu
            # 						sib = six
            # 						bang = -tang
            #if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
            #print qm,six,sib,bang
            #print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                kim = im - indcs[ifil][0]
                dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 +
                           (pcoords[cents][1] - pcoords[kim][1])**2)
                if (kim < cents): xl = -dst * bang + sib
                else: xl = dst * bang + sib
                shift_x[im] = xl

            # Average shift
            sx_sum += shift_x[indcs[ifil][0] + cents]

        # #print myid,sx_sum,total_nfils
        sx_sum = mpi_reduce(sx_sum, 1, MPI_FLOAT, MPI_SUM, main_node,
                            MPI_COMM_WORLD)
        if myid == main_node:
            sx_sum = float(sx_sum[0]) / total_nfils
            print_msg("Average shift  %6.2f\n" % (sx_sum))
        else:
            sx_sum = 0.0
        sx_sum = 0.0
        sx_sum = bcast_number_to_all(sx_sum, source_node=main_node)
        for im in range(ldata):
            shift_x[im] -= sx_sum
            #print  "   %3d  %6.3f"%(im,shift_x[im])
        #exit()

    # combine shifts found with the original parameters
    for im in range(ldata):
        t1 = Transform()
        ##import random
        ##shix=random.randint(-10, 10)
        ##t1.set_params({"type":"2D","tx":shix})
        t1.set_params({"type": "2D", "tx": shift_x[im]})
        # combine t0 and t1
        tt = t1 * init_params[im]
        data[im].set_attr("xform.align2d", tt)
    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi_barrier(MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from utilities import file_type
        if (file_type(stack) == "bdb"):
            from utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata,
                               nproc)
        else:
            from utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
    else:
        send_attr_dict(main_node, data, par_str, 0, ldata)
    if myid == main_node: print_end_msg("helical-shiftali_MPI")
コード例 #13
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def dovolume(ref_data):
    from utilities import print_msg, read_text_row
    from filter import fit_tanh, filt_tanl
    from fundamentals import fshift
    from morphology import threshold
    #  Prepare the reference in 3D alignment, this function corresponds to what do_volume does.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:

    global ref_ali2d_counter
    ref_ali2d_counter += 1

    fl = ref_data[2].cmp("dot", ref_data[2], {
        "negative": 0,
        "mask": ref_data[0]
    })
    print_msg("do_volume user function    Step = %5d        GOAL = %10.3e\n" %
              (ref_ali2d_counter, fl))

    stat = Util.infomask(ref_data[2], ref_data[0], False)
    vol = ref_data[2] - stat[0]
    Util.mul_scalar(vol, 1.0 / stat[1])
    vol = threshold(vol)
    #Util.mul_img(vol, ref_data[0])
    try:
        aa = read_text_row("flaa.txt")[0]
        fl = aa[0]
        aa = aa[1]
    except:
        fl = 0.4
        aa = 0.2
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)

    from utilities import read_text_file
    from fundamentals import rops_table, fftip, fft
    from filter import filt_table, filt_btwl
    fftip(vol)
    try:
        rt = read_text_file("pwreference.txt")
        ro = rops_table(vol)
        #  Here unless I am mistaken it is enough to take the beginning of the reference pw.
        for i in xrange(1, len(ro)):
            ro[i] = (rt[i] / ro[i])**0.5
        vol = fft(filt_table(filt_tanl(vol, fl, aa), ro))
        msg = "Power spectrum adjusted\n"
        print_msg(msg)
    except:
        vol = fft(filt_tanl(vol, fl, aa))

    stat = Util.infomask(vol, ref_data[0], False)
    vol -= stat[0]
    Util.mul_scalar(vol, 1.0 / stat[1])
    vol = threshold(vol)
    vol = filt_btwl(vol, 0.38, 0.5)
    Util.mul_img(vol, ref_data[0])

    if ref_data[1] == 1:
        cs = volf.phase_cog()
        msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n" % (
            cs[0], cs[1], cs[2])
        print_msg(msg)
        volf = fshift(volf, -cs[0], -cs[1], -cs[2])
    else:
        cs = [0.0] * 3

    return vol, cs
コード例 #14
0
ファイル: sxshiftali.py プロジェクト: cpsemmens/eman2
def helicalshiftali_MPI(stack, maskfile=None, maxit=100, CTF=False, snr=1.0, Fourvar=False, search_rng=-1):
	from applications import MPI_start_end
	from utilities    import model_circle, model_blank, get_image, peak_search, get_im, pad
	from utilities    import reduce_EMData_to_root, bcast_EMData_to_all, send_attr_dict, file_type, bcast_number_to_all, bcast_list_to_all
	from statistics   import varf2d_MPI
	from fundamentals import fft, ccf, rot_shift3D, rot_shift2D, fshift
	from utilities    import get_params2D, set_params2D, chunks_distribution
	from utilities    import print_msg, print_begin_msg, print_end_msg
	import os
	import sys
	from mpi 	  	  import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	from mpi 	  	  import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv
	from mpi 	  	  import MPI_SUM, MPI_FLOAT, MPI_INT
	from time         import time	
	from pixel_error  import ordersegments
	from math         import sqrt, atan2, tan, pi
	
	nproc = mpi_comm_size(MPI_COMM_WORLD)
	myid = mpi_comm_rank(MPI_COMM_WORLD)
	main_node = 0
		
	ftp = file_type(stack)

	if myid == main_node:
		print_begin_msg("helical-shiftali_MPI")

	max_iter=int(maxit)
	if( myid == main_node):
		infils = EMUtil.get_all_attributes(stack, "filament")
		ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
		filaments = ordersegments(infils, ptlcoords)
		total_nfils = len(filaments)
		inidl = [0]*total_nfils
		for i in xrange(total_nfils):  inidl[i] = len(filaments[i])
		linidl = sum(inidl)
		nima = linidl
		tfilaments = []
		for i in xrange(total_nfils):  tfilaments += filaments[i]
		del filaments
	else:
		total_nfils = 0
		linidl = 0
	total_nfils = bcast_number_to_all(total_nfils, source_node = main_node)
	if myid != main_node:
		inidl = [-1]*total_nfils
	inidl = bcast_list_to_all(inidl, myid, source_node = main_node)
	linidl = bcast_number_to_all(linidl, source_node = main_node)
	if myid != main_node:
		tfilaments = [-1]*linidl
	tfilaments = bcast_list_to_all(tfilaments, myid, source_node = main_node)
	filaments = []
	iendi = 0
	for i in xrange(total_nfils):
		isti = iendi
		iendi = isti+inidl[i]
		filaments.append(tfilaments[isti:iendi])
	del tfilaments,inidl

	if myid == main_node:
		print_msg( "total number of filaments: %d"%total_nfils)
	if total_nfils< nproc:
		ERROR('number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'%(nproc, total_nfils), "ehelix_MPI", 1,myid)

	#  balanced load
	temp = chunks_distribution([[len(filaments[i]), i] for i in xrange(len(filaments))], nproc)[myid:myid+1][0]
	filaments = [filaments[temp[i][1]] for i in xrange(len(temp))]
	nfils     = len(filaments)

	#filaments = [[0,1]]
	#print "filaments",filaments
	list_of_particles = []
	indcs = []
	k = 0
	for i in xrange(nfils):
		list_of_particles += filaments[i]
		k1 = k+len(filaments[i])
		indcs.append([k,k1])
		k = k1
	data = EMData.read_images(stack, list_of_particles)
	ldata = len(data)
	print "ldata=", ldata
	nx = data[0].get_xsize()
	ny = data[0].get_ysize()
	if maskfile == None:
		mrad = min(nx, ny)//2-2
		mask = pad( model_blank(2*mrad+1, ny, 1, 1.0), nx, ny, 1, 0.0)
	else:
		mask = get_im(maskfile)

	# apply initial xform.align2d parameters stored in header
	init_params = []
	for im in xrange(ldata):
		t = data[im].get_attr('xform.align2d')
		init_params.append(t)
		p = t.get_params("2d")
		data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'], p['mirror'], p['scale'])

	if CTF:
		from filter import filt_ctf
		from morphology   import ctf_img
		ctf_abs_sum = EMData(nx, ny, 1, False)
		ctf_2_sum = EMData(nx, ny, 1, False)
	else:
		ctf_2_sum = None
		ctf_abs_sum = None



	from utilities import info

	for im in xrange(ldata):
		data[im].set_attr('ID', list_of_particles[im])
		st = Util.infomask(data[im], mask, False)
		data[im] -= st[0]
		if CTF:
			ctf_params = data[im].get_attr("ctf")
			qctf = data[im].get_attr("ctf_applied")
			if qctf == 0:
				data[im] = filt_ctf(fft(data[im]), ctf_params)
				data[im].set_attr('ctf_applied', 1)
			elif qctf != 1:
				ERROR('Incorrectly set qctf flag', "helicalshiftali_MPI", 1,myid)
			ctfimg = ctf_img(nx, ctf_params, ny=ny)
			Util.add_img2(ctf_2_sum, ctfimg)
			Util.add_img_abs(ctf_abs_sum, ctfimg)
		else:  data[im] = fft(data[im])

	del list_of_particles		

	if CTF:
		reduce_EMData_to_root(ctf_2_sum, myid, main_node)
		reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
	if CTF:
		if myid != main_node:
			del ctf_2_sum
			del ctf_abs_sum
		else:
			temp = EMData(nx, ny, 1, False)
			tsnr = 1./snr
			for i in xrange(0,nx+2,2):
				for j in xrange(ny):
					temp.set_value_at(i,j,tsnr)
					temp.set_value_at(i+1,j,0.0)
			#info(ctf_2_sum)
			Util.add_img(ctf_2_sum, temp)
			#info(ctf_2_sum)
			del temp

	total_iter = 0
	shift_x = [0.0]*ldata

	for Iter in xrange(max_iter):
		if myid == main_node:
			start_time = time()
			print_msg("Iteration #%4d\n"%(total_iter))
		total_iter += 1
		avg = EMData(nx, ny, 1, False)
		for im in xrange(ldata):
			Util.add_img(avg, fshift(data[im], shift_x[im]))

		reduce_EMData_to_root(avg, myid, main_node)

		if myid == main_node:
			if CTF:  tavg = Util.divn_filter(avg, ctf_2_sum)
			else:    tavg = Util.mult_scalar(avg, 1.0/float(nima))
		else:
			tavg = model_blank(nx,ny)

		if Fourvar:
			bcast_EMData_to_all(tavg, myid, main_node)
			vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

		if myid == main_node:
			if Fourvar:
				tavg    = fft(Util.divn_img(fft(tavg), vav))
				vav_r	= Util.pack_complex_to_real(vav)
			# normalize and mask tavg in real space
			tavg = fft(tavg)
			stat = Util.infomask( tavg, mask, False )
			tavg -= stat[0]
			Util.mul_img(tavg, mask)
			tavg.write_image("tavg.hdf",Iter)
			# For testing purposes: shift tavg to some random place and see if the centering is still correct
			#tavg = rot_shift3D(tavg,sx=3,sy=-4)

		if Fourvar:  del vav
		bcast_EMData_to_all(tavg, myid, main_node)
		tavg = fft(tavg)

		sx_sum = 0.0
		nxc = nx//2
		
		for ifil in xrange(nfils):
			"""
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
			# Calculate 1D ccf between each segment and filament average
			nsegms = indcs[ifil][1]-indcs[ifil][0]
			ctx = [None]*nsegms
			pcoords = [None]*nsegms
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				ctx[im-indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx, 1)
				pcoords[im-indcs[ifil][0]] = data[im].get_attr('ptcl_source_coord')
				#ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
				#print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
			# search for best x-shift
			cents = nsegms//2
			
			dst = sqrt(max((pcoords[cents][0] - pcoords[0][0])**2 + (pcoords[cents][1] - pcoords[0][1])**2, (pcoords[cents][0] - pcoords[-1][0])**2 + (pcoords[cents][1] - pcoords[-1][1])**2))
			maxincline = atan2(ny//2-2-float(search_rng),dst)
			kang = int(dst*tan(maxincline)+0.5)
			#print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang
			
			# ## C code for alignment. @ming
 			results = [0.0]*3;
 			results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline, kang, search_rng,nxc)
			sib = int(results[0])
 			bang = results[1]
 			qm = results[2]
			#print qm, sib, bang
			
			# qm = -1.e23	
# 				
# 			for six in xrange(-search_rng, search_rng+1,1):
# 				q0 = ctx[cents].get_value_at(six+nxc)
# 				for incline in xrange(kang+1):
# 					qt = q0
# 					qu = q0
# 					if(kang>0):  tang = tan(maxincline/kang*incline)
# 					else:        tang = 0.0
# 					for kim in xrange(cents+1,nsegms):
# 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
# 						xl = dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
# 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 						xl = -dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 					for kim in xrange(cents):
# 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
# 						xl = -dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 						xl =  dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 					if( qt > qm ):
# 						qm = qt
# 						sib = six
# 						bang = tang
# 					if( qu > qm ):
# 						qm = qu
# 						sib = six
# 						bang = -tang
					#if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
			#print qm,six,sib,bang
			#print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				kim = im-indcs[ifil][0]
				dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
				if(kim < cents):  xl = -dst*bang+sib
				else:             xl =  dst*bang+sib
				shift_x[im] = xl
							
			# Average shift
			sx_sum += shift_x[indcs[ifil][0]+cents]
			
			
		# #print myid,sx_sum,total_nfils
		sx_sum = mpi_reduce(sx_sum, 1, MPI_FLOAT, MPI_SUM, main_node, MPI_COMM_WORLD)
		if myid == main_node:
			sx_sum = float(sx_sum[0])/total_nfils
			print_msg("Average shift  %6.2f\n"%(sx_sum))
		else:
			sx_sum = 0.0
		sx_sum = 0.0
		sx_sum = bcast_number_to_all(sx_sum, source_node = main_node)
		for im in xrange(ldata):
			shift_x[im] -= sx_sum
			#print  "   %3d  %6.3f"%(im,shift_x[im])
		#exit()


			
	# combine shifts found with the original parameters
	for im in xrange(ldata):		
		t1 = Transform()
		##import random
		##shix=random.randint(-10, 10)
		##t1.set_params({"type":"2D","tx":shix})
		t1.set_params({"type":"2D","tx":shift_x[im]})
		# combine t0 and t1
		tt = t1*init_params[im]
		data[im].set_attr("xform.align2d", tt)
	# write out headers and STOP, under MPI writing has to be done sequentially
	mpi_barrier(MPI_COMM_WORLD)
	par_str = ["xform.align2d", "ID"]
	if myid == main_node:
		from utilities import file_type
		if(file_type(stack) == "bdb"):
			from utilities import recv_attr_dict_bdb
			recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata, nproc)
		else:
			from utilities import recv_attr_dict
			recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
	else:           send_attr_dict(main_node, data, par_str, 0, ldata)
	if myid == main_node: print_end_msg("helical-shiftali_MPI")