コード例 #1
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def reference4(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl, filt_gaussl
    from fundamentals import fshift, fft
    from morphology import threshold
    #  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:

    #print_msg("reference4\n")
    cs = [0.0] * 3

    stat = Util.infomask(ref_data[2], ref_data[0], False)
    volf = ref_data[2] - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])
    volf = threshold(volf)
    #Util.mul_img(volf, ref_data[0])
    #fl, aa = fit_tanh(ref_data[3])
    fl = 0.25
    aa = 0.1
    #msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
    #print_msg(msg)
    volf = fft(filt_gaussl(filt_tanl(fft(volf), 0.35, 0.2), 0.3))
    if ref_data[1] == 1:
        cs = volf.phase_cog()
        msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n" % (
            cs[0], cs[1], cs[2])
        print_msg(msg)
        volf = fshift(volf, -cs[0], -cs[1], -cs[2])
    return volf, cs
コード例 #2
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def helical(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl
    from morphology import threshold
    #  Prepare the reference in helical refinement, i.e., low-pass filter .
    #  Input: list ref_data
    #   0 - raw volume
    #  Output: filtered, and masked reference image

    global ref_ali2d_counter
    ref_ali2d_counter += 1
    print_msg("helical   #%6d\n" % (ref_ali2d_counter))
    stat = Util.infomask(ref_data[0], None, True)
    volf = ref_data[0] - stat[0]
    nx = volf.get_xsize()
    ny = volf.get_ysize()
    nz = volf.get_zsize()
    #for i in xrange(nz):
    #	volf.insert_clip(filt_tanl(volf.get_clip(Region(0,0,i,nx,ny,1)),0.4,0.1),[0,0,i])

    volf = threshold(volf)
    fl = 0.45  #0.17
    aa = 0.1
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    volf = filt_tanl(volf, fl, aa)
    return volf  #,[0.,0.,0.]
コード例 #3
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def ref_7grp(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl, filt_gaussinv
    from fundamentals import fshift
    from morphology import threshold
    from math import sqrt
    #  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:
    #cs = [0.0]*3

    stat = Util.infomask(ref_data[2], None, False)
    volf = ref_data[2] - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])
    volf = Util.muln_img(threshold(volf), ref_data[0])

    fl, aa = fit_tanh(ref_data[3])
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    volf = filt_tanl(volf, fl, aa)
    if (ref_data[1] == 1):
        cs = volf.phase_cog()
        msg = "Center x =	%10.3f        Center y       = %10.3f        Center z       = %10.3f\n" % (
            cs[0], cs[1], cs[2])
        print_msg(msg)
        volf = fshift(volf, -cs[0], -cs[1], -cs[2])
    B_factor = 10.0
    volf = filt_gaussinv(volf, 10.0)
    return volf, cs
コード例 #4
0
ファイル: user_functions.py プロジェクト: cryoem/test
def reference4( ref_data ):
	from utilities      import print_msg
	from filter         import fit_tanh, filt_tanl, filt_gaussl
	from fundamentals   import fshift, fft
	from morphology     import threshold
	#  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:

	#print_msg("reference4\n")
	cs = [0.0]*3

	stat = Util.infomask(ref_data[2], ref_data[0], False)
	volf = ref_data[2] - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])
	volf = threshold(volf)
	#Util.mul_img(volf, ref_data[0])
	#fl, aa = fit_tanh(ref_data[3])
	fl = 0.25
	aa = 0.1
	#msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	#print_msg(msg)
	volf = fft(filt_gaussl(filt_tanl(fft(volf),0.35,0.2),0.3))
	if ref_data[1] == 1:
		cs = volf.phase_cog()
		msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	return  volf, cs
コード例 #5
0
ファイル: user_functions.py プロジェクト: cryoem/test
def ref_7grp( ref_data ):
	from utilities      import print_msg
	from filter         import fit_tanh, filt_tanl, filt_gaussinv
	from fundamentals   import fshift
	from morphology     import threshold
	from math           import sqrt
	#  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:
	#cs = [0.0]*3

	stat = Util.infomask(ref_data[2], None, False)
	volf = ref_data[2] - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])
	volf = Util.muln_img(threshold(volf), ref_data[0])

	fl, aa = fit_tanh(ref_data[3])
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)
	if(ref_data[1] == 1):
		cs    = volf.phase_cog()
		msg = "Center x =	%10.3f        Center y       = %10.3f        Center z       = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	B_factor = 10.0
	volf = filt_gaussinv( volf, 10.0 )
	return  volf,cs
コード例 #6
0
ファイル: user_functions.py プロジェクト: cryoem/test
def helical( ref_data ):
	from utilities      import print_msg
	from filter         import fit_tanh, filt_tanl
	from morphology     import threshold
	#  Prepare the reference in helical refinement, i.e., low-pass filter .
	#  Input: list ref_data
	#   0 - raw volume
	#  Output: filtered, and masked reference image

	global  ref_ali2d_counter
	ref_ali2d_counter += 1
	print_msg("helical   #%6d\n"%(ref_ali2d_counter))
	stat = Util.infomask(ref_data[0], None, True)
	volf = ref_data[0] - stat[0]
	nx = volf.get_xsize()
	ny = volf.get_ysize()
	nz = volf.get_zsize()
	#for i in xrange(nz):
	#	volf.insert_clip(filt_tanl(volf.get_clip(Region(0,0,i,nx,ny,1)),0.4,0.1),[0,0,i])

	volf = threshold(volf)
	fl = 0.45#0.17
	aa = 0.1
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)
	return  volf#,[0.,0.,0.]
コード例 #7
0
ファイル: user_functions.py プロジェクト: cryoem/test
def spruce_up_var_m( refdata ):
	from utilities  import print_msg
	from utilities  import model_circle, get_im
	from filter     import filt_tanl, filt_gaussl
	from morphology import threshold
	import os

	numref     = refdata[0]
	outdir     = refdata[1]
	fscc       = refdata[2]
	total_iter = refdata[3]
	varf       = refdata[4]
	mask       = refdata[5]
	ali50S     = refdata[6]

	if ali50S:
		mask_50S = get_im( "mask-50S.spi" )


	if fscc is None:
		flmin = 0.4
		aamin = 0.1
	else:
		flmin,aamin,idmin=minfilt( fscc )
		aamin = aamin

	msg = "Minimum tangent filter:  cut-off frequency = %10.3f     fall-off = %10.3f\n"%(fflmin, aamin)
	print_msg(msg)

	for i in xrange(numref):
		volf = get_im( os.path.join(outdir, "vol%04d.hdf"% total_iter) , i )
		if(not (varf is None) ):   volf = volf.filter_by_image( varf )
		volf = filt_tanl(volf, flmin, aamin)
		stat = Util.infomask(volf, mask, True)
		volf -= stat[0]
		Util.mul_scalar(volf, 1.0/stat[1])

		nx = volf.get_xsize()
		stat = Util.infomask(volf,model_circle(nx//2-2,nx,nx,nx)-model_circle(nx//2-6,nx,nx,nx), True)
		volf -= stat[0]
		Util.mul_img( volf, mask )

		volf = threshold(volf)
		volf = filt_gaussl( volf, 0.4)

		if ali50S:
			if i==0:
				v50S_0 = volf.copy()
				v50S_0 *= mask_50S
			else:
				from applications import ali_vol_3
				from fundamentals import rot_shift3D
				v50S_i = volf.copy()
				v50S_i *= mask_50S

				params = ali_vol_3(v50S_i, v50S_0, 10.0, 0.5, mask=mask_50S)
				volf = rot_shift3D( volf, params[0], params[1], params[2], params[3], params[4], params[5], 1.0)

		volf.write_image( os.path.join(outdir, "volf%04d.hdf"%total_iter), i )
コード例 #8
0
ファイル: projection.py プロジェクト: a-re/EMAN2-classes
def cml_end_log(Ori):
    from utilities import print_msg
    global g_n_prj
    print_msg('\n\n')
    for i in xrange(g_n_prj):
        print_msg(
            'Projection #%03i: phi %10.5f    theta %10.5f    psi %10.5f\n' %
            (i, Ori[4 * i], Ori[4 * i + 1], Ori[4 * i + 2]))
コード例 #9
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def ref_ali3dm_new(refdata):
    from utilities import print_msg
    from utilities import model_circle, get_im
    from filter import filt_tanl, filt_gaussl, filt_table
    from morphology import threshold
    from fundamentals import rops_table
    from alignment import ali_nvol
    from math import sqrt
    import os

    numref = refdata[0]
    outdir = refdata[1]
    fscc = refdata[2]
    total_iter = refdata[3]
    varf = refdata[4]
    mask = refdata[5]
    ali50S = refdata[6]

    if fscc is None:
        flmin = 0.38
        aamin = 0.1
        idmin = 0
    else:
        flmin, aamin, idmin = minfilt(fscc)
        aamin /= 2.0
    msg = "Minimum tangent filter derived from volume %2d:  cut-off frequency = %10.3f, fall-off = %10.3f\n" % (
        idmin, flmin, aamin)
    print_msg(msg)

    vol = []
    for i in xrange(numref):
        vol.append(get_im(os.path.join(outdir, "vol%04d.hdf" % total_iter), i))
        stat = Util.infomask(vol[i], mask, False)
        vol[i] -= stat[0]
        vol[i] /= stat[1]
        vol[i] *= mask
        vol[i] = threshold(vol[i])
    del stat

    reftab = rops_table(vol[idmin])
    for i in xrange(numref):
        if (i != idmin):
            vtab = rops_table(vol[i])
            ftab = [None] * len(vtab)
            for j in xrange(len(vtab)):
                ftab[j] = sqrt(reftab[j] / vtab[j])
            vol[i] = filt_table(vol[i], ftab)

    if ali50S:
        vol = ali_nvol(vol, get_im("mask-50S.spi"))
    for i in xrange(numref):
        if (not (varf is None)): vol[i] = vol[i].filter_by_image(varf)
        filt_tanl(vol[i], flmin, aamin).write_image(
            os.path.join(outdir, "volf%04d.hdf" % total_iter), i)
コード例 #10
0
ファイル: user_functions.py プロジェクト: cryoem/test
def ref_ali3dm_new( refdata ):
	from utilities    import print_msg
	from utilities    import model_circle, get_im
	from filter       import filt_tanl, filt_gaussl, filt_table
	from morphology   import threshold
	from fundamentals import rops_table
	from alignment    import ali_nvol
	from math         import sqrt
	import   os

	numref     = refdata[0]
	outdir     = refdata[1]
	fscc       = refdata[2]
	total_iter = refdata[3]
	varf       = refdata[4]
	mask       = refdata[5]
	ali50S     = refdata[6]

	if fscc is None:
		flmin = 0.38
		aamin = 0.1
		idmin = 0
	else:
		flmin, aamin, idmin = minfilt( fscc )
		aamin /= 2.0
	msg = "Minimum tangent filter derived from volume %2d:  cut-off frequency = %10.3f, fall-off = %10.3f\n"%(idmin, flmin, aamin)
	print_msg(msg)

	vol = []
	for i in xrange(numref):
		vol.append(get_im( os.path.join(outdir, "vol%04d.hdf"%total_iter), i ))
		stat = Util.infomask( vol[i], mask, False )
		vol[i] -= stat[0]
		vol[i] /= stat[1]
		vol[i] *= mask
		vol[i] = threshold(vol[i])
	del stat

	reftab = rops_table( vol[idmin] )
	for i in xrange(numref):
		if(i != idmin):
			vtab = rops_table( vol[i] )
			ftab = [None]*len(vtab)
			for j in xrange(len(vtab)):
		        	ftab[j] = sqrt( reftab[j]/vtab[j] )
			vol[i] = filt_table( vol[i], ftab )

	if ali50S:
		vol = ali_nvol(vol, get_im( "mask-50S.spi" ))
	for i in xrange(numref):
		if(not (varf is None) ):   vol[i] = vol[i].filter_by_image( varf )
		filt_tanl( vol[i], flmin, aamin ).write_image( os.path.join(outdir, "volf%04d.hdf" % total_iter), i )
コード例 #11
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def ref_aliB_cone(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl
    from fundamentals import fshift
    from morphology import threshold
    from math import sqrt
    #  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - reference PW
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:

    print_msg("ref_aliB_cone\n")
    #cs = [0.0]*3

    stat = Util.infomask(ref_data[2], None, True)
    volf = ref_data[2] - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])

    volf = threshold(volf)
    Util.mul_img(volf, ref_data[0])

    from fundamentals import rops_table
    pwem = rops_table(volf)
    ftb = []
    for idum in xrange(len(pwem)):
        ftb.append(sqrt(ref_data[1][idum] / pwem[idum]))
    from filter import filt_table
    volf = filt_table(volf, ftb)

    fl, aa = fit_tanh(ref_data[3])
    #fl = 0.41
    #aa = 0.15
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    volf = filt_tanl(volf, fl, aa)
    stat = Util.infomask(volf, None, True)
    volf -= stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])
    """
	if(ref_data[1] == 1):
		cs    = volf.phase_cog()
		msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	"""
    return volf
コード例 #12
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def spruce_up_variance(ref_data):
    from utilities import print_msg
    from filter import filt_tanl, fit_tanh, filt_gaussl
    from morphology import threshold
    #  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #   4 1.0/variance
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:
    mask = ref_data[0]
    center = ref_data[1]
    vol = ref_data[2]
    fscc = ref_data[3]
    varf = ref_data[4]

    print_msg("spruce_up with variance\n")
    cs = [0.0] * 3

    if not (varf is None):
        volf = vol.filter_by_image(varf)

    #fl, aa = fit_tanh(ref_data[3])
    fl = 0.22
    aa = 0.15
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    volf = filt_tanl(volf, fl, aa)

    stat = Util.infomask(volf, None, True)
    volf = volf - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])

    from utilities import model_circle
    nx = volf.get_xsize()
    stat = Util.infomask(
        volf,
        model_circle(nx // 2 - 2, nx, nx, nx) -
        model_circle(nx // 2 - 6, nx, nx, nx), True)

    volf -= stat[0]
    Util.mul_img(volf, mask)

    volf = threshold(volf)

    volf = filt_gaussl(volf, 0.4)
    return volf, cs
コード例 #13
0
ファイル: user_functions.py プロジェクト: cryoem/test
def ref_aliB_cone( ref_data ):
	from utilities      import print_msg
	from filter         import fit_tanh, filt_tanl
	from fundamentals   import fshift
	from morphology     import threshold
	from math           import sqrt
	#  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - reference PW
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:

	print_msg("ref_aliB_cone\n")
	#cs = [0.0]*3

	stat = Util.infomask(ref_data[2], None, True)
	volf = ref_data[2] - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])

	volf = threshold(volf)
	Util.mul_img(volf, ref_data[0])

	from  fundamentals  import  rops_table
	pwem = rops_table(volf)
	ftb = []
	for idum in xrange(len(pwem)):
		ftb.append(sqrt(ref_data[1][idum]/pwem[idum]))
	from filter import filt_table
	volf = filt_table(volf, ftb)

	fl, aa = fit_tanh(ref_data[3])
	#fl = 0.41
	#aa = 0.15
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)
	stat = Util.infomask(volf, None, True)
	volf -= stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])
	"""
	if(ref_data[1] == 1):
		cs    = volf.phase_cog()
		msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	"""
	return  volf
コード例 #14
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def ref_random(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl
    from utilities import center_2D
    #  Prepare the reference in 2D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FRC) to reference image:
    global ref_ali2d_counter
    ref_ali2d_counter += 1
    print_msg("ref_ali2d   #%6d\n" % (ref_ali2d_counter))
    """
	fl, aa = fit_tanh(ref_data[3])
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	tavg = filt_tanl(ref_data[2], fl, aa)
	"""
    # ONE CAN USE BUTTERWORTH FILTER
    #lowfq, highfq = filt_params( ref_data[3], low = 0.1)
    #tavg  = filt_btwl( ref_data[2], lowfq, highfq)
    #msg = "Low frequency = %10.3f        High frequency = %10.3f\n"%(lowfq, highfq)
    #print_msg(msg)
    #  ONE CAN CHANGE THE MASK AS THE PROGRAM PROGRESSES
    #from morphology import adaptive_mask
    #ref_data[0] = adaptive_mask(tavg)
    #  CENTER
    cs = [0.0] * 2
    tavg, cs[0], cs[1] = center_2D(ref_data[2], ref_data[1])
    '''
	from math import exp
	nx = tavg.get_xsize()
	ft = []
	good = True
	for i in xrange(nx):
		if(good):
			ex = exp((float(i)/float(nx))**2/2.0/0.12**2)
			if(ex>100.): good = False
		ft.append(ex)
	from filter import filt_table
	tavg = filt_table(tavg, ft)
	'''
    if (ref_data[1] > 0):
        msg = "Center x =      %10.3f        Center y       = %10.3f\n" % (
            cs[0], cs[1])
        print_msg(msg)
    return tavg, cs
コード例 #15
0
ファイル: user_functions.py プロジェクト: cryoem/test
def ref_random( ref_data ):
	from utilities    import print_msg
	from filter       import fit_tanh, filt_tanl
	from utilities    import center_2D
	#  Prepare the reference in 2D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FRC) to reference image:
	global  ref_ali2d_counter
	ref_ali2d_counter += 1
	print_msg("ref_ali2d   #%6d\n"%(ref_ali2d_counter))
	"""
	fl, aa = fit_tanh(ref_data[3])
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	tavg = filt_tanl(ref_data[2], fl, aa)
	"""	
	# ONE CAN USE BUTTERWORTH FILTER
	#lowfq, highfq = filt_params( ref_data[3], low = 0.1)
	#tavg  = filt_btwl( ref_data[2], lowfq, highfq)
	#msg = "Low frequency = %10.3f        High frequency = %10.3f\n"%(lowfq, highfq)
	#print_msg(msg)
	#  ONE CAN CHANGE THE MASK AS THE PROGRAM PROGRESSES
	#from morphology import adaptive_mask
	#ref_data[0] = adaptive_mask(tavg)
	#  CENTER
	cs = [0.0]*2
	tavg, cs[0], cs[1] = center_2D(ref_data[2], ref_data[1])
	'''
	from math import exp
	nx = tavg.get_xsize()
	ft = []
	good = True
	for i in xrange(nx):
		if(good):
			ex = exp((float(i)/float(nx))**2/2.0/0.12**2)
			if(ex>100.): good = False
		ft.append(ex)
	from filter import filt_table
	tavg = filt_table(tavg, ft)
	'''
	if(ref_data[1] > 0):
		msg = "Center x =      %10.3f        Center y       = %10.3f\n"%(cs[0], cs[1])
		print_msg(msg)
	return  tavg, cs
コード例 #16
0
ファイル: user_functions.py プロジェクト: cryoem/test
def spruce_up_variance( ref_data ):
	from utilities      import print_msg
	from filter         import filt_tanl, fit_tanh, filt_gaussl
	from morphology     import threshold
	#  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#   4 1.0/variance
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:
	mask   = ref_data[0]
	center = ref_data[1]
	vol    = ref_data[2]
	fscc   = ref_data[3]
	varf   = ref_data[4]

	print_msg("spruce_up with variance\n")
	cs = [0.0]*3

	if not(varf is None):
		volf = vol.filter_by_image(varf)

	#fl, aa = fit_tanh(ref_data[3])
	fl = 0.22
	aa = 0.15
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)

	stat = Util.infomask(volf, None, True)
	volf = volf - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])

	from utilities import model_circle
	nx = volf.get_xsize()
	stat = Util.infomask(volf, model_circle(nx//2-2,nx,nx,nx)-model_circle(nx//2-6,nx,nx,nx), True)

	volf -= stat[0]
	Util.mul_img(volf, mask)

	volf = threshold(volf)
	
	volf = filt_gaussl(volf, 0.4)
	return  volf, cs
コード例 #17
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def spruce_up(ref_data):
    from utilities import print_msg
    from filter import filt_tanl, fit_tanh
    from morphology import threshold
    #  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:

    print_msg("Changed4 spruce_up\n")
    cs = [0.0] * 3

    stat = Util.infomask(ref_data[2], None, True)
    volf = ref_data[2] - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])
    volf = threshold(volf)
    # Apply B-factor
    from filter import filt_gaussinv
    from math import sqrt
    B = 1.0 / sqrt(2. * 14.0)
    volf = filt_gaussinv(volf, B, False)
    nx = volf.get_xsize()
    from utilities import model_circle
    stat = Util.infomask(
        volf,
        model_circle(nx // 2 - 2, nx, nx, nx) -
        model_circle(nx // 2 - 6, nx, nx, nx), True)

    volf -= stat[0]
    Util.mul_img(volf, ref_data[0])
    fl, aa = fit_tanh(ref_data[3])
    #fl = 0.35
    #aa = 0.1
    aa /= 2
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    volf = filt_tanl(volf, fl, aa)
    return volf, cs
コード例 #18
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def ref_ali2d_c(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl
    from utilities import center_2D
    #  Prepare the reference in 2D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FRC) to reference image:
    global ref_ali2d_counter
    ref_ali2d_counter += 1
    print_msg("ref_ali2d   #%6d\n" % (ref_ali2d_counter))
    fl = min(0.1 + ref_ali2d_counter * 0.003, 0.4)
    aa = 0.1
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    tavg = filt_tanl(ref_data[2], fl, aa)
    cs = [0.0] * 2
    if (ref_data[1] > 0):
        tavg, cs[0], cs[1] = center_2D(tavg, ref_data[1])
        msg = "Center x = %10.3f, y       = %10.3f\n" % (cs[0], cs[1])
        print_msg(msg)
    return tavg, cs
コード例 #19
0
ファイル: user_functions.py プロジェクト: cryoem/test
def julien( ref_data ):
        from utilities    import print_msg
        from filter       import fit_tanh, filt_tanl
        from utilities    import center_2D
        #  Prepare the reference in 2D alignment, i.e., low-pass filter and center.
        #  Input: list ref_data
        #   0 - mask
        #   1 - center flag
        #   2 - raw average
        #   3 - fsc result
        #  Output: filtered, centered, and masked reference image
        #  apply filtration (FRC) to reference image:
        global  ref_ali2d_counter
        ref_ali2d_counter += 1
        ref_ali2d_counter  = ref_ali2d_counter % 50
        print_msg("ref_ali2d   #%6d\n"%(ref_ali2d_counter))
        fl = min(0.1+ref_ali2d_counter*0.003, 0.4)
        aa = 0.1
        msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
        print_msg(msg)
        tavg = filt_tanl(ref_data[2], fl, aa)
        cs = [0.0]*2
        if ref_data[1] > 0:
                tavg, cs[0], cs[1] = center_2D(tavg, ref_data[1])
                msg = "Center x = %10.3f, y       = %10.3f\n"%(cs[0], cs[1])
                print_msg(msg)
        return  tavg, cs
コード例 #20
0
def helical2( ref_data ):
	from utilities      import print_msg
	from filter	    import fit_tanh, filt_tanl
	from morphology     import threshold
	#  Prepare the reference in helical refinement, i.e., low-pass filter.
	#  Input: list ref_data
	#  2 - raw volume
	#  Output: filtered, and masked reference image

	global  ref_ali2d_counter
	ref_ali2d_counter += 1
	print_msg("helical2   #%6d\n"%(ref_ali2d_counter))
	stat = Util.infomask(ref_data[2], None, True)
	volf = ref_data[2] - stat[0]
	volf = threshold(volf)
	fl = 0.25#0.17
	aa = 0.1
	msg = "Tangent filter:  cut-off frequency = %10.3f	  fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)
	from utilities import read_text_file
	dipr=read_text_file('symdoc.txt',-1)
	#here pixel size, fract, rmax and rmin will have to be read from external text file
	from alignment import helios
	volf, dp, dphi = helios(volf, 2.175, dipr[0][-1], dipr[1][-1], 0.7, 30,3) 
	print_msg("New delta z and delta phi	  : %s,    %s\n\n"%(dp,dphi))
	fofo = open('symdoc.txt','a')
	fofo.write('  %12.4f   %12.4f\n'%(dp,dphi))
	fofo.close()
	return  volf,[0.0,0.0,0.0]
コード例 #21
0
ファイル: user_functions.py プロジェクト: cryoem/test
def steady( ref_data ):
	from utilities    import print_msg
	from filter       import fit_tanh, filt_tanl
	from utilities    import center_2D
	#  Prepare the reference in 2D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FRC) to reference image:
	global  ref_ali2d_counter
	ref_ali2d_counter += 1
	print_msg("steady   #%6d\n"%(ref_ali2d_counter))
	fl = 0.12 + (ref_ali2d_counter//3)*0.1
	aa = 0.1
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	tavg = filt_tanl(ref_data[2], fl, aa)
	cs = [0.0]*2
	return  tavg, cs
コード例 #22
0
ファイル: user_functions.py プロジェクト: cryoem/test
def helical2( ref_data ):
	from utilities      import print_msg
	from filter	    import fit_tanh, filt_tanl
	from morphology     import threshold
	#  Prepare the reference in helical refinement, i.e., low-pass filter.
	#  Input: list ref_data
	#  2 - raw volume
	#  Output: filtered, and masked reference image

	global  ref_ali2d_counter
	ref_ali2d_counter += 1
	print_msg("helical2   #%6d\n"%(ref_ali2d_counter))
	volf = ref_data[0]
	#stat = Util.infomask(ref_data[1], None, True)
	#volf = ref_data[0] - stat[0]
	#volf = threshold(volf)
	fl = 0.17
	aa = 0.2
	msg = "Tangent filter:  cut-off frequency = %10.3f	  fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)
	return  volf
コード例 #23
0
ファイル: user_functions.py プロジェクト: cryoem/test
def spruce_up( ref_data ):
	from utilities      import print_msg
	from filter         import filt_tanl, fit_tanh
	from morphology     import threshold
	#  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:

	print_msg("Changed4 spruce_up\n")
	cs = [0.0]*3

	stat = Util.infomask(ref_data[2], None, True)
	volf = ref_data[2] - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])
	volf = threshold(volf)
	# Apply B-factor
	from filter import filt_gaussinv
	from math import sqrt
	B = 1.0/sqrt(2.*14.0)
	volf = filt_gaussinv(volf, B, False)
	nx = volf.get_xsize()
	from utilities import model_circle
	stat = Util.infomask(volf, model_circle(nx//2-2,nx,nx,nx)-model_circle(nx//2-6,nx,nx,nx), True)

	volf -= stat[0]
	Util.mul_img(volf, ref_data[0])
	fl, aa = fit_tanh(ref_data[3])
	#fl = 0.35
	#aa = 0.1
	aa /= 2
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)
	return  volf, cs
コード例 #24
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def ref_ali3d(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl
    from fundamentals import fshift
    from morphology import threshold
    #  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:

    global ref_ali2d_counter
    ref_ali2d_counter += 1

    fl = ref_data[2].cmp("dot", ref_data[2], {
        "negative": 0,
        "mask": ref_data[0]
    })
    print_msg("ref_ali3d    Step = %5d        GOAL = %10.3e\n" %
              (ref_ali2d_counter, fl))

    cs = [0.0] * 3
    #filt = filt_from_fsc(fscc, 0.05)
    #vol  = filt_table(vol, filt)
    # here figure the filtration parameters and filter vol for the  next iteration
    #fl, fh = filt_params(res)
    #vol	= filt_btwl(vol, fl, fh)
    # store the filtered reference volume
    #lk = 0
    #while(res[1][lk] >0.9 and res[0][lk]<0.25):
    #	lk+=1
    #fl = res[0][lk]
    #fh = min(fl+0.1,0.49)
    #vol = filt_btwl(vol, fl, fh)
    #fl, fh = filt_params(fscc)
    #print "fl, fh, iter",fl,fh,Iter
    #vol = filt_btwl(vol, fl, fh)
    stat = Util.infomask(ref_data[2], ref_data[0], False)
    volf = ref_data[2] - stat[0]
    Util.mul_scalar(volf, 1.0 / stat[1])
    #volf = threshold(volf)
    Util.mul_img(volf, ref_data[0])
    fl, aa = fit_tanh(ref_data[3])
    #fl = 0.4
    #aa = 0.1
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    volf = filt_tanl(volf, fl, aa)
    if ref_data[1] == 1:
        cs = volf.phase_cog()
        msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n" % (
            cs[0], cs[1], cs[2])
        print_msg(msg)
        volf = fshift(volf, -cs[0], -cs[1], -cs[2])
    return volf, cs
コード例 #25
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def steady(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl
    from utilities import center_2D
    #  Prepare the reference in 2D alignment, i.e., low-pass filter and center.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FRC) to reference image:
    global ref_ali2d_counter
    ref_ali2d_counter += 1
    print_msg("steady   #%6d\n" % (ref_ali2d_counter))
    fl = 0.12 + (ref_ali2d_counter // 3) * 0.1
    aa = 0.1
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    tavg = filt_tanl(ref_data[2], fl, aa)
    cs = [0.0] * 2
    return tavg, cs
コード例 #26
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def helical2(ref_data):
    from utilities import print_msg
    from filter import fit_tanh, filt_tanl
    from morphology import threshold
    #  Prepare the reference in helical refinement, i.e., low-pass filter.
    #  Input: list ref_data
    #  2 - raw volume
    #  Output: filtered, and masked reference image

    global ref_ali2d_counter
    ref_ali2d_counter += 1
    print_msg("helical2   #%6d\n" % (ref_ali2d_counter))
    volf = ref_data[0]
    #stat = Util.infomask(ref_data[1], None, True)
    #volf = ref_data[0] - stat[0]
    #volf = threshold(volf)
    fl = 0.17
    aa = 0.2
    msg = "Tangent filter:  cut-off frequency = %10.3f	  fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)
    volf = filt_tanl(volf, fl, aa)
    return volf
コード例 #27
0
def if_error_then_all_processes_exit_program(error_status):
	import sys, os
	from utilities import print_msg

	if "OMPI_COMM_WORLD_SIZE" not in os.environ:
		def mpi_comm_rank(n): return 0
		def mpi_bcast(*largs):
			return [largs[0]]
		def mpi_finalize():
			return None
		MPI_INT, MPI_COMM_WORLD = 0, 0
	else:
		from mpi import mpi_comm_rank, mpi_bcast, mpi_finalize, MPI_INT, MPI_COMM_WORLD

	myid = mpi_comm_rank(MPI_COMM_WORLD)
	if error_status != None and error_status != 0:
		error_status_info = error_status
		error_status = 1
	else:
		error_status = 0

	error_status = mpi_bcast(error_status, 1, MPI_INT, 0, MPI_COMM_WORLD)
	error_status = int(error_status[0])

	if error_status > 0:
		if myid == 0:
			if type(error_status_info) == type((1,1)):
				if len(error_status_info) == 2:
					frameinfo = error_status_info[1]
					print_msg("***********************************\n")
					print_msg("** Error: %s\n"%error_status_info[0])
					print_msg("***********************************\n")
					print_msg("** Location: %s\n"%(frameinfo.filename + ":" + str(frameinfo.lineno)))
					print_msg("***********************************\n")
		sys.stdout.flush()
		mpi_finalize()
		sys.exit(1)
コード例 #28
0
ファイル: user_functions.py プロジェクト: cryoem/test
def ref_ali3d( ref_data ):
	from utilities      import print_msg
	from filter         import fit_tanh, filt_tanl
	from fundamentals   import fshift
	from morphology     import threshold
	#  Prepare the reference in 3D alignment, i.e., low-pass filter and center.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:

	global  ref_ali2d_counter
	ref_ali2d_counter += 1

	fl = ref_data[2].cmp("dot",ref_data[2], {"negative":0, "mask":ref_data[0]} )
	print_msg("ref_ali3d    Step = %5d        GOAL = %10.3e\n"%(ref_ali2d_counter,fl))

	cs = [0.0]*3
	#filt = filt_from_fsc(fscc, 0.05)
	#vol  = filt_table(vol, filt)
	# here figure the filtration parameters and filter vol for the  next iteration
	#fl, fh = filt_params(res)
	#vol	= filt_btwl(vol, fl, fh)
	# store the filtered reference volume
	#lk = 0
	#while(res[1][lk] >0.9 and res[0][lk]<0.25):
	#	lk+=1
	#fl = res[0][lk]
	#fh = min(fl+0.1,0.49)
	#vol = filt_btwl(vol, fl, fh)
	#fl, fh = filt_params(fscc)
	#print "fl, fh, iter",fl,fh,Iter
	#vol = filt_btwl(vol, fl, fh)
	stat = Util.infomask(ref_data[2], ref_data[0], False)
	volf = ref_data[2] - stat[0]
	Util.mul_scalar(volf, 1.0/stat[1])
	#volf = threshold(volf)
	Util.mul_img(volf, ref_data[0])
	fl, aa = fit_tanh(ref_data[3])
	#fl = 0.4
	#aa = 0.1
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)
	volf = filt_tanl(volf, fl, aa)
	if ref_data[1] == 1:
		cs = volf.phase_cog()
		msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	return  volf, cs
コード例 #29
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def dovolume(ref_data):
    from utilities import print_msg, read_text_row
    from filter import fit_tanh, filt_tanl
    from fundamentals import fshift
    from morphology import threshold
    #  Prepare the reference in 3D alignment, this function corresponds to what do_volume does.
    #  Input: list ref_data
    #   0 - mask
    #   1 - center flag
    #   2 - raw average
    #   3 - fsc result
    #  Output: filtered, centered, and masked reference image
    #  apply filtration (FSC) to reference image:

    global ref_ali2d_counter
    ref_ali2d_counter += 1

    fl = ref_data[2].cmp("dot", ref_data[2], {
        "negative": 0,
        "mask": ref_data[0]
    })
    print_msg("do_volume user function    Step = %5d        GOAL = %10.3e\n" %
              (ref_ali2d_counter, fl))

    stat = Util.infomask(ref_data[2], ref_data[0], False)
    vol = ref_data[2] - stat[0]
    Util.mul_scalar(vol, 1.0 / stat[1])
    vol = threshold(vol)
    #Util.mul_img(vol, ref_data[0])
    try:
        aa = read_text_row("flaa.txt")[0]
        fl = aa[0]
        aa = aa[1]
    except:
        fl = 0.4
        aa = 0.2
    msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n" % (
        fl, aa)
    print_msg(msg)

    from utilities import read_text_file
    from fundamentals import rops_table, fftip, fft
    from filter import filt_table, filt_btwl
    fftip(vol)
    try:
        rt = read_text_file("pwreference.txt")
        ro = rops_table(vol)
        #  Here unless I am mistaken it is enough to take the beginning of the reference pw.
        for i in xrange(1, len(ro)):
            ro[i] = (rt[i] / ro[i])**0.5
        vol = fft(filt_table(filt_tanl(vol, fl, aa), ro))
        msg = "Power spectrum adjusted\n"
        print_msg(msg)
    except:
        vol = fft(filt_tanl(vol, fl, aa))

    stat = Util.infomask(vol, ref_data[0], False)
    vol -= stat[0]
    Util.mul_scalar(vol, 1.0 / stat[1])
    vol = threshold(vol)
    vol = filt_btwl(vol, 0.38, 0.5)
    Util.mul_img(vol, ref_data[0])

    if ref_data[1] == 1:
        cs = volf.phase_cog()
        msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n" % (
            cs[0], cs[1], cs[2])
        print_msg(msg)
        volf = fshift(volf, -cs[0], -cs[1], -cs[2])
    else:
        cs = [0.0] * 3

    return vol, cs
コード例 #30
0
ファイル: sxshiftali.py プロジェクト: cpsemmens/eman2
def shiftali_MPI(stack, maskfile=None, maxit=100, CTF=False, snr=1.0, Fourvar=False, search_rng=-1, oneDx=False, search_rng_y=-1):  
	from applications import MPI_start_end
	from utilities    import model_circle, model_blank, get_image, peak_search, get_im
	from utilities    import reduce_EMData_to_root, bcast_EMData_to_all, send_attr_dict, file_type, bcast_number_to_all, bcast_list_to_all
	from statistics   import varf2d_MPI
	from fundamentals import fft, ccf, rot_shift3D, rot_shift2D
	from utilities    import get_params2D, set_params2D
	from utilities    import print_msg, print_begin_msg, print_end_msg
	import os
	import sys
	from mpi 	  	  import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	from mpi 	  	  import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv
	from mpi 	  	  import MPI_SUM, MPI_FLOAT, MPI_INT
	from EMAN2	  	  import Processor
	from time         import time	
	
	number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
	myid = mpi_comm_rank(MPI_COMM_WORLD)
	main_node = 0
		
	ftp = file_type(stack)

	if myid == main_node:
		print_begin_msg("shiftali_MPI")

	max_iter=int(maxit)

	if myid == main_node:
		if ftp == "bdb":
			from EMAN2db import db_open_dict
			dummy = db_open_dict(stack, True)
		nima = EMUtil.get_image_count(stack)
	else:
		nima = 0
	nima = bcast_number_to_all(nima, source_node = main_node)
	list_of_particles = range(nima)
	
	image_start, image_end = MPI_start_end(nima, number_of_proc, myid)
	list_of_particles = list_of_particles[image_start: image_end]

	# read nx and ctf_app (if CTF) and broadcast to all nodes
	if myid == main_node:
		ima = EMData()
		ima.read_image(stack, list_of_particles[0], True)
		nx = ima.get_xsize()
		ny = ima.get_ysize()
		if CTF:	ctf_app = ima.get_attr_default('ctf_applied', 2)
		del ima
	else:
		nx = 0
		ny = 0
		if CTF:	ctf_app = 0
	nx = bcast_number_to_all(nx, source_node = main_node)
	ny = bcast_number_to_all(ny, source_node = main_node)
	if CTF:
		ctf_app = bcast_number_to_all(ctf_app, source_node = main_node)
		if ctf_app > 0:	ERROR("data cannot be ctf-applied", "shiftali_MPI", 1, myid)

	if maskfile == None:
		mrad = min(nx, ny)
		mask = model_circle(mrad//2-2, nx, ny)
	else:
		mask = get_im(maskfile)

	if CTF:
		from filter import filt_ctf
		from morphology   import ctf_img
		ctf_abs_sum = EMData(nx, ny, 1, False)
		ctf_2_sum = EMData(nx, ny, 1, False)
	else:
		ctf_2_sum = None

	from global_def import CACHE_DISABLE
	if CACHE_DISABLE:
		data = EMData.read_images(stack, list_of_particles)
	else:
		for i in xrange(number_of_proc):
			if myid == i:
				data = EMData.read_images(stack, list_of_particles)
			if ftp == "bdb": mpi_barrier(MPI_COMM_WORLD)


	for im in xrange(len(data)):
		data[im].set_attr('ID', list_of_particles[im])
		st = Util.infomask(data[im], mask, False)
		data[im] -= st[0]
		if CTF:
			ctf_params = data[im].get_attr("ctf")
			ctfimg = ctf_img(nx, ctf_params, ny=ny)
			Util.add_img2(ctf_2_sum, ctfimg)
			Util.add_img_abs(ctf_abs_sum, ctfimg)

	if CTF:
		reduce_EMData_to_root(ctf_2_sum, myid, main_node)
		reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
	else:  ctf_2_sum = None
	if CTF:
		if myid != main_node:
			del ctf_2_sum
			del ctf_abs_sum
		else:
			temp = EMData(nx, ny, 1, False)
			for i in xrange(0,nx,2):
				for j in xrange(ny):
					temp.set_value_at(i,j,snr)
			Util.add_img(ctf_2_sum, temp)
			del temp

	total_iter = 0

	# apply initial xform.align2d parameters stored in header
	init_params = []
	for im in xrange(len(data)):
		t = data[im].get_attr('xform.align2d')
		init_params.append(t)
		p = t.get_params("2d")
		data[im] = rot_shift2D(data[im], p['alpha'], sx=p['tx'], sy=p['ty'], mirror=p['mirror'], scale=p['scale'])

	# fourier transform all images, and apply ctf if CTF
	for im in xrange(len(data)):
		if CTF:
			ctf_params = data[im].get_attr("ctf")
			data[im] = filt_ctf(fft(data[im]), ctf_params)
		else:
			data[im] = fft(data[im])

	sx_sum=0
	sy_sum=0
	sx_sum_total=0
	sy_sum_total=0
	shift_x = [0.0]*len(data)
	shift_y = [0.0]*len(data)
	ishift_x = [0.0]*len(data)
	ishift_y = [0.0]*len(data)

	for Iter in xrange(max_iter):
		if myid == main_node:
			start_time = time()
			print_msg("Iteration #%4d\n"%(total_iter))
		total_iter += 1
		avg = EMData(nx, ny, 1, False)
		for im in data:  Util.add_img(avg, im)

		reduce_EMData_to_root(avg, myid, main_node)

		if myid == main_node:
			if CTF:
				tavg = Util.divn_filter(avg, ctf_2_sum)
			else:	 tavg = Util.mult_scalar(avg, 1.0/float(nima))
		else:
			tavg = EMData(nx, ny, 1, False)                               

		if Fourvar:
			bcast_EMData_to_all(tavg, myid, main_node)
			vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

		if myid == main_node:
			if Fourvar:
				tavg    = fft(Util.divn_img(fft(tavg), vav))
				vav_r	= Util.pack_complex_to_real(vav)

			# normalize and mask tavg in real space
			tavg = fft(tavg)
			stat = Util.infomask( tavg, mask, False ) 
			tavg -= stat[0]
			Util.mul_img(tavg, mask)
			# For testing purposes: shift tavg to some random place and see if the centering is still correct
			#tavg = rot_shift3D(tavg,sx=3,sy=-4)
			tavg = fft(tavg)

		if Fourvar:  del vav
		bcast_EMData_to_all(tavg, myid, main_node)

		sx_sum=0 
		sy_sum=0 
		if search_rng > 0: nwx = 2*search_rng+1
		else:              nwx = nx
		
		if search_rng_y > 0: nwy = 2*search_rng_y+1
		else:                nwy = ny

		not_zero = 0
		for im in xrange(len(data)):
			if oneDx:
				ctx = Util.window(ccf(data[im],tavg),nwx,1)
				p1  = peak_search(ctx)
				p1_x = -int(p1[0][3])
				ishift_x[im] = p1_x
				sx_sum += p1_x
			else:
				p1 = peak_search(Util.window(ccf(data[im],tavg), nwx,nwy))
				p1_x = -int(p1[0][4])
				p1_y = -int(p1[0][5])
				ishift_x[im] = p1_x
				ishift_y[im] = p1_y
				sx_sum += p1_x
				sy_sum += p1_y

			if not_zero == 0:
				if (not(ishift_x[im] == 0.0)) or (not(ishift_y[im] == 0.0)):
					not_zero = 1

		sx_sum = mpi_reduce(sx_sum, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)  

		if not oneDx:
			sy_sum = mpi_reduce(sy_sum, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)

		if myid == main_node:
			sx_sum_total = int(sx_sum[0])
			if not oneDx:
				sy_sum_total = int(sy_sum[0])
		else:
			sx_sum_total = 0	
			sy_sum_total = 0

		sx_sum_total = bcast_number_to_all(sx_sum_total, source_node = main_node)

		if not oneDx:
			sy_sum_total = bcast_number_to_all(sy_sum_total, source_node = main_node)

		sx_ave = round(float(sx_sum_total)/nima)
		sy_ave = round(float(sy_sum_total)/nima)
		for im in xrange(len(data)): 
			p1_x = ishift_x[im] - sx_ave
			p1_y = ishift_y[im] - sy_ave
			params2 = {"filter_type" : Processor.fourier_filter_types.SHIFT, "x_shift" : p1_x, "y_shift" : p1_y, "z_shift" : 0.0}
			data[im] = Processor.EMFourierFilter(data[im], params2)
			shift_x[im] += p1_x
			shift_y[im] += p1_y
		# stop if all shifts are zero
		not_zero = mpi_reduce(not_zero, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)  
		if myid == main_node:
			not_zero_all = int(not_zero[0])
		else:
			not_zero_all = 0
		not_zero_all = bcast_number_to_all(not_zero_all, source_node = main_node)

		if myid == main_node:
			print_msg("Time of iteration = %12.2f\n"%(time()-start_time))
			start_time = time()

		if not_zero_all == 0:  break

	#for im in xrange(len(data)): data[im] = fft(data[im])  This should not be required as only header information is used
	# combine shifts found with the original parameters
	for im in xrange(len(data)):		
		t0 = init_params[im]
		t1 = Transform()
		t1.set_params({"type":"2D","alpha":0,"scale":t0.get_scale(),"mirror":0,"tx":shift_x[im],"ty":shift_y[im]})
		# combine t0 and t1
		tt = t1*t0
		data[im].set_attr("xform.align2d", tt)  

	# write out headers and STOP, under MPI writing has to be done sequentially
	mpi_barrier(MPI_COMM_WORLD)
	par_str = ["xform.align2d", "ID"]
	if myid == main_node:
		from utilities import file_type
		if(file_type(stack) == "bdb"):
			from utilities import recv_attr_dict_bdb
			recv_attr_dict_bdb(main_node, stack, data, par_str, image_start, image_end, number_of_proc)
		else:
			from utilities import recv_attr_dict
			recv_attr_dict(main_node, stack, data, par_str, image_start, image_end, number_of_proc)
		
	else:           send_attr_dict(main_node, data, par_str, image_start, image_end)
	if myid == main_node: print_end_msg("shiftali_MPI")				
コード例 #31
0
ファイル: functions.py プロジェクト: alushinlab/microtubules
def ali3d_MPI(stack,
              ref_vol,
              outdir,
              maskfile=None,
              ir=1,
              ou=-1,
              rs=1,
              xr="4 2 2 1",
              yr="-1",
              ts="1 1 0.5 0.25",
              delta="10 6 4 4",
              an="-1",
              center=0,
              maxit=5,
              term=95,
              CTF=False,
              fourvar=False,
              snr=1.0,
              ref_a="S",
              sym="c1",
              sort=True,
              cutoff=999.99,
              pix_cutoff="0",
              two_tail=False,
              model_jump="1 1 1 1 1",
              restart=False,
              save_half=False,
              protos=None,
              oplane=None,
              lmask=-1,
              ilmask=-1,
              findseam=False,
              vertstep=None,
              hpars="-1",
              hsearch="0.0 50.0",
              full_output=False,
              compare_repro=False,
              compare_ref_free="-1",
              ref_free_cutoff="-1 -1 -1 -1",
              wcmask=None,
              debug=False,
              recon_pad=4,
              olmask=75):

    from alignment import Numrinit, prepare_refrings
    from utilities import model_circle, get_image, drop_image, get_input_from_string
    from utilities import bcast_list_to_all, bcast_number_to_all, reduce_EMData_to_root, bcast_EMData_to_all
    from utilities import send_attr_dict
    from utilities import get_params_proj, file_type
    from fundamentals import rot_avg_image
    import os
    import types
    from utilities import print_begin_msg, print_end_msg, print_msg
    from mpi import mpi_bcast, mpi_comm_size, mpi_comm_rank, MPI_FLOAT, MPI_COMM_WORLD, mpi_barrier, mpi_reduce
    from mpi import mpi_reduce, MPI_INT, MPI_SUM, mpi_finalize
    from filter import filt_ctf
    from projection import prep_vol, prgs
    from statistics import hist_list, varf3d_MPI, fsc_mask
    from numpy import array, bincount, array2string, ones

    number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
    myid = mpi_comm_rank(MPI_COMM_WORLD)
    main_node = 0
    if myid == main_node:
        if os.path.exists(outdir):
            ERROR(
                'Output directory exists, please change the name and restart the program',
                "ali3d_MPI", 1)
        os.mkdir(outdir)
    mpi_barrier(MPI_COMM_WORLD)

    if debug:
        from time import sleep
        while not os.path.exists(outdir):
            print "Node ", myid, "  waiting..."
            sleep(5)

        info_file = os.path.join(outdir, "progress%04d" % myid)
        finfo = open(info_file, 'w')
    else:
        finfo = None
    mjump = get_input_from_string(model_jump)
    xrng = get_input_from_string(xr)
    if yr == "-1": yrng = xrng
    else: yrng = get_input_from_string(yr)
    step = get_input_from_string(ts)
    delta = get_input_from_string(delta)
    ref_free_cutoff = get_input_from_string(ref_free_cutoff)
    pix_cutoff = get_input_from_string(pix_cutoff)

    lstp = min(len(xrng), len(yrng), len(step), len(delta))
    if an == "-1":
        an = [-1] * lstp
    else:
        an = get_input_from_string(an)
    # make sure pix_cutoff is set for all iterations
    if len(pix_cutoff) < lstp:
        for i in xrange(len(pix_cutoff), lstp):
            pix_cutoff.append(pix_cutoff[-1])
    # don't waste time on sub-pixel alignment for low-resolution ang incr
    for i in range(len(step)):
        if (delta[i] > 4 or delta[i] == -1) and step[i] < 1:
            step[i] = 1

    first_ring = int(ir)
    rstep = int(rs)
    last_ring = int(ou)
    max_iter = int(maxit)
    center = int(center)

    nrefs = EMUtil.get_image_count(ref_vol)
    nmasks = 0
    if maskfile:
        # read number of masks within each maskfile (mc)
        nmasks = EMUtil.get_image_count(maskfile)
        # open masks within maskfile (mc)
        maskF = EMData.read_images(maskfile, xrange(nmasks))
    vol = EMData.read_images(ref_vol, xrange(nrefs))
    nx = vol[0].get_xsize()

    ## make sure box sizes are the same
    if myid == main_node:
        im = EMData.read_images(stack, [0])
        bx = im[0].get_xsize()
        if bx != nx:
            print_msg(
                "Error: Stack box size (%i) differs from initial model (%i)\n"
                % (bx, nx))
            sys.exit()
        del im, bx

    # for helical processing:
    helicalrecon = False
    if protos is not None or hpars != "-1" or findseam is True:
        helicalrecon = True
        # if no out-of-plane param set, use 5 degrees
        if oplane is None:
            oplane = 5.0
    if protos is not None:
        proto = get_input_from_string(protos)
        if len(proto) != nrefs:
            print_msg("Error: insufficient protofilament numbers supplied")
            sys.exit()
    if hpars != "-1":
        hpars = get_input_from_string(hpars)
        if len(hpars) != 2 * nrefs:
            print_msg("Error: insufficient helical parameters supplied")
            sys.exit()
    ## create helical parameter file for helical reconstruction
    if helicalrecon is True and myid == main_node:
        from hfunctions import createHpar
        # create initial helical parameter files
        dp = [0] * nrefs
        dphi = [0] * nrefs
        vdp = [0] * nrefs
        vdphi = [0] * nrefs
        for iref in xrange(nrefs):
            hpar = os.path.join(outdir, "hpar%02d.spi" % (iref))
            params = False
            if hpars != "-1":
                # if helical parameters explicitly given, set twist & rise
                params = [float(hpars[iref * 2]), float(hpars[(iref * 2) + 1])]
            dp[iref], dphi[iref], vdp[iref], vdphi[iref] = createHpar(
                hpar, proto[iref], params, vertstep)

    # get values for helical search parameters
    hsearch = get_input_from_string(hsearch)
    if len(hsearch) != 2:
        print_msg("Error: specify outer and inner radii for helical search")
        sys.exit()

    if last_ring < 0 or last_ring > int(nx / 2) - 2:
        last_ring = int(nx / 2) - 2

    if myid == main_node:
        #	import user_functions
        #	user_func = user_functions.factory[user_func_name]

        print_begin_msg("ali3d_MPI")
        print_msg("Input stack		 : %s\n" % (stack))
        print_msg("Reference volume	    : %s\n" % (ref_vol))
        print_msg("Output directory	    : %s\n" % (outdir))
        if nmasks > 0:
            print_msg("Maskfile (number of masks)  : %s (%i)\n" %
                      (maskfile, nmasks))
        print_msg("Inner radius		: %i\n" % (first_ring))
        print_msg("Outer radius		: %i\n" % (last_ring))
        print_msg("Ring step		   : %i\n" % (rstep))
        print_msg("X search range	      : %s\n" % (xrng))
        print_msg("Y search range	      : %s\n" % (yrng))
        print_msg("Translational step	  : %s\n" % (step))
        print_msg("Angular step		: %s\n" % (delta))
        print_msg("Angular search range	: %s\n" % (an))
        print_msg("Maximum iteration	   : %i\n" % (max_iter))
        print_msg("Center type		 : %i\n" % (center))
        print_msg("CTF correction	      : %s\n" % (CTF))
        print_msg("Signal-to-Noise Ratio       : %f\n" % (snr))
        print_msg("Reference projection method : %s\n" % (ref_a))
        print_msg("Symmetry group	      : %s\n" % (sym))
        print_msg("Fourier padding for 3D      : %i\n" % (recon_pad))
        print_msg("Number of reference models  : %i\n" % (nrefs))
        print_msg("Sort images between models  : %s\n" % (sort))
        print_msg("Allow images to jump	: %s\n" % (mjump))
        print_msg("CC cutoff standard dev      : %f\n" % (cutoff))
        print_msg("Two tail cutoff	     : %s\n" % (two_tail))
        print_msg("Termination pix error       : %f\n" % (term))
        print_msg("Pixel error cutoff	  : %s\n" % (pix_cutoff))
        print_msg("Restart		     : %s\n" % (restart))
        print_msg("Full output		 : %s\n" % (full_output))
        print_msg("Compare reprojections       : %s\n" % (compare_repro))
        print_msg("Compare ref free class avgs : %s\n" % (compare_ref_free))
        print_msg("Use cutoff from ref free    : %s\n" % (ref_free_cutoff))
        if protos:
            print_msg("Protofilament numbers	: %s\n" % (proto))
            print_msg("Using helical search range   : %s\n" % hsearch)
        if findseam is True:
            print_msg("Using seam-based reconstruction\n")
        if hpars != "-1":
            print_msg("Using hpars		  : %s\n" % hpars)
        if vertstep != None:
            print_msg("Using vertical step    : %.2f\n" % vertstep)
        if save_half is True:
            print_msg("Saving even/odd halves\n")
        for i in xrange(100):
            print_msg("*")
        print_msg("\n\n")
    if maskfile:
        if type(maskfile) is types.StringType: mask3D = get_image(maskfile)
        else: mask3D = maskfile
    else: mask3D = model_circle(last_ring, nx, nx, nx)

    numr = Numrinit(first_ring, last_ring, rstep, "F")
    mask2D = model_circle(last_ring, nx, nx) - model_circle(first_ring, nx, nx)

    fscmask = model_circle(last_ring, nx, nx, nx)
    if CTF:
        from filter import filt_ctf
    from reconstruction_rjh import rec3D_MPI_noCTF

    if myid == main_node:
        active = EMUtil.get_all_attributes(stack, 'active')
        list_of_particles = []
        for im in xrange(len(active)):
            if active[im]: list_of_particles.append(im)
        del active
        nima = len(list_of_particles)
    else:
        nima = 0
    total_nima = bcast_number_to_all(nima, source_node=main_node)

    if myid != main_node:
        list_of_particles = [-1] * total_nima
    list_of_particles = bcast_list_to_all(list_of_particles,
                                          source_node=main_node)

    image_start, image_end = MPI_start_end(total_nima, number_of_proc, myid)

    # create a list of images for each node
    list_of_particles = list_of_particles[image_start:image_end]
    nima = len(list_of_particles)
    if debug:
        finfo.write("image_start, image_end: %d %d\n" %
                    (image_start, image_end))
        finfo.flush()

    data = EMData.read_images(stack, list_of_particles)

    t_zero = Transform({
        "type": "spider",
        "phi": 0,
        "theta": 0,
        "psi": 0,
        "tx": 0,
        "ty": 0
    })
    transmulti = [[t_zero for i in xrange(nrefs)] for j in xrange(nima)]

    for iref, im in ((iref, im) for iref in xrange(nrefs)
                     for im in xrange(nima)):
        if nrefs == 1:
            transmulti[im][iref] = data[im].get_attr("xform.projection")
        else:
            # if multi models, keep track of eulers for all models
            try:
                transmulti[im][iref] = data[im].get_attr("eulers_txty.%i" %
                                                         iref)
            except:
                data[im].set_attr("eulers_txty.%i" % iref, t_zero)

    scoremulti = [[0.0 for i in xrange(nrefs)] for j in xrange(nima)]
    pixelmulti = [[0.0 for i in xrange(nrefs)] for j in xrange(nima)]
    ref_res = [0.0 for x in xrange(nrefs)]
    apix = data[0].get_attr('apix_x')

    # for oplane parameter, create cylindrical mask
    if oplane is not None and myid == main_node:
        from hfunctions import createCylMask
        cmaskf = os.path.join(outdir, "mask3D_cyl.mrc")
        mask3D = createCylMask(data, olmask, lmask, ilmask, cmaskf)
        # if finding seam of helix, create wedge masks
        if findseam is True:
            wedgemask = []
            for pf in xrange(nrefs):
                wedgemask.append(EMData())
            # wedgemask option
            if wcmask is not None:
                wcmask = get_input_from_string(wcmask)
                if len(wcmask) != 3:
                    print_msg(
                        "Error: wcmask option requires 3 values: x y radius")
                    sys.exit()

    # determine if particles have helix info:
    try:
        data[0].get_attr('h_angle')
        original_data = []
        boxmask = True
        from hfunctions import createBoxMask
    except:
        boxmask = False

    # prepare particles
    for im in xrange(nima):
        data[im].set_attr('ID', list_of_particles[im])
        data[im].set_attr('pix_score', int(0))
        if CTF:
            # only phaseflip particles, not full CTF correction
            ctf_params = data[im].get_attr("ctf")
            st = Util.infomask(data[im], mask2D, False)
            data[im] -= st[0]
            data[im] = filt_ctf(data[im], ctf_params, sign=-1, binary=1)
            data[im].set_attr('ctf_applied', 1)
        # for window mask:
        if boxmask is True:
            h_angle = data[im].get_attr("h_angle")
            original_data.append(data[im].copy())
            bmask = createBoxMask(nx, apix, ou, lmask, h_angle)
            data[im] *= bmask
            del bmask
    if debug:
        finfo.write('%d loaded  \n' % nima)
        finfo.flush()
    if myid == main_node:
        # initialize data for the reference preparation function
        ref_data = [mask3D, max(center, 0), None, None, None, None]
        # for method -1, switch off centering in user function

    from time import time

    #  this is needed for gathering of pixel errors
    disps = []
    recvcount = []
    disps_score = []
    recvcount_score = []
    for im in xrange(number_of_proc):
        if (im == main_node):
            disps.append(0)
            disps_score.append(0)
        else:
            disps.append(disps[im - 1] + recvcount[im - 1])
            disps_score.append(disps_score[im - 1] + recvcount_score[im - 1])
        ib, ie = MPI_start_end(total_nima, number_of_proc, im)
        recvcount.append(ie - ib)
        recvcount_score.append((ie - ib) * nrefs)

    pixer = [0.0] * nima
    cs = [0.0] * 3
    total_iter = 0
    volodd = EMData.read_images(ref_vol, xrange(nrefs))
    voleve = EMData.read_images(ref_vol, xrange(nrefs))

    if restart:
        # recreate initial volumes from alignments stored in header
        itout = "000_00"
        for iref in xrange(nrefs):
            if (nrefs == 1):
                modout = ""
            else:
                modout = "_model_%02d" % (iref)

            if (sort):
                group = iref
                for im in xrange(nima):
                    imgroup = data[im].get_attr('group')
                    if imgroup == iref:
                        data[im].set_attr('xform.projection',
                                          transmulti[im][iref])
            else:
                group = int(999)
                for im in xrange(nima):
                    data[im].set_attr('xform.projection', transmulti[im][iref])

            fscfile = os.path.join(outdir, "fsc_%s%s" % (itout, modout))

            vol[iref], fscc, volodd[iref], voleve[iref] = rec3D_MPI_noCTF(
                data,
                sym,
                fscmask,
                fscfile,
                myid,
                main_node,
                index=group,
                npad=recon_pad)

            if myid == main_node:
                if helicalrecon:
                    from hfunctions import processHelicalVol
                    vstep = None
                    if vertstep is not None:
                        vstep = (vdp[iref], vdphi[iref])
                    print_msg(
                        "Old rise and twist for model %i     : %8.3f, %8.3f\n"
                        % (iref, dp[iref], dphi[iref]))
                    hvals = processHelicalVol(vol[iref], voleve[iref],
                                              volodd[iref], iref, outdir,
                                              itout, dp[iref], dphi[iref],
                                              apix, hsearch, findseam, vstep,
                                              wcmask)
                    (vol[iref], voleve[iref], volodd[iref], dp[iref],
                     dphi[iref], vdp[iref], vdphi[iref]) = hvals
                    print_msg(
                        "New rise and twist for model %i     : %8.3f, %8.3f\n"
                        % (iref, dp[iref], dphi[iref]))
                    # get new FSC from symmetrized half volumes
                    fscc = fsc_mask(volodd[iref], voleve[iref], mask3D, rstep,
                                    fscfile)
                else:
                    vol[iref].write_image(
                        os.path.join(outdir, "vol_%s.hdf" % itout), -1)

                if save_half is True:
                    volodd[iref].write_image(
                        os.path.join(outdir, "volodd_%s.hdf" % itout), -1)
                    voleve[iref].write_image(
                        os.path.join(outdir, "voleve_%s.hdf" % itout), -1)

                if nmasks > 1:
                    # Read mask for multiplying
                    ref_data[0] = maskF[iref]
                ref_data[2] = vol[iref]
                ref_data[3] = fscc
                #  call user-supplied function to prepare reference image, i.e., center and filter it
                vol[iref], cs, fl = ref_ali3d(ref_data)
                vol[iref].write_image(
                    os.path.join(outdir, "volf_%s.hdf" % (itout)), -1)
                if (apix == 1):
                    res_msg = "Models filtered at spatial frequency of:\t"
                    res = fl
                else:
                    res_msg = "Models filtered at resolution of:       \t"
                    res = apix / fl
                ares = array2string(array(res), precision=2)
                print_msg("%s%s\n\n" % (res_msg, ares))

            bcast_EMData_to_all(vol[iref], myid, main_node)
            # write out headers, under MPI writing has to be done sequentially
            mpi_barrier(MPI_COMM_WORLD)

    # projection matching
    for N_step in xrange(lstp):
        terminate = 0
        Iter = -1
        while (Iter < max_iter - 1 and terminate == 0):
            Iter += 1
            total_iter += 1
            itout = "%03g_%02d" % (delta[N_step], Iter)
            if myid == main_node:
                print_msg(
                    "ITERATION #%3d, inner iteration #%3d\nDelta = %4.1f, an = %5.2f, xrange = %5.2f, yrange = %5.2f, step = %5.2f\n\n"
                    % (N_step, Iter, delta[N_step], an[N_step], xrng[N_step],
                       yrng[N_step], step[N_step]))

            for iref in xrange(nrefs):
                if myid == main_node: start_time = time()
                volft, kb = prep_vol(vol[iref])

                ## constrain projections to out of plane parameter
                theta1 = None
                theta2 = None
                if oplane is not None:
                    theta1 = 90 - oplane
                    theta2 = 90 + oplane
                refrings = prepare_refrings(volft,
                                            kb,
                                            nx,
                                            delta[N_step],
                                            ref_a,
                                            sym,
                                            numr,
                                            MPI=True,
                                            phiEqpsi="Minus",
                                            initial_theta=theta1,
                                            delta_theta=theta2)

                del volft, kb

                if myid == main_node:
                    print_msg(
                        "Time to prepare projections for model %i: %s\n" %
                        (iref, legibleTime(time() - start_time)))
                    start_time = time()

                for im in xrange(nima):
                    data[im].set_attr("xform.projection", transmulti[im][iref])
                    if an[N_step] == -1:
                        t1, peak, pixer[im] = proj_ali_incore(
                            data[im], refrings, numr, xrng[N_step],
                            yrng[N_step], step[N_step], finfo)
                    else:
                        t1, peak, pixer[im] = proj_ali_incore_local(
                            data[im], refrings, numr, xrng[N_step],
                            yrng[N_step], step[N_step], an[N_step], finfo)
                    #data[im].set_attr("xform.projection"%iref, t1)
                    if nrefs > 1:
                        data[im].set_attr("eulers_txty.%i" % iref, t1)
                    scoremulti[im][iref] = peak
                    from pixel_error import max_3D_pixel_error
                    # t1 is the current param, t2 is old
                    t2 = transmulti[im][iref]
                    pixelmulti[im][iref] = max_3D_pixel_error(t1, t2, numr[-3])
                    transmulti[im][iref] = t1

                if myid == main_node:
                    print_msg("Time of alignment for model %i: %s\n" %
                              (iref, legibleTime(time() - start_time)))
                    start_time = time()

            # gather scoring data from all processors
            from mpi import mpi_gatherv
            scoremultisend = sum(scoremulti, [])
            pixelmultisend = sum(pixelmulti, [])
            tmp = mpi_gatherv(scoremultisend, len(scoremultisend), MPI_FLOAT,
                              recvcount_score, disps_score, MPI_FLOAT,
                              main_node, MPI_COMM_WORLD)
            tmp1 = mpi_gatherv(pixelmultisend, len(pixelmultisend), MPI_FLOAT,
                               recvcount_score, disps_score, MPI_FLOAT,
                               main_node, MPI_COMM_WORLD)
            tmp = mpi_bcast(tmp, (total_nima * nrefs), MPI_FLOAT, 0,
                            MPI_COMM_WORLD)
            tmp1 = mpi_bcast(tmp1, (total_nima * nrefs), MPI_FLOAT, 0,
                             MPI_COMM_WORLD)
            tmp = map(float, tmp)
            tmp1 = map(float, tmp1)
            score = array(tmp).reshape(-1, nrefs)
            pixelerror = array(tmp1).reshape(-1, nrefs)
            score_local = array(scoremulti)
            mean_score = score.mean(axis=0)
            std_score = score.std(axis=0)
            cut = mean_score - (cutoff * std_score)
            cut2 = mean_score + (cutoff * std_score)
            res_max = score_local.argmax(axis=1)
            minus_cc = [0.0 for x in xrange(nrefs)]
            minus_pix = [0.0 for x in xrange(nrefs)]
            minus_ref = [0.0 for x in xrange(nrefs)]

            #output pixel errors
            if (myid == main_node):
                from statistics import hist_list
                lhist = 20
                pixmin = pixelerror.min(axis=1)
                region, histo = hist_list(pixmin, lhist)
                if (region[0] < 0.0): region[0] = 0.0
                print_msg(
                    "Histogram of pixel errors\n      ERROR       number of particles\n"
                )
                for lhx in xrange(lhist):
                    print_msg(" %10.3f     %7d\n" % (region[lhx], histo[lhx]))
                # Terminate if 95% within 1 pixel error
                im = 0
                for lhx in xrange(lhist):
                    if (region[lhx] > 1.0): break
                    im += histo[lhx]
                print_msg("Percent of particles with pixel error < 1: %f\n\n" %
                          (im / float(total_nima) * 100))
                term_cond = float(term) / 100
                if (im / float(total_nima) > term_cond):
                    terminate = 1
                    print_msg("Terminating internal loop\n")
                del region, histo
            terminate = mpi_bcast(terminate, 1, MPI_INT, 0, MPI_COMM_WORLD)
            terminate = int(terminate[0])

            for im in xrange(nima):
                if (sort == False):
                    data[im].set_attr('group', 999)
                elif (mjump[N_step] == 1):
                    data[im].set_attr('group', int(res_max[im]))

                pix_run = data[im].get_attr('pix_score')
                if (pix_cutoff[N_step] == 1
                        and (terminate == 1 or Iter == max_iter - 1)):
                    if (pixelmulti[im][int(res_max[im])] > 1):
                        data[im].set_attr('pix_score', int(777))

                if (score_local[im][int(res_max[im])] < cut[int(
                        res_max[im])]) or (two_tail and score_local[im][int(
                            res_max[im])] > cut2[int(res_max[im])]):
                    data[im].set_attr('group', int(888))
                    minus_cc[int(res_max[im])] = minus_cc[int(res_max[im])] + 1

                if (pix_run == 777):
                    data[im].set_attr('group', int(777))
                    minus_pix[int(
                        res_max[im])] = minus_pix[int(res_max[im])] + 1

                if (compare_ref_free != "-1") and (ref_free_cutoff[N_step] !=
                                                   -1) and (total_iter > 1):
                    id = data[im].get_attr('ID')
                    if id in rejects:
                        data[im].set_attr('group', int(666))
                        minus_ref[int(
                            res_max[im])] = minus_ref[int(res_max[im])] + 1

            minus_cc_tot = mpi_reduce(minus_cc, nrefs, MPI_FLOAT, MPI_SUM, 0,
                                      MPI_COMM_WORLD)
            minus_pix_tot = mpi_reduce(minus_pix, nrefs, MPI_FLOAT, MPI_SUM, 0,
                                       MPI_COMM_WORLD)
            minus_ref_tot = mpi_reduce(minus_ref, nrefs, MPI_FLOAT, MPI_SUM, 0,
                                       MPI_COMM_WORLD)
            if (myid == main_node):
                if (sort):
                    tot_max = score.argmax(axis=1)
                    res = bincount(tot_max)
                else:
                    res = ones(nrefs) * total_nima
                print_msg("Particle distribution:	     \t\t%s\n" % (res * 1.0))
                afcut1 = res - minus_cc_tot
                afcut2 = afcut1 - minus_pix_tot
                afcut3 = afcut2 - minus_ref_tot
                print_msg("Particle distribution after cc cutoff:\t\t%s\n" %
                          (afcut1))
                print_msg("Particle distribution after pix cutoff:\t\t%s\n" %
                          (afcut2))
                print_msg("Particle distribution after ref cutoff:\t\t%s\n\n" %
                          (afcut3))

            res = [0.0 for i in xrange(nrefs)]
            for iref in xrange(nrefs):
                if (center == -1):
                    from utilities import estimate_3D_center_MPI, rotate_3D_shift
                    dummy = EMData()
                    cs[0], cs[1], cs[2], dummy, dummy = estimate_3D_center_MPI(
                        data, total_nima, myid, number_of_proc, main_node)
                    cs = mpi_bcast(cs, 3, MPI_FLOAT, main_node, MPI_COMM_WORLD)
                    cs = [-float(cs[0]), -float(cs[1]), -float(cs[2])]
                    rotate_3D_shift(data, cs)

                if (sort):
                    group = iref
                    for im in xrange(nima):
                        imgroup = data[im].get_attr('group')
                        if imgroup == iref:
                            data[im].set_attr('xform.projection',
                                              transmulti[im][iref])
                else:
                    group = int(999)
                    for im in xrange(nima):
                        data[im].set_attr('xform.projection',
                                          transmulti[im][iref])
                if (nrefs == 1):
                    modout = ""
                else:
                    modout = "_model_%02d" % (iref)

                fscfile = os.path.join(outdir, "fsc_%s%s" % (itout, modout))
                vol[iref], fscc, volodd[iref], voleve[iref] = rec3D_MPI_noCTF(
                    data,
                    sym,
                    fscmask,
                    fscfile,
                    myid,
                    main_node,
                    index=group,
                    npad=recon_pad)

                if myid == main_node:
                    print_msg("3D reconstruction time for model %i: %s\n" %
                              (iref, legibleTime(time() - start_time)))
                    start_time = time()

                # Compute Fourier variance
                if fourvar:
                    outvar = os.path.join(outdir, "volVar_%s.hdf" % (itout))
                    ssnr_file = os.path.join(outdir, "ssnr_%s" % (itout))
                    varf = varf3d_MPI(data,
                                      ssnr_text_file=ssnr_file,
                                      mask2D=None,
                                      reference_structure=vol[iref],
                                      ou=last_ring,
                                      rw=1.0,
                                      npad=1,
                                      CTF=None,
                                      sign=1,
                                      sym=sym,
                                      myid=myid)
                    if myid == main_node:
                        print_msg(
                            "Time to calculate 3D Fourier variance for model %i: %s\n"
                            % (iref, legibleTime(time() - start_time)))
                        start_time = time()
                        varf = 1.0 / varf
                        varf.write_image(outvar, -1)
                else:
                    varf = None

                if myid == main_node:
                    if helicalrecon:
                        from hfunctions import processHelicalVol

                        vstep = None
                        if vertstep is not None:
                            vstep = (vdp[iref], vdphi[iref])
                        print_msg(
                            "Old rise and twist for model %i     : %8.3f, %8.3f\n"
                            % (iref, dp[iref], dphi[iref]))
                        hvals = processHelicalVol(vol[iref], voleve[iref],
                                                  volodd[iref], iref, outdir,
                                                  itout, dp[iref], dphi[iref],
                                                  apix, hsearch, findseam,
                                                  vstep, wcmask)
                        (vol[iref], voleve[iref], volodd[iref], dp[iref],
                         dphi[iref], vdp[iref], vdphi[iref]) = hvals
                        print_msg(
                            "New rise and twist for model %i     : %8.3f, %8.3f\n"
                            % (iref, dp[iref], dphi[iref]))
                        # get new FSC from symmetrized half volumes
                        fscc = fsc_mask(volodd[iref], voleve[iref], mask3D,
                                        rstep, fscfile)

                        print_msg(
                            "Time to search and apply helical symmetry for model %i: %s\n\n"
                            % (iref, legibleTime(time() - start_time)))
                        start_time = time()
                    else:
                        vol[iref].write_image(
                            os.path.join(outdir, "vol_%s.hdf" % (itout)), -1)

                    if save_half is True:
                        volodd[iref].write_image(
                            os.path.join(outdir, "volodd_%s.hdf" % (itout)),
                            -1)
                        voleve[iref].write_image(
                            os.path.join(outdir, "voleve_%s.hdf" % (itout)),
                            -1)

                    if nmasks > 1:
                        # Read mask for multiplying
                        ref_data[0] = maskF[iref]
                    ref_data[2] = vol[iref]
                    ref_data[3] = fscc
                    ref_data[4] = varf
                    #  call user-supplied function to prepare reference image, i.e., center and filter it
                    vol[iref], cs, fl = ref_ali3d(ref_data)
                    vol[iref].write_image(
                        os.path.join(outdir, "volf_%s.hdf" % (itout)), -1)
                    if (apix == 1):
                        res_msg = "Models filtered at spatial frequency of:\t"
                        res[iref] = fl
                    else:
                        res_msg = "Models filtered at resolution of:       \t"
                        res[iref] = apix / fl

                del varf
                bcast_EMData_to_all(vol[iref], myid, main_node)

                if compare_ref_free != "-1": compare_repro = True
                if compare_repro:
                    outfile_repro = comp_rep(refrings, data, itout, modout,
                                             vol[iref], group, nima, nx, myid,
                                             main_node, outdir)
                    mpi_barrier(MPI_COMM_WORLD)
                    if compare_ref_free != "-1":
                        ref_free_output = os.path.join(
                            outdir, "ref_free_%s%s" % (itout, modout))
                        rejects = compare(compare_ref_free, outfile_repro,
                                          ref_free_output, yrng[N_step],
                                          xrng[N_step], rstep, nx, apix,
                                          ref_free_cutoff[N_step],
                                          number_of_proc, myid, main_node)

            # retrieve alignment params from all processors
            par_str = ['xform.projection', 'ID', 'group']
            if nrefs > 1:
                for iref in xrange(nrefs):
                    par_str.append('eulers_txty.%i' % iref)

            if myid == main_node:
                from utilities import recv_attr_dict
                recv_attr_dict(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)

            else:
                send_attr_dict(main_node, data, par_str, image_start,
                               image_end)

            if myid == main_node:
                ares = array2string(array(res), precision=2)
                print_msg("%s%s\n\n" % (res_msg, ares))
                dummy = EMData()
                if full_output:
                    nimat = EMUtil.get_image_count(stack)
                    output_file = os.path.join(outdir, "paramout_%s" % itout)
                    foutput = open(output_file, 'w')
                    for im in xrange(nimat):
                        # save the parameters for each of the models
                        outstring = ""
                        dummy.read_image(stack, im, True)
                        param3d = dummy.get_attr('xform.projection')
                        g = dummy.get_attr("group")
                        # retrieve alignments in EMAN-format
                        pE = param3d.get_params('eman')
                        outstring += "%f\t%f\t%f\t%f\t%f\t%i\n" % (
                            pE["az"], pE["alt"], pE["phi"], pE["tx"], pE["ty"],
                            g)
                        foutput.write(outstring)
                    foutput.close()
                del dummy
            mpi_barrier(MPI_COMM_WORLD)


#	mpi_finalize()

    if myid == main_node: print_end_msg("ali3d_MPI")
コード例 #32
0
ファイル: user_functions.py プロジェクト: a-re/EMAN2-classes
def spruce_up_var_m(refdata):
    from utilities import print_msg
    from utilities import model_circle, get_im
    from filter import filt_tanl, filt_gaussl
    from morphology import threshold
    import os

    numref = refdata[0]
    outdir = refdata[1]
    fscc = refdata[2]
    total_iter = refdata[3]
    varf = refdata[4]
    mask = refdata[5]
    ali50S = refdata[6]

    if ali50S:
        mask_50S = get_im("mask-50S.spi")

    if fscc is None:
        flmin = 0.4
        aamin = 0.1
    else:
        flmin, aamin, idmin = minfilt(fscc)
        aamin = aamin

    msg = "Minimum tangent filter:  cut-off frequency = %10.3f     fall-off = %10.3f\n" % (
        fflmin, aamin)
    print_msg(msg)

    for i in xrange(numref):
        volf = get_im(os.path.join(outdir, "vol%04d.hdf" % total_iter), i)
        if (not (varf is None)): volf = volf.filter_by_image(varf)
        volf = filt_tanl(volf, flmin, aamin)
        stat = Util.infomask(volf, mask, True)
        volf -= stat[0]
        Util.mul_scalar(volf, 1.0 / stat[1])

        nx = volf.get_xsize()
        stat = Util.infomask(
            volf,
            model_circle(nx // 2 - 2, nx, nx, nx) -
            model_circle(nx // 2 - 6, nx, nx, nx), True)
        volf -= stat[0]
        Util.mul_img(volf, mask)

        volf = threshold(volf)
        volf = filt_gaussl(volf, 0.4)

        if ali50S:
            if i == 0:
                v50S_0 = volf.copy()
                v50S_0 *= mask_50S
            else:
                from applications import ali_vol_3
                from fundamentals import rot_shift3D
                v50S_i = volf.copy()
                v50S_i *= mask_50S

                params = ali_vol_3(v50S_i, v50S_0, 10.0, 0.5, mask=mask_50S)
                volf = rot_shift3D(volf, params[0], params[1], params[2],
                                   params[3], params[4], params[5], 1.0)

        volf.write_image(os.path.join(outdir, "volf%04d.hdf" % total_iter), i)
コード例 #33
0
ファイル: projection.py プロジェクト: a-re/EMAN2-classes
def cml_head_log(stack, outdir, delta, ir, ou, lf, hf, rand_seed, maxit, given,
                 flag_weights, trials, ncpu):
    from utilities import print_msg

    # call global var
    global g_anglst, g_n_prj, g_d_psi, g_n_anglst

    print_msg('Input stack                  : %s\n' % stack)
    print_msg('Number of projections        : %d\n' % g_n_prj)
    print_msg('Output directory             : %s\n' % outdir)
    print_msg('Angular step                 : %5.2f\n' % delta)
    print_msg('Sinogram angle accuracy      : %5.2f\n' % g_d_psi)
    print_msg('Inner particle radius        : %5.2f\n' % ir)
    print_msg('Outer particle radius        : %5.2f\n' % ou)
    print_msg('Filter, minimum frequency    : %5.3f\n' % lf)
    print_msg('Filter, maximum frequency    : %5.3f\n' % hf)
    print_msg('Random seed                  : %i\n' % rand_seed)
    print_msg('Number of maximum iterations : %d\n' % maxit)
    print_msg('Start from given orientations: %s\n' % given)
    print_msg('Number of angles             : %i\n' % g_n_anglst)
    print_msg('Number of trials             : %i\n' % trials)
    print_msg('Number of cpus               : %i\n' % ncpu)
    print_msg('Use Voronoi weights          : %s\n\n' % flag_weights)
コード例 #34
0
ファイル: user_functions.py プロジェクト: cryoem/test
def dovolume( ref_data ):
	from utilities      import print_msg, read_text_row
	from filter         import fit_tanh, filt_tanl
	from fundamentals   import fshift
	from morphology     import threshold
	#  Prepare the reference in 3D alignment, this function corresponds to what do_volume does.
	#  Input: list ref_data
	#   0 - mask
	#   1 - center flag
	#   2 - raw average
	#   3 - fsc result
	#  Output: filtered, centered, and masked reference image
	#  apply filtration (FSC) to reference image:

	global  ref_ali2d_counter
	ref_ali2d_counter += 1

	fl = ref_data[2].cmp("dot",ref_data[2], {"negative":0, "mask":ref_data[0]} )
	print_msg("do_volume user function    Step = %5d        GOAL = %10.3e\n"%(ref_ali2d_counter,fl))

	stat = Util.infomask(ref_data[2], ref_data[0], False)
	vol = ref_data[2] - stat[0]
	Util.mul_scalar(vol, 1.0/stat[1])
	vol = threshold(vol)
	#Util.mul_img(vol, ref_data[0])
	try:
		aa = read_text_row("flaa.txt")[0]
		fl = aa[0]
		aa=aa[1]
	except:
		fl = 0.4
		aa = 0.2
	msg = "Tangent filter:  cut-off frequency = %10.3f        fall-off = %10.3f\n"%(fl, aa)
	print_msg(msg)

	from utilities    import read_text_file
	from fundamentals import rops_table, fftip, fft
	from filter       import filt_table, filt_btwl
	fftip(vol)
	try:
		rt = read_text_file( "pwreference.txt" )
		ro = rops_table(vol)
		#  Here unless I am mistaken it is enough to take the beginning of the reference pw.
		for i in xrange(1,len(ro)):  ro[i] = (rt[i]/ro[i])**0.5
		vol = fft( filt_table( filt_tanl(vol, fl, aa), ro) )
		msg = "Power spectrum adjusted\n"
		print_msg(msg)
	except:
		vol = fft( filt_tanl(vol, fl, aa) )

	stat = Util.infomask(vol, ref_data[0], False)
	vol -= stat[0]
	Util.mul_scalar(vol, 1.0/stat[1])
	vol = threshold(vol)
	vol = filt_btwl(vol, 0.38, 0.5)
	Util.mul_img(vol, ref_data[0])

	if ref_data[1] == 1:
		cs = volf.phase_cog()
		msg = "Center x = %10.3f        Center y = %10.3f        Center z = %10.3f\n"%(cs[0], cs[1], cs[2])
		print_msg(msg)
		volf  = fshift(volf, -cs[0], -cs[1], -cs[2])
	else:  	cs = [0.0]*3

	return  vol, cs
コード例 #35
0
def helicalshiftali_MPI(stack,
                        maskfile=None,
                        maxit=100,
                        CTF=False,
                        snr=1.0,
                        Fourvar=False,
                        search_rng=-1):
    from applications import MPI_start_end
    from utilities import model_circle, model_blank, get_image, peak_search, get_im, pad
    from utilities import reduce_EMData_to_root, bcast_EMData_to_all, send_attr_dict, file_type, bcast_number_to_all, bcast_list_to_all
    from pap_statistics import varf2d_MPI
    from fundamentals import fft, ccf, rot_shift3D, rot_shift2D, fshift
    from utilities import get_params2D, set_params2D, chunks_distribution
    from utilities import print_msg, print_begin_msg, print_end_msg
    import os
    import sys
    from mpi import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
    from mpi import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv
    from mpi import MPI_SUM, MPI_FLOAT, MPI_INT
    from time import time
    from pixel_error import ordersegments
    from math import sqrt, atan2, tan, pi

    nproc = mpi_comm_size(MPI_COMM_WORLD)
    myid = mpi_comm_rank(MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("helical-shiftali_MPI")

    max_iter = int(maxit)
    if (myid == main_node):
        infils = EMUtil.get_all_attributes(stack, "filament")
        ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
        filaments = ordersegments(infils, ptlcoords)
        total_nfils = len(filaments)
        inidl = [0] * total_nfils
        for i in range(total_nfils):
            inidl[i] = len(filaments[i])
        linidl = sum(inidl)
        nima = linidl
        tfilaments = []
        for i in range(total_nfils):
            tfilaments += filaments[i]
        del filaments
    else:
        total_nfils = 0
        linidl = 0
    total_nfils = bcast_number_to_all(total_nfils, source_node=main_node)
    if myid != main_node:
        inidl = [-1] * total_nfils
    inidl = bcast_list_to_all(inidl, myid, source_node=main_node)
    linidl = bcast_number_to_all(linidl, source_node=main_node)
    if myid != main_node:
        tfilaments = [-1] * linidl
    tfilaments = bcast_list_to_all(tfilaments, myid, source_node=main_node)
    filaments = []
    iendi = 0
    for i in range(total_nfils):
        isti = iendi
        iendi = isti + inidl[i]
        filaments.append(tfilaments[isti:iendi])
    del tfilaments, inidl

    if myid == main_node:
        print_msg("total number of filaments: %d" % total_nfils)
    if total_nfils < nproc:
        ERROR(
            'number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'
            % (nproc, total_nfils), "ehelix_MPI", 1, myid)

    #  balanced load
    temp = chunks_distribution([[len(filaments[i]), i]
                                for i in range(len(filaments))],
                               nproc)[myid:myid + 1][0]
    filaments = [filaments[temp[i][1]] for i in range(len(temp))]
    nfils = len(filaments)

    #filaments = [[0,1]]
    #print "filaments",filaments
    list_of_particles = []
    indcs = []
    k = 0
    for i in range(nfils):
        list_of_particles += filaments[i]
        k1 = k + len(filaments[i])
        indcs.append([k, k1])
        k = k1
    data = EMData.read_images(stack, list_of_particles)
    ldata = len(data)
    print("ldata=", ldata)
    nx = data[0].get_xsize()
    ny = data[0].get_ysize()
    if maskfile == None:
        mrad = min(nx, ny) // 2 - 2
        mask = pad(model_blank(2 * mrad + 1, ny, 1, 1.0), nx, ny, 1, 0.0)
    else:
        mask = get_im(maskfile)

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(ldata):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'],
                               p['mirror'], p['scale'])

    if CTF:
        from filter import filt_ctf
        from morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None
        ctf_abs_sum = None

    from utilities import info

    for im in range(ldata):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            qctf = data[im].get_attr("ctf_applied")
            if qctf == 0:
                data[im] = filt_ctf(fft(data[im]), ctf_params)
                data[im].set_attr('ctf_applied', 1)
            elif qctf != 1:
                ERROR('Incorrectly set qctf flag', "helicalshiftali_MPI", 1,
                      myid)
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)
        else:
            data[im] = fft(data[im])

    del list_of_particles

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            tsnr = 1. / snr
            for i in range(0, nx + 2, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, tsnr)
                    temp.set_value_at(i + 1, j, 0.0)
            #info(ctf_2_sum)
            Util.add_img(ctf_2_sum, temp)
            #info(ctf_2_sum)
            del temp

    total_iter = 0
    shift_x = [0.0] * ldata

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in range(ldata):
            Util.add_img(avg, fshift(data[im], shift_x[im]))

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF: tavg = Util.divn_filter(avg, ctf_2_sum)
            else: tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = model_blank(nx, ny)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)
            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            tavg.write_image("tavg.hdf", Iter)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)
        tavg = fft(tavg)

        sx_sum = 0.0
        nxc = nx // 2

        for ifil in range(nfils):
            """
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
            # Calculate 1D ccf between each segment and filament average
            nsegms = indcs[ifil][1] - indcs[ifil][0]
            ctx = [None] * nsegms
            pcoords = [None] * nsegms
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                ctx[im - indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx,
                                                       1)
                pcoords[im - indcs[ifil][0]] = data[im].get_attr(
                    'ptcl_source_coord')
                #ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
                #print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
            # search for best x-shift
            cents = nsegms // 2

            dst = sqrt(
                max((pcoords[cents][0] - pcoords[0][0])**2 +
                    (pcoords[cents][1] - pcoords[0][1])**2,
                    (pcoords[cents][0] - pcoords[-1][0])**2 +
                    (pcoords[cents][1] - pcoords[-1][1])**2))
            maxincline = atan2(ny // 2 - 2 - float(search_rng), dst)
            kang = int(dst * tan(maxincline) + 0.5)
            #print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang

            # ## C code for alignment. @ming
            results = [0.0] * 3
            results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline,
                                         kang, search_rng, nxc)
            sib = int(results[0])
            bang = results[1]
            qm = results[2]
            #print qm, sib, bang

            # qm = -1.e23
            #
            # 			for six in xrange(-search_rng, search_rng+1,1):
            # 				q0 = ctx[cents].get_value_at(six+nxc)
            # 				for incline in xrange(kang+1):
            # 					qt = q0
            # 					qu = q0
            # 					if(kang>0):  tang = tan(maxincline/kang*incline)
            # 					else:        tang = 0.0
            # 					for kim in xrange(cents+1,nsegms):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					for kim in xrange(cents):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl =  dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					if( qt > qm ):
            # 						qm = qt
            # 						sib = six
            # 						bang = tang
            # 					if( qu > qm ):
            # 						qm = qu
            # 						sib = six
            # 						bang = -tang
            #if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
            #print qm,six,sib,bang
            #print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                kim = im - indcs[ifil][0]
                dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 +
                           (pcoords[cents][1] - pcoords[kim][1])**2)
                if (kim < cents): xl = -dst * bang + sib
                else: xl = dst * bang + sib
                shift_x[im] = xl

            # Average shift
            sx_sum += shift_x[indcs[ifil][0] + cents]

        # #print myid,sx_sum,total_nfils
        sx_sum = mpi_reduce(sx_sum, 1, MPI_FLOAT, MPI_SUM, main_node,
                            MPI_COMM_WORLD)
        if myid == main_node:
            sx_sum = float(sx_sum[0]) / total_nfils
            print_msg("Average shift  %6.2f\n" % (sx_sum))
        else:
            sx_sum = 0.0
        sx_sum = 0.0
        sx_sum = bcast_number_to_all(sx_sum, source_node=main_node)
        for im in range(ldata):
            shift_x[im] -= sx_sum
            #print  "   %3d  %6.3f"%(im,shift_x[im])
        #exit()

    # combine shifts found with the original parameters
    for im in range(ldata):
        t1 = Transform()
        ##import random
        ##shix=random.randint(-10, 10)
        ##t1.set_params({"type":"2D","tx":shix})
        t1.set_params({"type": "2D", "tx": shift_x[im]})
        # combine t0 and t1
        tt = t1 * init_params[im]
        data[im].set_attr("xform.align2d", tt)
    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi_barrier(MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from utilities import file_type
        if (file_type(stack) == "bdb"):
            from utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata,
                               nproc)
        else:
            from utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
    else:
        send_attr_dict(main_node, data, par_str, 0, ldata)
    if myid == main_node: print_end_msg("helical-shiftali_MPI")
コード例 #36
0
def shiftali_MPI(stack,
                 maskfile=None,
                 maxit=100,
                 CTF=False,
                 snr=1.0,
                 Fourvar=False,
                 search_rng=-1,
                 oneDx=False,
                 search_rng_y=-1):
    from applications import MPI_start_end
    from utilities import model_circle, model_blank, get_image, peak_search, get_im
    from utilities import reduce_EMData_to_root, bcast_EMData_to_all, send_attr_dict, file_type, bcast_number_to_all, bcast_list_to_all
    from pap_statistics import varf2d_MPI
    from fundamentals import fft, ccf, rot_shift3D, rot_shift2D
    from utilities import get_params2D, set_params2D
    from utilities import print_msg, print_begin_msg, print_end_msg
    import os
    import sys
    from mpi import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
    from mpi import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv
    from mpi import MPI_SUM, MPI_FLOAT, MPI_INT
    from EMAN2 import Processor
    from time import time

    number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
    myid = mpi_comm_rank(MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("shiftali_MPI")

    max_iter = int(maxit)

    if myid == main_node:
        if ftp == "bdb":
            from EMAN2db import db_open_dict
            dummy = db_open_dict(stack, True)
        nima = EMUtil.get_image_count(stack)
    else:
        nima = 0
    nima = bcast_number_to_all(nima, source_node=main_node)
    list_of_particles = list(range(nima))

    image_start, image_end = MPI_start_end(nima, number_of_proc, myid)
    list_of_particles = list_of_particles[image_start:image_end]

    # read nx and ctf_app (if CTF) and broadcast to all nodes
    if myid == main_node:
        ima = EMData()
        ima.read_image(stack, list_of_particles[0], True)
        nx = ima.get_xsize()
        ny = ima.get_ysize()
        if CTF: ctf_app = ima.get_attr_default('ctf_applied', 2)
        del ima
    else:
        nx = 0
        ny = 0
        if CTF: ctf_app = 0
    nx = bcast_number_to_all(nx, source_node=main_node)
    ny = bcast_number_to_all(ny, source_node=main_node)
    if CTF:
        ctf_app = bcast_number_to_all(ctf_app, source_node=main_node)
        if ctf_app > 0:
            ERROR("data cannot be ctf-applied", "shiftali_MPI", 1, myid)

    if maskfile == None:
        mrad = min(nx, ny)
        mask = model_circle(mrad // 2 - 2, nx, ny)
    else:
        mask = get_im(maskfile)

    if CTF:
        from filter import filt_ctf
        from morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None

    from global_def import CACHE_DISABLE
    if CACHE_DISABLE:
        data = EMData.read_images(stack, list_of_particles)
    else:
        for i in range(number_of_proc):
            if myid == i:
                data = EMData.read_images(stack, list_of_particles)
            if ftp == "bdb": mpi_barrier(MPI_COMM_WORLD)

    for im in range(len(data)):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    else:
        ctf_2_sum = None
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            for i in range(0, nx, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, snr)
            Util.add_img(ctf_2_sum, temp)
            del temp

    total_iter = 0

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(len(data)):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im],
                               p['alpha'],
                               sx=p['tx'],
                               sy=p['ty'],
                               mirror=p['mirror'],
                               scale=p['scale'])

    # fourier transform all images, and apply ctf if CTF
    for im in range(len(data)):
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            data[im] = filt_ctf(fft(data[im]), ctf_params)
        else:
            data[im] = fft(data[im])

    sx_sum = 0
    sy_sum = 0
    sx_sum_total = 0
    sy_sum_total = 0
    shift_x = [0.0] * len(data)
    shift_y = [0.0] * len(data)
    ishift_x = [0.0] * len(data)
    ishift_y = [0.0] * len(data)

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in data:
            Util.add_img(avg, im)

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF:
                tavg = Util.divn_filter(avg, ctf_2_sum)
            else:
                tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = EMData(nx, ny, 1, False)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)

            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)
            tavg = fft(tavg)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)

        sx_sum = 0
        sy_sum = 0
        if search_rng > 0: nwx = 2 * search_rng + 1
        else: nwx = nx

        if search_rng_y > 0: nwy = 2 * search_rng_y + 1
        else: nwy = ny

        not_zero = 0
        for im in range(len(data)):
            if oneDx:
                ctx = Util.window(ccf(data[im], tavg), nwx, 1)
                p1 = peak_search(ctx)
                p1_x = -int(p1[0][3])
                ishift_x[im] = p1_x
                sx_sum += p1_x
            else:
                p1 = peak_search(Util.window(ccf(data[im], tavg), nwx, nwy))
                p1_x = -int(p1[0][4])
                p1_y = -int(p1[0][5])
                ishift_x[im] = p1_x
                ishift_y[im] = p1_y
                sx_sum += p1_x
                sy_sum += p1_y

            if not_zero == 0:
                if (not (ishift_x[im] == 0.0)) or (not (ishift_y[im] == 0.0)):
                    not_zero = 1

        sx_sum = mpi_reduce(sx_sum, 1, MPI_INT, MPI_SUM, main_node,
                            MPI_COMM_WORLD)

        if not oneDx:
            sy_sum = mpi_reduce(sy_sum, 1, MPI_INT, MPI_SUM, main_node,
                                MPI_COMM_WORLD)

        if myid == main_node:
            sx_sum_total = int(sx_sum[0])
            if not oneDx:
                sy_sum_total = int(sy_sum[0])
        else:
            sx_sum_total = 0
            sy_sum_total = 0

        sx_sum_total = bcast_number_to_all(sx_sum_total, source_node=main_node)

        if not oneDx:
            sy_sum_total = bcast_number_to_all(sy_sum_total,
                                               source_node=main_node)

        sx_ave = round(float(sx_sum_total) / nima)
        sy_ave = round(float(sy_sum_total) / nima)
        for im in range(len(data)):
            p1_x = ishift_x[im] - sx_ave
            p1_y = ishift_y[im] - sy_ave
            params2 = {
                "filter_type": Processor.fourier_filter_types.SHIFT,
                "x_shift": p1_x,
                "y_shift": p1_y,
                "z_shift": 0.0
            }
            data[im] = Processor.EMFourierFilter(data[im], params2)
            shift_x[im] += p1_x
            shift_y[im] += p1_y
        # stop if all shifts are zero
        not_zero = mpi_reduce(not_zero, 1, MPI_INT, MPI_SUM, main_node,
                              MPI_COMM_WORLD)
        if myid == main_node:
            not_zero_all = int(not_zero[0])
        else:
            not_zero_all = 0
        not_zero_all = bcast_number_to_all(not_zero_all, source_node=main_node)

        if myid == main_node:
            print_msg("Time of iteration = %12.2f\n" % (time() - start_time))
            start_time = time()

        if not_zero_all == 0: break

    #for im in xrange(len(data)): data[im] = fft(data[im])  This should not be required as only header information is used
    # combine shifts found with the original parameters
    for im in range(len(data)):
        t0 = init_params[im]
        t1 = Transform()
        t1.set_params({
            "type": "2D",
            "alpha": 0,
            "scale": t0.get_scale(),
            "mirror": 0,
            "tx": shift_x[im],
            "ty": shift_y[im]
        })
        # combine t0 and t1
        tt = t1 * t0
        data[im].set_attr("xform.align2d", tt)

    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi_barrier(MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from utilities import file_type
        if (file_type(stack) == "bdb"):
            from utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)
        else:
            from utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, image_start,
                           image_end, number_of_proc)

    else:
        send_attr_dict(main_node, data, par_str, image_start, image_end)
    if myid == main_node: print_end_msg("shiftali_MPI")
コード例 #37
0
ファイル: sx3dvariability.py プロジェクト: jkaelber/eman2
def main():
    def params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror):
        if mirror:
            m = 1
            alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0,
                                                       540.0 - psi, 0, 0, 1.0)
        else:
            m = 0
            alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0,
                                                       360.0 - psi, 0, 0, 1.0)
        return alpha, sx, sy, m

    progname = os.path.basename(sys.argv[0])
    usage = progname + " prj_stack  --ave2D= --var2D=  --ave3D= --var3D= --img_per_grp= --fl=0.2 --aa=0.1  --sym=symmetry --CTF"
    parser = OptionParser(usage, version=SPARXVERSION)

    parser.add_option("--ave2D",
                      type="string",
                      default=False,
                      help="write to the disk a stack of 2D averages")
    parser.add_option("--var2D",
                      type="string",
                      default=False,
                      help="write to the disk a stack of 2D variances")
    parser.add_option("--ave3D",
                      type="string",
                      default=False,
                      help="write to the disk reconstructed 3D average")
    parser.add_option("--var3D",
                      type="string",
                      default=False,
                      help="compute 3D variability (time consuming!)")
    parser.add_option("--img_per_grp",
                      type="int",
                      default=10,
                      help="number of neighbouring projections")
    parser.add_option("--no_norm",
                      action="store_true",
                      default=False,
                      help="do not use normalization")
    parser.add_option("--radius",
                      type="int",
                      default=-1,
                      help="radius for 3D variability")
    parser.add_option("--npad",
                      type="int",
                      default=2,
                      help="number of time to pad the original images")
    parser.add_option("--sym", type="string", default="c1", help="symmetry")
    parser.add_option("--fl",
                      type="float",
                      default=0.0,
                      help="stop-band frequency (Default - no filtration)")
    parser.add_option("--aa",
                      type="float",
                      default=0.0,
                      help="fall off of the filter (Default - no filtration)")
    parser.add_option("--CTF",
                      action="store_true",
                      default=False,
                      help="use CFT correction")
    parser.add_option("--VERBOSE",
                      action="store_true",
                      default=False,
                      help="Long output for debugging")
    #parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")
    #parser.add_option("--radiuspca", 	type="int"         ,	default=-1   ,				help="radius for PCA" )
    #parser.add_option("--iter", 		type="int"         ,	default=40   ,				help="maximum number of iterations (stop criterion of reconstruction process)" )
    #parser.add_option("--abs", 		type="float"   ,        default=0.0  ,				help="minimum average absolute change of voxels' values (stop criterion of reconstruction process)" )
    #parser.add_option("--squ", 		type="float"   ,	    default=0.0  ,				help="minimum average squared change of voxels' values (stop criterion of reconstruction process)" )
    parser.add_option(
        "--VAR",
        action="store_true",
        default=False,
        help="stack on input consists of 2D variances (Default False)")
    parser.add_option(
        "--decimate",
        type="float",
        default=1.0,
        help="image decimate rate, a number large than 1. default is 1")
    parser.add_option(
        "--window",
        type="int",
        default=0,
        help=
        "reduce images to a small image size without changing pixel_size. Default value is zero."
    )
    #parser.add_option("--SND",			action="store_true",	default=False,				help="compute squared normalized differences (Default False)")
    parser.add_option(
        "--nvec",
        type="int",
        default=0,
        help="number of eigenvectors, default = 0 meaning no PCA calculated")
    parser.add_option(
        "--symmetrize",
        action="store_true",
        default=False,
        help="Prepare input stack for handling symmetry (Default False)")

    (options, args) = parser.parse_args()
    #####
    from mpi import mpi_init, mpi_comm_rank, mpi_comm_size, mpi_recv, MPI_COMM_WORLD
    from mpi import mpi_barrier, mpi_reduce, mpi_bcast, mpi_send, MPI_FLOAT, MPI_SUM, MPI_INT, MPI_MAX
    from applications import MPI_start_end
    from reconstruction import recons3d_em, recons3d_em_MPI
    from reconstruction import recons3d_4nn_MPI, recons3d_4nn_ctf_MPI
    from utilities import print_begin_msg, print_end_msg, print_msg
    from utilities import read_text_row, get_image, get_im
    from utilities import bcast_EMData_to_all, bcast_number_to_all
    from utilities import get_symt

    #  This is code for handling symmetries by the above program.  To be incorporated. PAP 01/27/2015

    from EMAN2db import db_open_dict

    # Set up global variables related to bdb cache
    if global_def.CACHE_DISABLE:
        from utilities import disable_bdb_cache
        disable_bdb_cache()

    # Set up global variables related to ERROR function
    global_def.BATCH = True

    # detect if program is running under MPI
    RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in os.environ
    if RUNNING_UNDER_MPI:
        global_def.MPI = True

    if options.symmetrize:
        if RUNNING_UNDER_MPI:
            try:
                sys.argv = mpi_init(len(sys.argv), sys.argv)
                try:
                    number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
                    if (number_of_proc > 1):
                        ERROR(
                            "Cannot use more than one CPU for symmetry prepration",
                            "sx3dvariability", 1)
                except:
                    pass
            except:
                pass

        #  Input
        #instack = "Clean_NORM_CTF_start_wparams.hdf"
        #instack = "bdb:data"
        instack = args[0]
        sym = options.sym
        if (sym == "c1"):
            ERROR("Thre is no need to symmetrize stack for C1 symmetry",
                  "sx3dvariability", 1)

        if (instack[:4] != "bdb:"):
            stack = "bdb:data"
            delete_bdb(stack)
            junk = cmdexecute("sxcpy.py  " + instack + "  " + stack)
        else:
            stack = instack

        qt = EMUtil.get_all_attributes(stack, 'xform.projection')

        na = len(qt)
        ts = get_symt(sym)
        ks = len(ts)
        angsa = [None] * na
        for k in xrange(ks):
            delete_bdb("bdb:Q%1d" % k)
            junk = cmdexecute("e2bdb.py  " + stack +
                              "  --makevstack=bdb:Q%1d" % k)
            DB = db_open_dict("bdb:Q%1d" % k)
            for i in xrange(na):
                ut = qt[i] * ts[k]
                DB.set_attr(i, "xform.projection", ut)
                #bt = ut.get_params("spider")
                #angsa[i] = [round(bt["phi"],3)%360.0, round(bt["theta"],3)%360.0, bt["psi"], -bt["tx"], -bt["ty"]]
            #write_text_row(angsa, 'ptsma%1d.txt'%k)
            #junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            #junk = cmdexecute("sxheader.py  bdb:Q%1d  --params=xform.projection  --import=ptsma%1d.txt"%(k,k))
            DB.close()
        delete_bdb("bdb:sdata")
        junk = cmdexecute("e2bdb.py . --makevstack=bdb:sdata --filt=Q")
        #junk = cmdexecute("ls  EMAN2DB/sdata*")
        a = get_im("bdb:sdata")
        a.set_attr("variabilitysymmetry", sym)
        a.write_image("bdb:sdata")

    else:

        sys.argv = mpi_init(len(sys.argv), sys.argv)
        myid = mpi_comm_rank(MPI_COMM_WORLD)
        number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
        main_node = 0

        if len(args) == 1:
            stack = args[0]
        else:
            print("usage: " + usage)
            print("Please run '" + progname + " -h' for detailed options")
            return 1

        t0 = time()
        # obsolete flags
        options.MPI = True
        options.nvec = 0
        options.radiuspca = -1
        options.iter = 40
        options.abs = 0.0
        options.squ = 0.0

        if options.fl > 0.0 and options.aa == 0.0:
            ERROR("Fall off has to be given for the low-pass filter",
                  "sx3dvariability", 1, myid)
        if options.VAR and options.SND:
            ERROR("Only one of var and SND can be set!", "sx3dvariability",
                  myid)
            exit()
        if options.VAR and (options.ave2D or options.ave3D or options.var2D):
            ERROR(
                "When VAR is set, the program cannot output ave2D, ave3D or var2D",
                "sx3dvariability", 1, myid)
            exit()
        #if options.SND and (options.ave2D or options.ave3D):
        #	ERROR("When SND is set, the program cannot output ave2D or ave3D", "sx3dvariability", 1, myid)
        #	exit()
        if options.nvec > 0:
            ERROR("PCA option not implemented", "sx3dvariability", 1, myid)
            exit()
        if options.nvec > 0 and options.ave3D == None:
            ERROR("When doing PCA analysis, one must set ave3D",
                  "sx3dvariability",
                  myid=myid)
            exit()
        import string
        options.sym = options.sym.lower()

        # if global_def.CACHE_DISABLE:
        # 	from utilities import disable_bdb_cache
        # 	disable_bdb_cache()
        # global_def.BATCH = True

        if myid == main_node:
            print_begin_msg("sx3dvariability")
            print_msg("%-70s:  %s\n" % ("Input stack", stack))

        img_per_grp = options.img_per_grp
        nvec = options.nvec
        radiuspca = options.radiuspca

        symbaselen = 0
        if myid == main_node:
            nima = EMUtil.get_image_count(stack)
            img = get_image(stack)
            nx = img.get_xsize()
            ny = img.get_ysize()
            if options.sym != "c1":
                imgdata = get_im(stack)
                try:
                    i = imgdata.get_attr("variabilitysymmetry")
                    if (i != options.sym):
                        ERROR(
                            "The symmetry provided does not agree with the symmetry of the input stack",
                            "sx3dvariability",
                            myid=myid)
                except:
                    ERROR(
                        "Input stack is not prepared for symmetry, please follow instructions",
                        "sx3dvariability",
                        myid=myid)
                from utilities import get_symt
                i = len(get_symt(options.sym))
                if ((nima / i) * i != nima):
                    ERROR(
                        "The length of the input stack is incorrect for symmetry processing",
                        "sx3dvariability",
                        myid=myid)
                symbaselen = nima / i
            else:
                symbaselen = nima
        else:
            nima = 0
            nx = 0
            ny = 0
        nima = bcast_number_to_all(nima)
        nx = bcast_number_to_all(nx)
        ny = bcast_number_to_all(ny)
        Tracker = {}
        Tracker["total_stack"] = nima
        if options.decimate == 1.:
            if options.window != 0:
                nx = options.window
                ny = options.window
        else:
            if options.window == 0:
                nx = int(nx / options.decimate)
                ny = int(ny / options.decimate)
            else:
                nx = int(options.window / options.decimate)
                ny = nx
        Tracker["nx"] = nx
        Tracker["ny"] = ny
        Tracker["nz"] = nx
        symbaselen = bcast_number_to_all(symbaselen)
        if radiuspca == -1: radiuspca = nx / 2 - 2

        if myid == main_node:
            print_msg("%-70s:  %d\n" % ("Number of projection", nima))

        img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
        """
		if options.SND:
			from projection		import prep_vol, prgs
			from statistics		import im_diff
			from utilities		import get_im, model_circle, get_params_proj, set_params_proj
			from utilities		import get_ctf, generate_ctf
			from filter			import filt_ctf
		
			imgdata = EMData.read_images(stack, range(img_begin, img_end))

			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)

			bcast_EMData_to_all(vol, myid)
			volft, kb = prep_vol(vol)

			mask = model_circle(nx/2-2, nx, ny)
			varList = []
			for i in xrange(img_begin, img_end):
				phi, theta, psi, s2x, s2y = get_params_proj(imgdata[i-img_begin])
				ref_prj = prgs(volft, kb, [phi, theta, psi, -s2x, -s2y])
				if options.CTF:
					ctf_params = get_ctf(imgdata[i-img_begin])
					ref_prj = filt_ctf(ref_prj, generate_ctf(ctf_params))
				diff, A, B = im_diff(ref_prj, imgdata[i-img_begin], mask)
				diff2 = diff*diff
				set_params_proj(diff2, [phi, theta, psi, s2x, s2y])
				varList.append(diff2)
			mpi_barrier(MPI_COMM_WORLD)
		"""
        if options.VAR:
            #varList   = EMData.read_images(stack, range(img_begin, img_end))
            varList = []
            this_image = EMData()
            for index_of_particle in xrange(img_begin, img_end):
                this_image.read_image(stack, index_of_particle)
                varList.append(
                    image_decimate_window_xform_ctf(this_image,
                                                    options.decimate,
                                                    options.window,
                                                    options.CTF))
        else:
            from utilities import bcast_number_to_all, bcast_list_to_all, send_EMData, recv_EMData
            from utilities import set_params_proj, get_params_proj, params_3D_2D, get_params2D, set_params2D, compose_transform2
            from utilities import model_blank, nearest_proj, model_circle
            from applications import pca
            from statistics import avgvar, avgvar_ctf, ccc
            from filter import filt_tanl
            from morphology import threshold, square_root
            from projection import project, prep_vol, prgs
            from sets import Set

            if myid == main_node:
                t1 = time()
                proj_angles = []
                aveList = []
                tab = EMUtil.get_all_attributes(stack, 'xform.projection')
                for i in xrange(nima):
                    t = tab[i].get_params('spider')
                    phi = t['phi']
                    theta = t['theta']
                    psi = t['psi']
                    x = theta
                    if x > 90.0: x = 180.0 - x
                    x = x * 10000 + psi
                    proj_angles.append([x, t['phi'], t['theta'], t['psi'], i])
                t2 = time()
                print_msg("%-70s:  %d\n" %
                          ("Number of neighboring projections", img_per_grp))
                print_msg("...... Finding neighboring projections\n")
                if options.VERBOSE:
                    print "Number of images per group: ", img_per_grp
                    print "Now grouping projections"
                proj_angles.sort()
            proj_angles_list = [0.0] * (nima * 4)
            if myid == main_node:
                for i in xrange(nima):
                    proj_angles_list[i * 4] = proj_angles[i][1]
                    proj_angles_list[i * 4 + 1] = proj_angles[i][2]
                    proj_angles_list[i * 4 + 2] = proj_angles[i][3]
                    proj_angles_list[i * 4 + 3] = proj_angles[i][4]
            proj_angles_list = bcast_list_to_all(proj_angles_list, myid,
                                                 main_node)
            proj_angles = []
            for i in xrange(nima):
                proj_angles.append([
                    proj_angles_list[i * 4], proj_angles_list[i * 4 + 1],
                    proj_angles_list[i * 4 + 2],
                    int(proj_angles_list[i * 4 + 3])
                ])
            del proj_angles_list
            proj_list, mirror_list = nearest_proj(proj_angles, img_per_grp,
                                                  range(img_begin, img_end))

            all_proj = Set()
            for im in proj_list:
                for jm in im:
                    all_proj.add(proj_angles[jm][3])

            all_proj = list(all_proj)
            if options.VERBOSE:
                print "On node %2d, number of images needed to be read = %5d" % (
                    myid, len(all_proj))

            index = {}
            for i in xrange(len(all_proj)):
                index[all_proj[i]] = i
            mpi_barrier(MPI_COMM_WORLD)

            if myid == main_node:
                print_msg("%-70s:  %.2f\n" %
                          ("Finding neighboring projections lasted [s]",
                           time() - t2))
                print_msg("%-70s:  %d\n" %
                          ("Number of groups processed on the main node",
                           len(proj_list)))
                if options.VERBOSE:
                    print "Grouping projections took: ", (time() -
                                                          t2) / 60, "[min]"
                    print "Number of groups on main node: ", len(proj_list)
            mpi_barrier(MPI_COMM_WORLD)

            if myid == main_node:
                print_msg("...... calculating the stack of 2D variances \n")
                if options.VERBOSE:
                    print "Now calculating the stack of 2D variances"

            proj_params = [0.0] * (nima * 5)
            aveList = []
            varList = []
            if nvec > 0:
                eigList = [[] for i in xrange(nvec)]

            if options.VERBOSE:
                print "Begin to read images on processor %d" % (myid)
            ttt = time()
            #imgdata = EMData.read_images(stack, all_proj)
            imgdata = []
            for index_of_proj in xrange(len(all_proj)):
                img = EMData()
                img.read_image(stack, all_proj[index_of_proj])
                dmg = image_decimate_window_xform_ctf(img, options.decimate,
                                                      options.window,
                                                      options.CTF)
                #print dmg.get_xsize(), "init"
                imgdata.append(dmg)
            if options.VERBOSE:
                print "Reading images on processor %d done, time = %.2f" % (
                    myid, time() - ttt)
                print "On processor %d, we got %d images" % (myid,
                                                             len(imgdata))
            mpi_barrier(MPI_COMM_WORLD)
            '''	
			imgdata2 = EMData.read_images(stack, range(img_begin, img_end))
			if options.fl > 0.0:
				for k in xrange(len(imgdata2)):
					imgdata2[k] = filt_tanl(imgdata2[k], options.fl, options.aa)
			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata2, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata2, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			if myid == main_node:
				vol.write_image("vol_ctf.hdf")
				print_msg("Writing to the disk volume reconstructed from averages as		:  %s\n"%("vol_ctf.hdf"))
			del vol, imgdata2
			mpi_barrier(MPI_COMM_WORLD)
			'''
            from applications import prepare_2d_forPCA
            from utilities import model_blank
            for i in xrange(len(proj_list)):
                ki = proj_angles[proj_list[i][0]][3]
                if ki >= symbaselen: continue
                mi = index[ki]
                phiM, thetaM, psiM, s2xM, s2yM = get_params_proj(imgdata[mi])

                grp_imgdata = []
                for j in xrange(img_per_grp):
                    mj = index[proj_angles[proj_list[i][j]][3]]
                    phi, theta, psi, s2x, s2y = get_params_proj(imgdata[mj])
                    alpha, sx, sy, mirror = params_3D_2D_NEW(
                        phi, theta, psi, s2x, s2y, mirror_list[i][j])
                    if thetaM <= 90:
                        if mirror == 0:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, phiM - phi, 0.0, 0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, 180 - (phiM - phi), 0.0,
                                0.0, 1.0)
                    else:
                        if mirror == 0:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, -(phiM - phi), 0.0, 0.0,
                                1.0)
                        else:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, -(180 - (phiM - phi)), 0.0,
                                0.0, 1.0)
                    set_params2D(imgdata[mj], [alpha, sx, sy, mirror, 1.0])
                    grp_imgdata.append(imgdata[mj])
                    #print grp_imgdata[j].get_xsize(), imgdata[mj].get_xsize()

                if not options.no_norm:
                    #print grp_imgdata[j].get_xsize()
                    mask = model_circle(nx / 2 - 2, nx, nx)
                    for k in xrange(img_per_grp):
                        ave, std, minn, maxx = Util.infomask(
                            grp_imgdata[k], mask, False)
                        grp_imgdata[k] -= ave
                        grp_imgdata[k] /= std
                    del mask

                if options.fl > 0.0:
                    from filter import filt_ctf, filt_table
                    from fundamentals import fft, window2d
                    nx2 = 2 * nx
                    ny2 = 2 * ny
                    if options.CTF:
                        from utilities import pad
                        for k in xrange(img_per_grp):
                            grp_imgdata[k] = window2d(
                                fft(
                                    filt_tanl(
                                        filt_ctf(
                                            fft(
                                                pad(grp_imgdata[k], nx2, ny2,
                                                    1, 0.0)),
                                            grp_imgdata[k].get_attr("ctf"),
                                            binary=1), options.fl,
                                        options.aa)), nx, ny)
                            #grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
                            #grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
                    else:
                        for k in xrange(img_per_grp):
                            grp_imgdata[k] = filt_tanl(grp_imgdata[k],
                                                       options.fl, options.aa)
                            #grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
                            #grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
                else:
                    from utilities import pad, read_text_file
                    from filter import filt_ctf, filt_table
                    from fundamentals import fft, window2d
                    nx2 = 2 * nx
                    ny2 = 2 * ny
                    if options.CTF:
                        from utilities import pad
                        for k in xrange(img_per_grp):
                            grp_imgdata[k] = window2d(
                                fft(
                                    filt_ctf(fft(
                                        pad(grp_imgdata[k], nx2, ny2, 1, 0.0)),
                                             grp_imgdata[k].get_attr("ctf"),
                                             binary=1)), nx, ny)
                            #grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
                            #grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
                '''
				if i < 10 and myid == main_node:
					for k in xrange(10):
						grp_imgdata[k].write_image("grp%03d.hdf"%i, k)
				'''
                """
				if myid == main_node and i==0:
					for pp in xrange(len(grp_imgdata)):
						grp_imgdata[pp].write_image("pp.hdf", pp)
				"""
                ave, grp_imgdata = prepare_2d_forPCA(grp_imgdata)
                """
				if myid == main_node and i==0:
					for pp in xrange(len(grp_imgdata)):
						grp_imgdata[pp].write_image("qq.hdf", pp)
				"""

                var = model_blank(nx, ny)
                for q in grp_imgdata:
                    Util.add_img2(var, q)
                Util.mul_scalar(var, 1.0 / (len(grp_imgdata) - 1))
                # Switch to std dev
                var = square_root(threshold(var))
                #if options.CTF:	ave, var = avgvar_ctf(grp_imgdata, mode="a")
                #else:	            ave, var = avgvar(grp_imgdata, mode="a")
                """
				if myid == main_node:
					ave.write_image("avgv.hdf",i)
					var.write_image("varv.hdf",i)
				"""

                set_params_proj(ave, [phiM, thetaM, 0.0, 0.0, 0.0])
                set_params_proj(var, [phiM, thetaM, 0.0, 0.0, 0.0])

                aveList.append(ave)
                varList.append(var)

                if options.VERBOSE:
                    print "%5.2f%% done on processor %d" % (
                        i * 100.0 / len(proj_list), myid)
                if nvec > 0:
                    eig = pca(input_stacks=grp_imgdata,
                              subavg="",
                              mask_radius=radiuspca,
                              nvec=nvec,
                              incore=True,
                              shuffle=False,
                              genbuf=True)
                    for k in xrange(nvec):
                        set_params_proj(eig[k], [phiM, thetaM, 0.0, 0.0, 0.0])
                        eigList[k].append(eig[k])
                    """
					if myid == 0 and i == 0:
						for k in xrange(nvec):
							eig[k].write_image("eig.hdf", k)
					"""

            del imgdata
            #  To this point, all averages, variances, and eigenvectors are computed

            if options.ave2D:
                from fundamentals import fpol
                if myid == main_node:
                    km = 0
                    for i in xrange(number_of_proc):
                        if i == main_node:
                            for im in xrange(len(aveList)):
                                aveList[im].write_image(options.ave2D, km)
                                km += 1
                        else:
                            nl = mpi_recv(1, MPI_INT, i,
                                          SPARX_MPI_TAG_UNIVERSAL,
                                          MPI_COMM_WORLD)
                            nl = int(nl[0])
                            for im in xrange(nl):
                                ave = recv_EMData(i, im + i + 70000)
                                """
								nm = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								nm = int(nm[0])
								members = mpi_recv(nm, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('members', map(int, members))
								members = mpi_recv(nm, MPI_FLOAT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('pix_err', map(float, members))
								members = mpi_recv(3, MPI_FLOAT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('refprojdir', map(float, members))
								"""
                                tmpvol = fpol(ave, Tracker["nx"],
                                              Tracker["nx"], 1)
                                tmpvol.write_image(options.ave2D, km)
                                km += 1
                else:
                    mpi_send(len(aveList), 1, MPI_INT, main_node,
                             SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    for im in xrange(len(aveList)):
                        send_EMData(aveList[im], main_node, im + myid + 70000)
                        """
						members = aveList[im].get_attr('members')
						mpi_send(len(members), 1, MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						mpi_send(members, len(members), MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						members = aveList[im].get_attr('pix_err')
						mpi_send(members, len(members), MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						try:
							members = aveList[im].get_attr('refprojdir')
							mpi_send(members, 3, MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						except:
							mpi_send([-999.0,-999.0,-999.0], 3, MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						"""

            if options.ave3D:
                from fundamentals import fpol
                if options.VERBOSE:
                    print "Reconstructing 3D average volume"
                ave3D = recons3d_4nn_MPI(myid,
                                         aveList,
                                         symmetry=options.sym,
                                         npad=options.npad)
                bcast_EMData_to_all(ave3D, myid)
                if myid == main_node:
                    ave3D = fpol(ave3D, Tracker["nx"], Tracker["nx"],
                                 Tracker["nx"])
                    ave3D.write_image(options.ave3D)
                    print_msg("%-70s:  %s\n" % (
                        "Writing to the disk volume reconstructed from averages as",
                        options.ave3D))
            del ave, var, proj_list, stack, phi, theta, psi, s2x, s2y, alpha, sx, sy, mirror, aveList

            if nvec > 0:
                for k in xrange(nvec):
                    if options.VERBOSE:
                        print "Reconstruction eigenvolumes", k
                    cont = True
                    ITER = 0
                    mask2d = model_circle(radiuspca, nx, nx)
                    while cont:
                        #print "On node %d, iteration %d"%(myid, ITER)
                        eig3D = recons3d_4nn_MPI(myid,
                                                 eigList[k],
                                                 symmetry=options.sym,
                                                 npad=options.npad)
                        bcast_EMData_to_all(eig3D, myid, main_node)
                        if options.fl > 0.0:
                            eig3D = filt_tanl(eig3D, options.fl, options.aa)
                        if myid == main_node:
                            eig3D.write_image("eig3d_%03d.hdf" % k, ITER)
                        Util.mul_img(eig3D,
                                     model_circle(radiuspca, nx, nx, nx))
                        eig3Df, kb = prep_vol(eig3D)
                        del eig3D
                        cont = False
                        icont = 0
                        for l in xrange(len(eigList[k])):
                            phi, theta, psi, s2x, s2y = get_params_proj(
                                eigList[k][l])
                            proj = prgs(eig3Df, kb,
                                        [phi, theta, psi, s2x, s2y])
                            cl = ccc(proj, eigList[k][l], mask2d)
                            if cl < 0.0:
                                icont += 1
                                cont = True
                                eigList[k][l] *= -1.0
                        u = int(cont)
                        u = mpi_reduce([u], 1, MPI_INT, MPI_MAX, main_node,
                                       MPI_COMM_WORLD)
                        icont = mpi_reduce([icont], 1, MPI_INT, MPI_SUM,
                                           main_node, MPI_COMM_WORLD)

                        if myid == main_node:
                            u = int(u[0])
                            print " Eigenvector: ", k, " number changed ", int(
                                icont[0])
                        else:
                            u = 0
                        u = bcast_number_to_all(u, main_node)
                        cont = bool(u)
                        ITER += 1

                    del eig3Df, kb
                    mpi_barrier(MPI_COMM_WORLD)
                del eigList, mask2d

            if options.ave3D: del ave3D
            if options.var2D:
                from fundamentals import fpol
                if myid == main_node:
                    km = 0
                    for i in xrange(number_of_proc):
                        if i == main_node:
                            for im in xrange(len(varList)):
                                tmpvol = fpol(varList[im], Tracker["nx"],
                                              Tracker["nx"], 1)
                                tmpvol.write_image(options.var2D, km)
                                km += 1
                        else:
                            nl = mpi_recv(1, MPI_INT, i,
                                          SPARX_MPI_TAG_UNIVERSAL,
                                          MPI_COMM_WORLD)
                            nl = int(nl[0])
                            for im in xrange(nl):
                                ave = recv_EMData(i, im + i + 70000)
                                tmpvol = fpol(ave, Tracker["nx"],
                                              Tracker["nx"], 1)
                                tmpvol.write_image(options.var2D, km)
                                km += 1
                else:
                    mpi_send(len(varList), 1, MPI_INT, main_node,
                             SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    for im in xrange(len(varList)):
                        send_EMData(varList[im], main_node, im + myid +
                                    70000)  #  What with the attributes??

            mpi_barrier(MPI_COMM_WORLD)

        if options.var3D:
            if myid == main_node and options.VERBOSE:
                print "Reconstructing 3D variability volume"

            t6 = time()
            radiusvar = options.radius
            if (radiusvar < 0): radiusvar = nx // 2 - 3
            res = recons3d_4nn_MPI(myid,
                                   varList,
                                   symmetry=options.sym,
                                   npad=options.npad)
            #res = recons3d_em_MPI(varList, vol_stack, options.iter, radiusvar, options.abs, True, options.sym, options.squ)
            if myid == main_node:
                from fundamentals import fpol
                res = fpol(res, Tracker["nx"], Tracker["nx"], Tracker["nx"])
                res.write_image(options.var3D)

            if myid == main_node:
                print_msg(
                    "%-70s:  %.2f\n" %
                    ("Reconstructing 3D variability took [s]", time() - t6))
                if options.VERBOSE:
                    print "Reconstruction took: %.2f [min]" % (
                        (time() - t6) / 60)

            if myid == main_node:
                print_msg(
                    "%-70s:  %.2f\n" %
                    ("Total time for these computations [s]", time() - t0))
                if options.VERBOSE:
                    print "Total time for these computations: %.2f [min]" % (
                        (time() - t0) / 60)
                print_end_msg("sx3dvariability")

        from mpi import mpi_finalize
        mpi_finalize()

        if RUNNING_UNDER_MPI:
            global_def.MPI = False

        global_def.BATCH = False
コード例 #38
0
ファイル: sxshiftali.py プロジェクト: cpsemmens/eman2
def helicalshiftali_MPI(stack, maskfile=None, maxit=100, CTF=False, snr=1.0, Fourvar=False, search_rng=-1):
	from applications import MPI_start_end
	from utilities    import model_circle, model_blank, get_image, peak_search, get_im, pad
	from utilities    import reduce_EMData_to_root, bcast_EMData_to_all, send_attr_dict, file_type, bcast_number_to_all, bcast_list_to_all
	from statistics   import varf2d_MPI
	from fundamentals import fft, ccf, rot_shift3D, rot_shift2D, fshift
	from utilities    import get_params2D, set_params2D, chunks_distribution
	from utilities    import print_msg, print_begin_msg, print_end_msg
	import os
	import sys
	from mpi 	  	  import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	from mpi 	  	  import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv
	from mpi 	  	  import MPI_SUM, MPI_FLOAT, MPI_INT
	from time         import time	
	from pixel_error  import ordersegments
	from math         import sqrt, atan2, tan, pi
	
	nproc = mpi_comm_size(MPI_COMM_WORLD)
	myid = mpi_comm_rank(MPI_COMM_WORLD)
	main_node = 0
		
	ftp = file_type(stack)

	if myid == main_node:
		print_begin_msg("helical-shiftali_MPI")

	max_iter=int(maxit)
	if( myid == main_node):
		infils = EMUtil.get_all_attributes(stack, "filament")
		ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
		filaments = ordersegments(infils, ptlcoords)
		total_nfils = len(filaments)
		inidl = [0]*total_nfils
		for i in xrange(total_nfils):  inidl[i] = len(filaments[i])
		linidl = sum(inidl)
		nima = linidl
		tfilaments = []
		for i in xrange(total_nfils):  tfilaments += filaments[i]
		del filaments
	else:
		total_nfils = 0
		linidl = 0
	total_nfils = bcast_number_to_all(total_nfils, source_node = main_node)
	if myid != main_node:
		inidl = [-1]*total_nfils
	inidl = bcast_list_to_all(inidl, myid, source_node = main_node)
	linidl = bcast_number_to_all(linidl, source_node = main_node)
	if myid != main_node:
		tfilaments = [-1]*linidl
	tfilaments = bcast_list_to_all(tfilaments, myid, source_node = main_node)
	filaments = []
	iendi = 0
	for i in xrange(total_nfils):
		isti = iendi
		iendi = isti+inidl[i]
		filaments.append(tfilaments[isti:iendi])
	del tfilaments,inidl

	if myid == main_node:
		print_msg( "total number of filaments: %d"%total_nfils)
	if total_nfils< nproc:
		ERROR('number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'%(nproc, total_nfils), "ehelix_MPI", 1,myid)

	#  balanced load
	temp = chunks_distribution([[len(filaments[i]), i] for i in xrange(len(filaments))], nproc)[myid:myid+1][0]
	filaments = [filaments[temp[i][1]] for i in xrange(len(temp))]
	nfils     = len(filaments)

	#filaments = [[0,1]]
	#print "filaments",filaments
	list_of_particles = []
	indcs = []
	k = 0
	for i in xrange(nfils):
		list_of_particles += filaments[i]
		k1 = k+len(filaments[i])
		indcs.append([k,k1])
		k = k1
	data = EMData.read_images(stack, list_of_particles)
	ldata = len(data)
	print "ldata=", ldata
	nx = data[0].get_xsize()
	ny = data[0].get_ysize()
	if maskfile == None:
		mrad = min(nx, ny)//2-2
		mask = pad( model_blank(2*mrad+1, ny, 1, 1.0), nx, ny, 1, 0.0)
	else:
		mask = get_im(maskfile)

	# apply initial xform.align2d parameters stored in header
	init_params = []
	for im in xrange(ldata):
		t = data[im].get_attr('xform.align2d')
		init_params.append(t)
		p = t.get_params("2d")
		data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'], p['mirror'], p['scale'])

	if CTF:
		from filter import filt_ctf
		from morphology   import ctf_img
		ctf_abs_sum = EMData(nx, ny, 1, False)
		ctf_2_sum = EMData(nx, ny, 1, False)
	else:
		ctf_2_sum = None
		ctf_abs_sum = None



	from utilities import info

	for im in xrange(ldata):
		data[im].set_attr('ID', list_of_particles[im])
		st = Util.infomask(data[im], mask, False)
		data[im] -= st[0]
		if CTF:
			ctf_params = data[im].get_attr("ctf")
			qctf = data[im].get_attr("ctf_applied")
			if qctf == 0:
				data[im] = filt_ctf(fft(data[im]), ctf_params)
				data[im].set_attr('ctf_applied', 1)
			elif qctf != 1:
				ERROR('Incorrectly set qctf flag', "helicalshiftali_MPI", 1,myid)
			ctfimg = ctf_img(nx, ctf_params, ny=ny)
			Util.add_img2(ctf_2_sum, ctfimg)
			Util.add_img_abs(ctf_abs_sum, ctfimg)
		else:  data[im] = fft(data[im])

	del list_of_particles		

	if CTF:
		reduce_EMData_to_root(ctf_2_sum, myid, main_node)
		reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
	if CTF:
		if myid != main_node:
			del ctf_2_sum
			del ctf_abs_sum
		else:
			temp = EMData(nx, ny, 1, False)
			tsnr = 1./snr
			for i in xrange(0,nx+2,2):
				for j in xrange(ny):
					temp.set_value_at(i,j,tsnr)
					temp.set_value_at(i+1,j,0.0)
			#info(ctf_2_sum)
			Util.add_img(ctf_2_sum, temp)
			#info(ctf_2_sum)
			del temp

	total_iter = 0
	shift_x = [0.0]*ldata

	for Iter in xrange(max_iter):
		if myid == main_node:
			start_time = time()
			print_msg("Iteration #%4d\n"%(total_iter))
		total_iter += 1
		avg = EMData(nx, ny, 1, False)
		for im in xrange(ldata):
			Util.add_img(avg, fshift(data[im], shift_x[im]))

		reduce_EMData_to_root(avg, myid, main_node)

		if myid == main_node:
			if CTF:  tavg = Util.divn_filter(avg, ctf_2_sum)
			else:    tavg = Util.mult_scalar(avg, 1.0/float(nima))
		else:
			tavg = model_blank(nx,ny)

		if Fourvar:
			bcast_EMData_to_all(tavg, myid, main_node)
			vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

		if myid == main_node:
			if Fourvar:
				tavg    = fft(Util.divn_img(fft(tavg), vav))
				vav_r	= Util.pack_complex_to_real(vav)
			# normalize and mask tavg in real space
			tavg = fft(tavg)
			stat = Util.infomask( tavg, mask, False )
			tavg -= stat[0]
			Util.mul_img(tavg, mask)
			tavg.write_image("tavg.hdf",Iter)
			# For testing purposes: shift tavg to some random place and see if the centering is still correct
			#tavg = rot_shift3D(tavg,sx=3,sy=-4)

		if Fourvar:  del vav
		bcast_EMData_to_all(tavg, myid, main_node)
		tavg = fft(tavg)

		sx_sum = 0.0
		nxc = nx//2
		
		for ifil in xrange(nfils):
			"""
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
			# Calculate 1D ccf between each segment and filament average
			nsegms = indcs[ifil][1]-indcs[ifil][0]
			ctx = [None]*nsegms
			pcoords = [None]*nsegms
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				ctx[im-indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx, 1)
				pcoords[im-indcs[ifil][0]] = data[im].get_attr('ptcl_source_coord')
				#ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
				#print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
			# search for best x-shift
			cents = nsegms//2
			
			dst = sqrt(max((pcoords[cents][0] - pcoords[0][0])**2 + (pcoords[cents][1] - pcoords[0][1])**2, (pcoords[cents][0] - pcoords[-1][0])**2 + (pcoords[cents][1] - pcoords[-1][1])**2))
			maxincline = atan2(ny//2-2-float(search_rng),dst)
			kang = int(dst*tan(maxincline)+0.5)
			#print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang
			
			# ## C code for alignment. @ming
 			results = [0.0]*3;
 			results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline, kang, search_rng,nxc)
			sib = int(results[0])
 			bang = results[1]
 			qm = results[2]
			#print qm, sib, bang
			
			# qm = -1.e23	
# 				
# 			for six in xrange(-search_rng, search_rng+1,1):
# 				q0 = ctx[cents].get_value_at(six+nxc)
# 				for incline in xrange(kang+1):
# 					qt = q0
# 					qu = q0
# 					if(kang>0):  tang = tan(maxincline/kang*incline)
# 					else:        tang = 0.0
# 					for kim in xrange(cents+1,nsegms):
# 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
# 						xl = dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
# 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 						xl = -dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 					for kim in xrange(cents):
# 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
# 						xl = -dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 						xl =  dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 					if( qt > qm ):
# 						qm = qt
# 						sib = six
# 						bang = tang
# 					if( qu > qm ):
# 						qm = qu
# 						sib = six
# 						bang = -tang
					#if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
			#print qm,six,sib,bang
			#print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				kim = im-indcs[ifil][0]
				dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
				if(kim < cents):  xl = -dst*bang+sib
				else:             xl =  dst*bang+sib
				shift_x[im] = xl
							
			# Average shift
			sx_sum += shift_x[indcs[ifil][0]+cents]
			
			
		# #print myid,sx_sum,total_nfils
		sx_sum = mpi_reduce(sx_sum, 1, MPI_FLOAT, MPI_SUM, main_node, MPI_COMM_WORLD)
		if myid == main_node:
			sx_sum = float(sx_sum[0])/total_nfils
			print_msg("Average shift  %6.2f\n"%(sx_sum))
		else:
			sx_sum = 0.0
		sx_sum = 0.0
		sx_sum = bcast_number_to_all(sx_sum, source_node = main_node)
		for im in xrange(ldata):
			shift_x[im] -= sx_sum
			#print  "   %3d  %6.3f"%(im,shift_x[im])
		#exit()


			
	# combine shifts found with the original parameters
	for im in xrange(ldata):		
		t1 = Transform()
		##import random
		##shix=random.randint(-10, 10)
		##t1.set_params({"type":"2D","tx":shix})
		t1.set_params({"type":"2D","tx":shift_x[im]})
		# combine t0 and t1
		tt = t1*init_params[im]
		data[im].set_attr("xform.align2d", tt)
	# write out headers and STOP, under MPI writing has to be done sequentially
	mpi_barrier(MPI_COMM_WORLD)
	par_str = ["xform.align2d", "ID"]
	if myid == main_node:
		from utilities import file_type
		if(file_type(stack) == "bdb"):
			from utilities import recv_attr_dict_bdb
			recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata, nproc)
		else:
			from utilities import recv_attr_dict
			recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
	else:           send_attr_dict(main_node, data, par_str, 0, ldata)
	if myid == main_node: print_end_msg("helical-shiftali_MPI")				
コード例 #39
0
ファイル: projection.py プロジェクト: cpsemmens/eman2
def cml_head_log(stack, outdir, delta, ir, ou, lf, hf, rand_seed, maxit, given, flag_weights, trials, ncpu):
	from utilities import print_msg

	# call global var
	global g_anglst, g_n_prj, g_d_psi, g_n_anglst

	print_msg('Input stack                  : %s\n'     % stack)
	print_msg('Number of projections        : %d\n'     % g_n_prj)
	print_msg('Output directory             : %s\n'     % outdir)
	print_msg('Angular step                 : %5.2f\n'  % delta)
	print_msg('Sinogram angle accuracy      : %5.2f\n'  % g_d_psi)
	print_msg('Inner particle radius        : %5.2f\n'  % ir)	
	print_msg('Outer particle radius        : %5.2f\n'  % ou)
	print_msg('Filter, minimum frequency    : %5.3f\n'  % lf)
	print_msg('Filter, maximum frequency    : %5.3f\n'  % hf)
	print_msg('Random seed                  : %i\n'     % rand_seed)
	print_msg('Number of maximum iterations : %d\n'     % maxit)
	print_msg('Start from given orientations: %s\n'     % given)
	print_msg('Number of angles             : %i\n'     % g_n_anglst)
	print_msg('Number of trials             : %i\n'     % trials)
	print_msg('Number of cpus               : %i\n'     % ncpu)
	print_msg('Use Voronoi weights          : %s\n\n'   % flag_weights)
コード例 #40
0
ファイル: functions.py プロジェクト: leschzinerlab/EMAN2recon
def ali3d_MPI(stack, ref_vol, outdir, maskfile = None, ir = 1, ou = -1, rs = 1, 
	    xr = "4 2 2 1", yr = "-1", ts = "1 1 0.5 0.25", delta = "10 6 4 4", an = "-1",
	    center = 0, maxit = 5, term = 95, CTF = False, fourvar = False, snr = 1.0,  ref_a = "S", sym = "c1", 
	    sort=True, cutoff=999.99, pix_cutoff="0", two_tail=False, model_jump="1 1 1 1 1", restart=False, save_half=False,
	    protos=None, oplane=None, lmask=-1, ilmask=-1, findseam=False, vertstep=None, hpars="-1", hsearch="73.0 170.0",
	    full_output = False, compare_repro = False, compare_ref_free = "-1", ref_free_cutoff= "-1 -1 -1 -1",
	    wcmask = None, debug = False, recon_pad = 4):

	from alignment      import Numrinit, prepare_refrings
	from utilities      import model_circle, get_image, drop_image, get_input_from_string
	from utilities      import bcast_list_to_all, bcast_number_to_all, reduce_EMData_to_root, bcast_EMData_to_all 
	from utilities      import send_attr_dict
	from utilities      import get_params_proj, file_type
	from fundamentals   import rot_avg_image
	import os
	import types
	from utilities      import print_begin_msg, print_end_msg, print_msg
	from mpi	    import mpi_bcast, mpi_comm_size, mpi_comm_rank, MPI_FLOAT, MPI_COMM_WORLD, mpi_barrier, mpi_reduce
	from mpi	    import mpi_reduce, MPI_INT, MPI_SUM, mpi_finalize
	from filter	 import filt_ctf
	from projection     import prep_vol, prgs
	from statistics     import hist_list, varf3d_MPI, fsc_mask
	from numpy	  import array, bincount, array2string, ones

	number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
	myid	   = mpi_comm_rank(MPI_COMM_WORLD)
	main_node = 0
	if myid == main_node:
		if os.path.exists(outdir):  ERROR('Output directory exists, please change the name and restart the program', "ali3d_MPI", 1)
		os.mkdir(outdir)
	mpi_barrier(MPI_COMM_WORLD)

	if debug:
		from time import sleep
		while not os.path.exists(outdir):
			print  "Node ",myid,"  waiting..."
			sleep(5)

		info_file = os.path.join(outdir, "progress%04d"%myid)
		finfo = open(info_file, 'w')
	else:
		finfo = None
	mjump = get_input_from_string(model_jump)
	xrng	= get_input_from_string(xr)
	if  yr == "-1":  yrng = xrng
	else	  :  yrng = get_input_from_string(yr)
	step	= get_input_from_string(ts)
	delta       = get_input_from_string(delta)
	ref_free_cutoff = get_input_from_string(ref_free_cutoff)	
	pix_cutoff = get_input_from_string(pix_cutoff)
	
	lstp = min(len(xrng), len(yrng), len(step), len(delta))
	if an == "-1":
		an = [-1] * lstp
	else:
		an = get_input_from_string(an)
	# make sure pix_cutoff is set for all iterations
	if len(pix_cutoff)<lstp:
		for i in xrange(len(pix_cutoff),lstp):
			pix_cutoff.append(pix_cutoff[-1])
	# don't waste time on sub-pixel alignment for low-resolution ang incr
	for i in range(len(step)):
		if (delta[i] > 4 or delta[i] == -1) and step[i] < 1:
			step[i] = 1

	first_ring  = int(ir)
	rstep       = int(rs)
	last_ring   = int(ou)
	max_iter    = int(maxit)
	center      = int(center)

	nrefs   = EMUtil.get_image_count( ref_vol )
	nmasks = 0
	if maskfile:
		# read number of masks within each maskfile (mc)
		nmasks   = EMUtil.get_image_count( maskfile )
		# open masks within maskfile (mc)
		maskF   = EMData.read_images(maskfile, xrange(nmasks))
	vol     = EMData.read_images(ref_vol, xrange(nrefs))
	nx      = vol[0].get_xsize()

	## make sure box sizes are the same
	if myid == main_node:
		im=EMData.read_images(stack,[0])
		bx = im[0].get_xsize()
		if bx!=nx:
			print_msg("Error: Stack box size (%i) differs from initial model (%i)\n"%(bx,nx))
			sys.exit()
		del im,bx
	
	# for helical processing:
	helicalrecon = False
	if protos is not None or hpars != "-1" or findseam is True:
		helicalrecon = True
		# if no out-of-plane param set, use 5 degrees
		if oplane is None:
			oplane=5.0
	if protos is not None:
		proto = get_input_from_string(protos)
		if len(proto) != nrefs:
			print_msg("Error: insufficient protofilament numbers supplied")
			sys.exit()
	if hpars != "-1":
		hpars = get_input_from_string(hpars)
		if len(hpars) != 2*nrefs:
			print_msg("Error: insufficient helical parameters supplied")
			sys.exit()
	## create helical parameter file for helical reconstruction
	if helicalrecon is True and myid == main_node:
		from hfunctions import createHpar
		# create initial helical parameter files
		dp=[0]*nrefs
		dphi=[0]*nrefs
		vdp=[0]*nrefs
		vdphi=[0]*nrefs
		for iref in xrange(nrefs):
			hpar = os.path.join(outdir,"hpar%02d.spi"%(iref))
			params = False
			if hpars != "-1":
				# if helical parameters explicitly given, set twist & rise
				params = [float(hpars[iref*2]),float(hpars[(iref*2)+1])]
			dp[iref],dphi[iref],vdp[iref],vdphi[iref] = createHpar(hpar,proto[iref],params,vertstep)

	# get values for helical search parameters
	hsearch = get_input_from_string(hsearch)
	if len(hsearch) != 2:
		print_msg("Error: specify outer and inner radii for helical search")
		sys.exit()

	if last_ring < 0 or last_ring > int(nx/2)-2 :	last_ring = int(nx/2) - 2

	if myid == main_node:
	#	import user_functions
	#	user_func = user_functions.factory[user_func_name]

		print_begin_msg("ali3d_MPI")
		print_msg("Input stack		 : %s\n"%(stack))
		print_msg("Reference volume	    : %s\n"%(ref_vol))	
		print_msg("Output directory	    : %s\n"%(outdir))
		if nmasks > 0:
			print_msg("Maskfile (number of masks)  : %s (%i)\n"%(maskfile,nmasks))
		print_msg("Inner radius		: %i\n"%(first_ring))
		print_msg("Outer radius		: %i\n"%(last_ring))
		print_msg("Ring step		   : %i\n"%(rstep))
		print_msg("X search range	      : %s\n"%(xrng))
		print_msg("Y search range	      : %s\n"%(yrng))
		print_msg("Translational step	  : %s\n"%(step))
		print_msg("Angular step		: %s\n"%(delta))
		print_msg("Angular search range	: %s\n"%(an))
		print_msg("Maximum iteration	   : %i\n"%(max_iter))
		print_msg("Center type		 : %i\n"%(center))
		print_msg("CTF correction	      : %s\n"%(CTF))
		print_msg("Signal-to-Noise Ratio       : %f\n"%(snr))
		print_msg("Reference projection method : %s\n"%(ref_a))
		print_msg("Symmetry group	      : %s\n"%(sym))
		print_msg("Fourier padding for 3D      : %i\n"%(recon_pad))
		print_msg("Number of reference models  : %i\n"%(nrefs))
		print_msg("Sort images between models  : %s\n"%(sort))
		print_msg("Allow images to jump	: %s\n"%(mjump))
		print_msg("CC cutoff standard dev      : %f\n"%(cutoff))
		print_msg("Two tail cutoff	     : %s\n"%(two_tail))
		print_msg("Termination pix error       : %f\n"%(term))
		print_msg("Pixel error cutoff	  : %s\n"%(pix_cutoff))
		print_msg("Restart		     : %s\n"%(restart))
		print_msg("Full output		 : %s\n"%(full_output))
		print_msg("Compare reprojections       : %s\n"%(compare_repro))
		print_msg("Compare ref free class avgs : %s\n"%(compare_ref_free))
		print_msg("Use cutoff from ref free    : %s\n"%(ref_free_cutoff))
		if protos:
			print_msg("Protofilament numbers	: %s\n"%(proto))
			print_msg("Using helical search range   : %s\n"%hsearch) 
		if findseam is True:
			print_msg("Using seam-based reconstruction\n")
		if hpars != "-1":
			print_msg("Using hpars		  : %s\n"%hpars)
		if vertstep != None:
			print_msg("Using vertical step    : %.2f\n"%vertstep)
		if save_half is True:
			print_msg("Saving even/odd halves\n")
		for i in xrange(100) : print_msg("*")
		print_msg("\n\n")
	if maskfile:
		if type(maskfile) is types.StringType: mask3D = get_image(maskfile)
		else:				  mask3D = maskfile
	else: mask3D = model_circle(last_ring, nx, nx, nx)

	numr	= Numrinit(first_ring, last_ring, rstep, "F")
	mask2D  = model_circle(last_ring,nx,nx) - model_circle(first_ring,nx,nx)

	fscmask = model_circle(last_ring,nx,nx,nx)
	if CTF:
		from filter	 import filt_ctf
	from reconstruction_rjh import rec3D_MPI_noCTF

	if myid == main_node:
		active = EMUtil.get_all_attributes(stack, 'active')
		list_of_particles = []
		for im in xrange(len(active)):
			if active[im]:  list_of_particles.append(im)
		del active
		nima = len(list_of_particles)
	else:
		nima = 0
	total_nima = bcast_number_to_all(nima, source_node = main_node)

	if myid != main_node:
		list_of_particles = [-1]*total_nima
	list_of_particles = bcast_list_to_all(list_of_particles, source_node = main_node)

	image_start, image_end = MPI_start_end(total_nima, number_of_proc, myid)

	# create a list of images for each node
	list_of_particles = list_of_particles[image_start: image_end]
	nima = len(list_of_particles)
	if debug:
		finfo.write("image_start, image_end: %d %d\n" %(image_start, image_end))
		finfo.flush()

	data = EMData.read_images(stack, list_of_particles)

	t_zero = Transform({"type":"spider","phi":0,"theta":0,"psi":0,"tx":0,"ty":0})
	transmulti = [[t_zero for i in xrange(nrefs)] for j in xrange(nima)]

	for iref,im in ((iref,im) for iref in xrange(nrefs) for im in xrange(nima)):
		if nrefs == 1:
			transmulti[im][iref] = data[im].get_attr("xform.projection")
		else:
			# if multi models, keep track of eulers for all models
			try:
				transmulti[im][iref] = data[im].get_attr("eulers_txty.%i"%iref)
			except:
				data[im].set_attr("eulers_txty.%i"%iref,t_zero)

	scoremulti = [[0.0 for i in xrange(nrefs)] for j in xrange(nima)] 
	pixelmulti = [[0.0 for i in xrange(nrefs)] for j in xrange(nima)] 
	ref_res = [0.0 for x in xrange(nrefs)] 
	apix = data[0].get_attr('apix_x')

	# for oplane parameter, create cylindrical mask
	if oplane is not None and myid == main_node:
		from hfunctions import createCylMask
		cmaskf=os.path.join(outdir, "mask3D_cyl.mrc")
		mask3D = createCylMask(data,ou,lmask,ilmask,cmaskf)
		# if finding seam of helix, create wedge masks
		if findseam is True:
			wedgemask=[]
			for pf in xrange(nrefs):
				wedgemask.append(EMData())
			# wedgemask option
			if wcmask is not None:
				wcmask = get_input_from_string(wcmask)
				if len(wcmask) != 3:
					print_msg("Error: wcmask option requires 3 values: x y radius")
					sys.exit()

	# determine if particles have helix info:
	try:
		data[0].get_attr('h_angle')
		original_data = []
		boxmask = True
		from hfunctions import createBoxMask
	except:
		boxmask = False

	# prepare particles
	for im in xrange(nima):
		data[im].set_attr('ID', list_of_particles[im])
		data[im].set_attr('pix_score', int(0))
		if CTF:
			# only phaseflip particles, not full CTF correction
			ctf_params = data[im].get_attr("ctf")
			st = Util.infomask(data[im], mask2D, False)
			data[im] -= st[0]
			data[im] = filt_ctf(data[im], ctf_params, sign = -1, binary=1)
			data[im].set_attr('ctf_applied', 1)
		# for window mask:
		if boxmask is True:
			h_angle = data[im].get_attr("h_angle")
			original_data.append(data[im].copy())
			bmask = createBoxMask(nx,apix,ou,lmask,h_angle)
			data[im]*=bmask
			del bmask
	if debug:
		finfo.write( '%d loaded  \n' % nima )
		finfo.flush()
	if myid == main_node:
		# initialize data for the reference preparation function
		ref_data = [ mask3D, max(center,0), None, None, None, None ]
		# for method -1, switch off centering in user function

	from time import time	

	#  this is needed for gathering of pixel errors
	disps = []
	recvcount = []
	disps_score = []
	recvcount_score = []
	for im in xrange(number_of_proc):
		if( im == main_node ):  
			disps.append(0)
			disps_score.append(0)
		else:		  
			disps.append(disps[im-1] + recvcount[im-1])
			disps_score.append(disps_score[im-1] + recvcount_score[im-1])
		ib, ie = MPI_start_end(total_nima, number_of_proc, im)
		recvcount.append( ie - ib )
		recvcount_score.append((ie-ib)*nrefs)

	pixer = [0.0]*nima
	cs = [0.0]*3
	total_iter = 0
	volodd = EMData.read_images(ref_vol, xrange(nrefs))
	voleve = EMData.read_images(ref_vol, xrange(nrefs))

	if restart:
		# recreate initial volumes from alignments stored in header
		itout = "000_00"
		for iref in xrange(nrefs):
			if(nrefs == 1):
				modout = ""
			else:
				modout = "_model_%02d"%(iref)	
	
			if(sort): 
				group = iref
				for im in xrange(nima):
					imgroup = data[im].get_attr('group')
					if imgroup == iref:
						data[im].set_attr('xform.projection',transmulti[im][iref])
			else: 
				group = int(999) 
				for im in xrange(nima):
					data[im].set_attr('xform.projection',transmulti[im][iref])
			
			fscfile = os.path.join(outdir, "fsc_%s%s"%(itout,modout))

			vol[iref], fscc, volodd[iref], voleve[iref] = rec3D_MPI_noCTF(data, sym, fscmask, fscfile, myid, main_node, index = group, npad = recon_pad)

			if myid == main_node:
				if helicalrecon:
					from hfunctions import processHelicalVol

					vstep=None
					if vertstep is not None:
						vstep=(vdp[iref],vdphi[iref])
					print_msg("Old rise and twist for model %i     : %8.3f, %8.3f\n"%(iref,dp[iref],dphi[iref]))
					hvals=processHelicalVol(vol[iref],voleve[iref],volodd[iref],iref,outdir,itout,
								dp[iref],dphi[iref],apix,hsearch,findseam,vstep,wcmask)
					(vol[iref],voleve[iref],volodd[iref],dp[iref],dphi[iref],vdp[iref],vdphi[iref])=hvals
					print_msg("New rise and twist for model %i     : %8.3f, %8.3f\n"%(iref,dp[iref],dphi[iref]))
					# get new FSC from symmetrized half volumes
					fscc = fsc_mask( volodd[iref], voleve[iref], mask3D, rstep, fscfile)
				else:
					vol[iref].write_image(os.path.join(outdir, "vol_%s.hdf"%itout),-1)

				if save_half is True:
					volodd[iref].write_image(os.path.join(outdir, "volodd_%s.hdf"%itout),-1)
					voleve[iref].write_image(os.path.join(outdir, "voleve_%s.hdf"%itout),-1)

				if nmasks > 1:
					# Read mask for multiplying
					ref_data[0] = maskF[iref]
				ref_data[2] = vol[iref]
				ref_data[3] = fscc
				#  call user-supplied function to prepare reference image, i.e., center and filter it
				vol[iref], cs,fl = ref_ali3d(ref_data)
				vol[iref].write_image(os.path.join(outdir, "volf_%s.hdf"%(itout)),-1)
				if (apix == 1):
					res_msg = "Models filtered at spatial frequency of:\t"
					res = fl
				else:
					res_msg = "Models filtered at resolution of:       \t"
					res = apix / fl	
				ares = array2string(array(res), precision = 2)
				print_msg("%s%s\n\n"%(res_msg,ares))	
			
			bcast_EMData_to_all(vol[iref], myid, main_node)
			# write out headers, under MPI writing has to be done sequentially
			mpi_barrier(MPI_COMM_WORLD)

	# projection matching	
	for N_step in xrange(lstp):
		terminate = 0
		Iter = -1
 		while(Iter < max_iter-1 and terminate == 0):
			Iter += 1
			total_iter += 1
			itout = "%03g_%02d" %(delta[N_step], Iter)
			if myid == main_node:
				print_msg("ITERATION #%3d, inner iteration #%3d\nDelta = %4.1f, an = %5.2f, xrange = %5.2f, yrange = %5.2f, step = %5.2f\n\n"%(N_step, Iter, delta[N_step], an[N_step], xrng[N_step],yrng[N_step],step[N_step]))
	
			for iref in xrange(nrefs):
				if myid == main_node: start_time = time()
				volft,kb = prep_vol( vol[iref] )

				## constrain projections to out of plane parameter
				theta1 = None
				theta2 = None
				if oplane is not None:
					theta1 = 90-oplane
					theta2 = 90+oplane
				refrings = prepare_refrings( volft, kb, nx, delta[N_step], ref_a, sym, numr, MPI=True, phiEqpsi = "Minus", initial_theta=theta1, delta_theta=theta2)
				
				del volft,kb

				if myid== main_node:
					print_msg( "Time to prepare projections for model %i: %s\n" % (iref, legibleTime(time()-start_time)) )
					start_time = time()
	
				for im in xrange( nima ):
					data[im].set_attr("xform.projection", transmulti[im][iref])
					if an[N_step] == -1:
						t1, peak, pixer[im] = proj_ali_incore(data[im],refrings,numr,xrng[N_step],yrng[N_step],step[N_step],finfo)
					else:
						t1, peak, pixer[im] = proj_ali_incore_local(data[im],refrings,numr,xrng[N_step],yrng[N_step],step[N_step],an[N_step],finfo)
					#data[im].set_attr("xform.projection"%iref, t1)
					if nrefs > 1: data[im].set_attr("eulers_txty.%i"%iref,t1)
					scoremulti[im][iref] = peak
					from pixel_error import max_3D_pixel_error
					# t1 is the current param, t2 is old
					t2 = transmulti[im][iref]
					pixelmulti[im][iref] = max_3D_pixel_error(t1,t2,numr[-3])
					transmulti[im][iref] = t1

				if myid == main_node:
					print_msg("Time of alignment for model %i: %s\n"%(iref, legibleTime(time()-start_time)))
					start_time = time()


			# gather scoring data from all processors
			from mpi import mpi_gatherv
			scoremultisend = sum(scoremulti,[])
			pixelmultisend = sum(pixelmulti,[])
			tmp = mpi_gatherv(scoremultisend,len(scoremultisend),MPI_FLOAT, recvcount_score, disps_score, MPI_FLOAT, main_node,MPI_COMM_WORLD)
			tmp1 = mpi_gatherv(pixelmultisend,len(pixelmultisend),MPI_FLOAT, recvcount_score, disps_score, MPI_FLOAT, main_node,MPI_COMM_WORLD)
			tmp = mpi_bcast(tmp,(total_nima * nrefs), MPI_FLOAT,0, MPI_COMM_WORLD)
			tmp1 = mpi_bcast(tmp1,(total_nima * nrefs), MPI_FLOAT,0, MPI_COMM_WORLD)
			tmp = map(float,tmp)
			tmp1 = map(float,tmp1)
			score = array(tmp).reshape(-1,nrefs)
			pixelerror = array(tmp1).reshape(-1,nrefs) 
			score_local = array(scoremulti)
			mean_score = score.mean(axis=0)
			std_score = score.std(axis=0)
			cut = mean_score - (cutoff * std_score)
			cut2 = mean_score + (cutoff * std_score)
			res_max = score_local.argmax(axis=1)
			minus_cc = [0.0 for x in xrange(nrefs)]
			minus_pix = [0.0 for x in xrange(nrefs)]
			minus_ref = [0.0 for x in xrange(nrefs)]
			
			#output pixel errors
			if(myid == main_node):
				from statistics import hist_list
				lhist = 20
				pixmin = pixelerror.min(axis=1)
				region, histo = hist_list(pixmin, lhist)
				if(region[0] < 0.0):  region[0] = 0.0
				print_msg("Histogram of pixel errors\n      ERROR       number of particles\n")
				for lhx in xrange(lhist):
					print_msg(" %10.3f     %7d\n"%(region[lhx], histo[lhx]))
				# Terminate if 95% within 1 pixel error
				im = 0
				for lhx in xrange(lhist):
					if(region[lhx] > 1.0): break
					im += histo[lhx]
				print_msg( "Percent of particles with pixel error < 1: %f\n\n"% (im/float(total_nima)*100))
				term_cond = float(term)/100
				if(im/float(total_nima) > term_cond): 
					terminate = 1
					print_msg("Terminating internal loop\n")
				del region, histo
			terminate = mpi_bcast(terminate, 1, MPI_INT, 0, MPI_COMM_WORLD)
			terminate = int(terminate[0])	
			
			for im in xrange(nima):
				if(sort==False):
					data[im].set_attr('group',999)
				elif (mjump[N_step]==1):
					data[im].set_attr('group',int(res_max[im]))
				
				pix_run = data[im].get_attr('pix_score')			
				if (pix_cutoff[N_step]==1 and (terminate==1 or Iter == max_iter-1)):
					if (pixelmulti[im][int(res_max[im])] > 1):
						data[im].set_attr('pix_score',int(777))

				if (score_local[im][int(res_max[im])]<cut[int(res_max[im])]) or (two_tail and score_local[im][int(res_max[im])]>cut2[int(res_max[im])]):
					data[im].set_attr('group',int(888))
					minus_cc[int(res_max[im])] = minus_cc[int(res_max[im])] + 1

				if(pix_run == 777):
					data[im].set_attr('group',int(777))
					minus_pix[int(res_max[im])] = minus_pix[int(res_max[im])] + 1

				if (compare_ref_free != "-1") and (ref_free_cutoff[N_step] != -1) and (total_iter > 1):
					id = data[im].get_attr('ID')
					if id in rejects:
						data[im].set_attr('group',int(666))
						minus_ref[int(res_max[im])] = minus_ref[int(res_max[im])] + 1	
						
				
			minus_cc_tot = mpi_reduce(minus_cc,nrefs,MPI_FLOAT,MPI_SUM,0,MPI_COMM_WORLD)	
			minus_pix_tot = mpi_reduce(minus_pix,nrefs,MPI_FLOAT,MPI_SUM,0,MPI_COMM_WORLD) 	
			minus_ref_tot = mpi_reduce(minus_ref,nrefs,MPI_FLOAT,MPI_SUM,0,MPI_COMM_WORLD)
			if (myid == main_node):
				if(sort):
					tot_max = score.argmax(axis=1)
					res = bincount(tot_max)
				else:
					res = ones(nrefs) * total_nima
				print_msg("Particle distribution:	     \t\t%s\n"%(res*1.0))
				afcut1 = res - minus_cc_tot
				afcut2 = afcut1 - minus_pix_tot
				afcut3 = afcut2 - minus_ref_tot
				print_msg("Particle distribution after cc cutoff:\t\t%s\n"%(afcut1))
				print_msg("Particle distribution after pix cutoff:\t\t%s\n"%(afcut2)) 
				print_msg("Particle distribution after ref cutoff:\t\t%s\n\n"%(afcut3)) 
					
						
			res = [0.0 for i in xrange(nrefs)]
			for iref in xrange(nrefs):
				if(center == -1):
					from utilities      import estimate_3D_center_MPI, rotate_3D_shift
					dummy=EMData()
					cs[0], cs[1], cs[2], dummy, dummy = estimate_3D_center_MPI(data, total_nima, myid, number_of_proc, main_node)				
					cs = mpi_bcast(cs, 3, MPI_FLOAT, main_node, MPI_COMM_WORLD)
					cs = [-float(cs[0]), -float(cs[1]), -float(cs[2])]
					rotate_3D_shift(data, cs)


				if(sort): 
					group = iref
					for im in xrange(nima):
						imgroup = data[im].get_attr('group')
						if imgroup == iref:
							data[im].set_attr('xform.projection',transmulti[im][iref])
				else: 
					group = int(999) 
					for im in xrange(nima):
						data[im].set_attr('xform.projection',transmulti[im][iref])
				if(nrefs == 1):
					modout = ""
				else:
					modout = "_model_%02d"%(iref)	
				
				fscfile = os.path.join(outdir, "fsc_%s%s"%(itout,modout))
				vol[iref], fscc, volodd[iref], voleve[iref] = rec3D_MPI_noCTF(data, sym, fscmask, fscfile, myid, main_node, index=group, npad=recon_pad)
	
				if myid == main_node:
					print_msg("3D reconstruction time for model %i: %s\n"%(iref, legibleTime(time()-start_time)))
					start_time = time()
	
				# Compute Fourier variance
				if fourvar:
					outvar = os.path.join(outdir, "volVar_%s.hdf"%(itout))
					ssnr_file = os.path.join(outdir, "ssnr_%s"%(itout))
					varf = varf3d_MPI(data, ssnr_text_file=ssnr_file, mask2D=None, reference_structure=vol[iref], ou=last_ring, rw=1.0, npad=1, CTF=None, sign=1, sym=sym, myid=myid)
					if myid == main_node:
						print_msg("Time to calculate 3D Fourier variance for model %i: %s\n"%(iref, legibleTime(time()-start_time)))
						start_time = time()
						varf = 1.0/varf
						varf.write_image(outvar,-1)
				else:  varf = None

				if myid == main_node:
					if helicalrecon:
						from hfunctions import processHelicalVol

						vstep=None
						if vertstep is not None:
							vstep=(vdp[iref],vdphi[iref])
						print_msg("Old rise and twist for model %i     : %8.3f, %8.3f\n"%(iref,dp[iref],dphi[iref]))
						hvals=processHelicalVol(vol[iref],voleve[iref],volodd[iref],iref,outdir,itout,
									dp[iref],dphi[iref],apix,hsearch,findseam,vstep,wcmask)
						(vol[iref],voleve[iref],volodd[iref],dp[iref],dphi[iref],vdp[iref],vdphi[iref])=hvals
						print_msg("New rise and twist for model %i     : %8.3f, %8.3f\n"%(iref,dp[iref],dphi[iref]))
						# get new FSC from symmetrized half volumes
						fscc = fsc_mask( volodd[iref], voleve[iref], mask3D, rstep, fscfile)

						print_msg("Time to search and apply helical symmetry for model %i: %s\n\n"%(iref, legibleTime(time()-start_time)))
						start_time = time()
					else:
						vol[iref].write_image(os.path.join(outdir, "vol_%s.hdf"%(itout)),-1)

					if save_half is True:
						volodd[iref].write_image(os.path.join(outdir, "volodd_%s.hdf"%(itout)),-1)
						voleve[iref].write_image(os.path.join(outdir, "voleve_%s.hdf"%(itout)),-1)

					if nmasks > 1:
						# Read mask for multiplying
						ref_data[0] = maskF[iref]
					ref_data[2] = vol[iref]
					ref_data[3] = fscc
					ref_data[4] = varf
					#  call user-supplied function to prepare reference image, i.e., center and filter it
					vol[iref], cs,fl = ref_ali3d(ref_data)
					vol[iref].write_image(os.path.join(outdir, "volf_%s.hdf"%(itout)),-1)
					if (apix == 1):
						res_msg = "Models filtered at spatial frequency of:\t"
						res[iref] = fl
					else:
						res_msg = "Models filtered at resolution of:       \t"
						res[iref] = apix / fl	
	
				del varf
				bcast_EMData_to_all(vol[iref], myid, main_node)
				
				if compare_ref_free != "-1": compare_repro = True
				if compare_repro:
					outfile_repro = comp_rep(refrings, data, itout, modout, vol[iref], group, nima, nx, myid, main_node, outdir)
					mpi_barrier(MPI_COMM_WORLD)
					if compare_ref_free != "-1":
						ref_free_output = os.path.join(outdir,"ref_free_%s%s"%(itout,modout))
						rejects = compare(compare_ref_free, outfile_repro,ref_free_output,yrng[N_step], xrng[N_step], rstep,nx,apix,ref_free_cutoff[N_step], number_of_proc, myid, main_node)

			# retrieve alignment params from all processors
			par_str = ['xform.projection','ID','group']
			if nrefs > 1:
				for iref in xrange(nrefs):
					par_str.append('eulers_txty.%i'%iref)

			if myid == main_node:
				from utilities import recv_attr_dict
				recv_attr_dict(main_node, stack, data, par_str, image_start, image_end, number_of_proc)
				
			else:	send_attr_dict(main_node, data, par_str, image_start, image_end)

			if myid == main_node:
				ares = array2string(array(res), precision = 2)
				print_msg("%s%s\n\n"%(res_msg,ares))
				dummy = EMData()
				if full_output:
					nimat = EMUtil.get_image_count(stack)
					output_file = os.path.join(outdir, "paramout_%s"%itout)
					foutput = open(output_file, 'w')
					for im in xrange(nimat):
						# save the parameters for each of the models
						outstring = ""
						dummy.read_image(stack,im,True)
						param3d = dummy.get_attr('xform.projection')
						g = dummy.get_attr("group")
						# retrieve alignments in EMAN-format
						pE = param3d.get_params('eman')
						outstring += "%f\t%f\t%f\t%f\t%f\t%i\n" %(pE["az"], pE["alt"], pE["phi"], pE["tx"], pE["ty"],g)
						foutput.write(outstring)
					foutput.close()
				del dummy
			mpi_barrier(MPI_COMM_WORLD)


#	mpi_finalize()	

	if myid == main_node: print_end_msg("ali3d_MPI")
コード例 #41
0
ファイル: projection.py プロジェクト: cpsemmens/eman2
def cml_end_log(Ori):
	from utilities import print_msg
	global g_n_prj
	print_msg('\n\n')
	for i in xrange(g_n_prj): print_msg('Projection #%03i: phi %10.5f    theta %10.5f    psi %10.5f\n' % (i, Ori[4*i], Ori[4*i+1], Ori[4*i+2]))
コード例 #42
0
ファイル: sx3dvariability.py プロジェクト: cryoem/test
def main():

	def params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror):
		if mirror:
			m = 1
			alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0, 540.0-psi, 0, 0, 1.0)
		else:
			m = 0
			alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0, 360.0-psi, 0, 0, 1.0)
		return  alpha, sx, sy, m
	
	progname = os.path.basename(sys.argv[0])
	usage = progname + " prj_stack  --ave2D= --var2D=  --ave3D= --var3D= --img_per_grp= --fl=0.2 --aa=0.1  --sym=symmetry --CTF"
	parser = OptionParser(usage, version=SPARXVERSION)

	parser.add_option("--ave2D",		type="string"	   ,	default=False,				help="write to the disk a stack of 2D averages")
	parser.add_option("--var2D",		type="string"	   ,	default=False,				help="write to the disk a stack of 2D variances")
	parser.add_option("--ave3D",		type="string"	   ,	default=False,				help="write to the disk reconstructed 3D average")
	parser.add_option("--var3D",		type="string"	   ,	default=False,				help="compute 3D variability (time consuming!)")
	parser.add_option("--img_per_grp",	type="int"         ,	default=10   ,				help="number of neighbouring projections")
	parser.add_option("--no_norm",		action="store_true",	default=False,				help="do not use normalization")
	parser.add_option("--radiusvar", 	type="int"         ,	default=-1   ,				help="radius for 3D var" )
	parser.add_option("--npad",			type="int"         ,	default=2    ,				help="number of time to pad the original images")
	parser.add_option("--sym" , 		type="string"      ,	default="c1" ,				help="symmetry")
	parser.add_option("--fl",			type="float"       ,	default=0.0  ,				help="stop-band frequency (Default - no filtration)")
	parser.add_option("--aa",			type="float"       ,	default=0.0  ,				help="fall off of the filter (Default - no filtration)")
	parser.add_option("--CTF",			action="store_true",	default=False,				help="use CFT correction")
	parser.add_option("--VERBOSE",		action="store_true",	default=False,				help="Long output for debugging")
	#parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")
	#parser.add_option("--radiuspca", 	type="int"         ,	default=-1   ,				help="radius for PCA" )
	#parser.add_option("--iter", 		type="int"         ,	default=40   ,				help="maximum number of iterations (stop criterion of reconstruction process)" )
	#parser.add_option("--abs", 			type="float"       ,	default=0.0  ,				help="minimum average absolute change of voxels' values (stop criterion of reconstruction process)" )
	#parser.add_option("--squ", 			type="float"       ,	default=0.0  ,				help="minimum average squared change of voxels' values (stop criterion of reconstruction process)" )
	parser.add_option("--VAR" , 		action="store_true",	default=False,				help="stack on input consists of 2D variances (Default False)")
	parser.add_option("--decimate",     type="float",           default=1.0,                 help="image decimate rate, a number large than 1. default is 1")
	parser.add_option("--window",       type="int",             default=0,                   help="reduce images to a small image size without changing pixel_size. Default value is zero.")
	#parser.add_option("--SND",			action="store_true",	default=False,				help="compute squared normalized differences (Default False)")
	parser.add_option("--nvec",			type="int"         ,	default=0    ,				help="number of eigenvectors, default = 0 meaning no PCA calculated")
	parser.add_option("--symmetrize",	action="store_true",	default=False,				help="Prepare input stack for handling symmetry (Default False)")
	
	(options,args) = parser.parse_args()
	#####
	from mpi import mpi_init, mpi_comm_rank, mpi_comm_size, mpi_recv, MPI_COMM_WORLD, MPI_TAG_UB
	from mpi import mpi_barrier, mpi_reduce, mpi_bcast, mpi_send, MPI_FLOAT, MPI_SUM, MPI_INT, MPI_MAX
	from applications import MPI_start_end
	from reconstruction import recons3d_em, recons3d_em_MPI
	from reconstruction	import recons3d_4nn_MPI, recons3d_4nn_ctf_MPI
	from utilities import print_begin_msg, print_end_msg, print_msg
	from utilities import read_text_row, get_image, get_im
	from utilities import bcast_EMData_to_all, bcast_number_to_all
	from utilities import get_symt

	#  This is code for handling symmetries by the above program.  To be incorporated. PAP 01/27/2015

	from EMAN2db import db_open_dict
	
	if options.symmetrize :
		try:
			sys.argv = mpi_init(len(sys.argv), sys.argv)
			try:	
				number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
				if( number_of_proc > 1 ):
					ERROR("Cannot use more than one CPU for symmetry prepration","sx3dvariability",1)
			except:
				pass
		except:
			pass

		#  Input
		#instack = "Clean_NORM_CTF_start_wparams.hdf"
		#instack = "bdb:data"
		instack = args[0]
		sym = options.sym
		if( sym == "c1" ):
			ERROR("Thre is no need to symmetrize stack for C1 symmetry","sx3dvariability",1)

		if(instack[:4] !="bdb:"):
			stack = "bdb:data"
			delete_bdb(stack)
			cmdexecute("sxcpy.py  "+instack+"  "+stack)
		else:
			stack = instack

		qt = EMUtil.get_all_attributes(stack,'xform.projection')

		na = len(qt)
		ts = get_symt(sym)
		ks = len(ts)
		angsa = [None]*na
		for k in xrange(ks):
			delete_bdb("bdb:Q%1d"%k)
			cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
			DB = db_open_dict("bdb:Q%1d"%k)
			for i in xrange(na):
				ut = qt[i]*ts[k]
				DB.set_attr(i, "xform.projection", ut)
				#bt = ut.get_params("spider")
				#angsa[i] = [round(bt["phi"],3)%360.0, round(bt["theta"],3)%360.0, bt["psi"], -bt["tx"], -bt["ty"]]
			#write_text_row(angsa, 'ptsma%1d.txt'%k)
			#cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
			#cmdexecute("sxheader.py  bdb:Q%1d  --params=xform.projection  --import=ptsma%1d.txt"%(k,k))
			DB.close()
		delete_bdb("bdb:sdata")
		cmdexecute("e2bdb.py . --makevstack=bdb:sdata --filt=Q")
		#cmdexecute("ls  EMAN2DB/sdata*")
		a = get_im("bdb:sdata")
		a.set_attr("variabilitysymmetry",sym)
		a.write_image("bdb:sdata")


	else:

		sys.argv = mpi_init(len(sys.argv), sys.argv)
		myid     = mpi_comm_rank(MPI_COMM_WORLD)
		number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
		main_node = 0

		if len(args) == 1:
			stack = args[0]
		else:
			print( "usage: " + usage)
			print( "Please run '" + progname + " -h' for detailed options")
			return 1

		t0 = time()
	
		# obsolete flags
		options.MPI = True
		options.nvec = 0
		options.radiuspca = -1
		options.iter = 40
		options.abs = 0.0
		options.squ = 0.0

		if options.fl > 0.0 and options.aa == 0.0:
			ERROR("Fall off has to be given for the low-pass filter", "sx3dvariability", 1, myid)
		if options.VAR and options.SND:
			ERROR("Only one of var and SND can be set!", "sx3dvariability", myid)
			exit()
		if options.VAR and (options.ave2D or options.ave3D or options.var2D): 
			ERROR("When VAR is set, the program cannot output ave2D, ave3D or var2D", "sx3dvariability", 1, myid)
			exit()
		#if options.SND and (options.ave2D or options.ave3D):
		#	ERROR("When SND is set, the program cannot output ave2D or ave3D", "sx3dvariability", 1, myid)
		#	exit()
		if options.nvec > 0 :
			ERROR("PCA option not implemented", "sx3dvariability", 1, myid)
			exit()
		if options.nvec > 0 and options.ave3D == None:
			ERROR("When doing PCA analysis, one must set ave3D", "sx3dvariability", myid=myid)
			exit()
		import string
		options.sym = options.sym.lower()
		 
		if global_def.CACHE_DISABLE:
			from utilities import disable_bdb_cache
			disable_bdb_cache()
		global_def.BATCH = True

		if myid == main_node:
			print_begin_msg("sx3dvariability")
			print_msg("%-70s:  %s\n"%("Input stack", stack))
	
		img_per_grp = options.img_per_grp
		nvec = options.nvec
		radiuspca = options.radiuspca

		symbaselen = 0
		if myid == main_node:
			nima = EMUtil.get_image_count(stack)
			img  = get_image(stack)
			nx   = img.get_xsize()
			ny   = img.get_ysize()
			if options.sym != "c1" :
				imgdata = get_im(stack)
				try:
					i = imgdata.get_attr("variabilitysymmetry")
					if(i != options.sym):
						ERROR("The symmetry provided does not agree with the symmetry of the input stack", "sx3dvariability", myid=myid)
				except:
					ERROR("Input stack is not prepared for symmetry, please follow instructions", "sx3dvariability", myid=myid)
				from utilities import get_symt
				i = len(get_symt(options.sym))
				if((nima/i)*i != nima):
					ERROR("The length of the input stack is incorrect for symmetry processing", "sx3dvariability", myid=myid)
				symbaselen = nima/i
			else:  symbaselen = nima
		else:
			nima = 0
			nx = 0
			ny = 0
		nima = bcast_number_to_all(nima)
		nx   = bcast_number_to_all(nx)
		ny   = bcast_number_to_all(ny)
		Tracker ={}
		Tracker["nx"]  =nx
		Tracker["ny"]  =ny
		Tracker["total_stack"]=nima
		if options.decimate==1.:
			if options.window !=0:
				nx = options.window
				ny = options.window
		else:
			if options.window ==0:
				nx = int(nx/options.decimate)
				ny = int(ny/options.decimate)
			else:
				nx = int(options.window/options.decimate)
				ny = nx
		symbaselen = bcast_number_to_all(symbaselen)
		if radiuspca == -1: radiuspca = nx/2-2

		if myid == main_node:
			print_msg("%-70s:  %d\n"%("Number of projection", nima))
		
		img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
		"""
		if options.SND:
			from projection		import prep_vol, prgs
			from statistics		import im_diff
			from utilities		import get_im, model_circle, get_params_proj, set_params_proj
			from utilities		import get_ctf, generate_ctf
			from filter			import filt_ctf
		
			imgdata = EMData.read_images(stack, range(img_begin, img_end))

			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)

			bcast_EMData_to_all(vol, myid)
			volft, kb = prep_vol(vol)

			mask = model_circle(nx/2-2, nx, ny)
			varList = []
			for i in xrange(img_begin, img_end):
				phi, theta, psi, s2x, s2y = get_params_proj(imgdata[i-img_begin])
				ref_prj = prgs(volft, kb, [phi, theta, psi, -s2x, -s2y])
				if options.CTF:
					ctf_params = get_ctf(imgdata[i-img_begin])
					ref_prj = filt_ctf(ref_prj, generate_ctf(ctf_params))
				diff, A, B = im_diff(ref_prj, imgdata[i-img_begin], mask)
				diff2 = diff*diff
				set_params_proj(diff2, [phi, theta, psi, s2x, s2y])
				varList.append(diff2)
			mpi_barrier(MPI_COMM_WORLD)
		"""
		if options.VAR:
			#varList = EMData.read_images(stack, range(img_begin, img_end))
			varList = []
			this_image = EMData()
			for index_of_particle in xrange(img_begin,img_end):
				this_image.read_image(stack,index_of_particle)
				varList.append(image_decimate_window_xform_ctf(img,options.decimate,options.window,options.CTF))
		else:
			from utilities		import bcast_number_to_all, bcast_list_to_all, send_EMData, recv_EMData
			from utilities		import set_params_proj, get_params_proj, params_3D_2D, get_params2D, set_params2D, compose_transform2
			from utilities		import model_blank, nearest_proj, model_circle
			from applications	import pca
			from statistics		import avgvar, avgvar_ctf, ccc
			from filter		    import filt_tanl
			from morphology		import threshold, square_root
			from projection 	import project, prep_vol, prgs
			from sets		    import Set

			if myid == main_node:
				t1 = time()
				proj_angles = []
				aveList = []
				tab = EMUtil.get_all_attributes(stack, 'xform.projection')
				for i in xrange(nima):
					t     = tab[i].get_params('spider')
					phi   = t['phi']
					theta = t['theta']
					psi   = t['psi']
					x     = theta
					if x > 90.0: x = 180.0 - x
					x = x*10000+psi
					proj_angles.append([x, t['phi'], t['theta'], t['psi'], i])
				t2 = time()
				print_msg("%-70s:  %d\n"%("Number of neighboring projections", img_per_grp))
				print_msg("...... Finding neighboring projections\n")
				if options.VERBOSE:
					print "Number of images per group: ", img_per_grp
					print "Now grouping projections"
				proj_angles.sort()

			proj_angles_list = [0.0]*(nima*4)
			if myid == main_node:
				for i in xrange(nima):
					proj_angles_list[i*4]   = proj_angles[i][1]
					proj_angles_list[i*4+1] = proj_angles[i][2]
					proj_angles_list[i*4+2] = proj_angles[i][3]
					proj_angles_list[i*4+3] = proj_angles[i][4]
			proj_angles_list = bcast_list_to_all(proj_angles_list, myid, main_node)
			proj_angles = []
			for i in xrange(nima):
				proj_angles.append([proj_angles_list[i*4], proj_angles_list[i*4+1], proj_angles_list[i*4+2], int(proj_angles_list[i*4+3])])
			del proj_angles_list

			proj_list, mirror_list = nearest_proj(proj_angles, img_per_grp, range(img_begin, img_end))

			all_proj = Set()
			for im in proj_list:
				for jm in im:
					all_proj.add(proj_angles[jm][3])

			all_proj = list(all_proj)
			if options.VERBOSE:
				print "On node %2d, number of images needed to be read = %5d"%(myid, len(all_proj))

			index = {}
			for i in xrange(len(all_proj)): index[all_proj[i]] = i
			mpi_barrier(MPI_COMM_WORLD)

			if myid == main_node:
				print_msg("%-70s:  %.2f\n"%("Finding neighboring projections lasted [s]", time()-t2))
				print_msg("%-70s:  %d\n"%("Number of groups processed on the main node", len(proj_list)))
				if options.VERBOSE:
					print "Grouping projections took: ", (time()-t2)/60	, "[min]"
					print "Number of groups on main node: ", len(proj_list)
			mpi_barrier(MPI_COMM_WORLD)

			if myid == main_node:
				print_msg("...... calculating the stack of 2D variances \n")
				if options.VERBOSE:
					print "Now calculating the stack of 2D variances"

			proj_params = [0.0]*(nima*5)
			aveList = []
			varList = []				
			if nvec > 0:
				eigList = [[] for i in xrange(nvec)]

			if options.VERBOSE: 	print "Begin to read images on processor %d"%(myid)
			ttt = time()
			#imgdata = EMData.read_images(stack, all_proj)
			img     = EMData()
			imgdata = []
			for index_of_proj in xrange(len(all_proj)):
				img.read_image(stack, all_proj[index_of_proj])
				dmg = image_decimate_window_xform_ctf(img,options.decimate,options.window,options.CTF)
				#print dmg.get_xsize(), "init"
				imgdata.append(dmg)
			if options.VERBOSE:
				print "Reading images on processor %d done, time = %.2f"%(myid, time()-ttt)
				print "On processor %d, we got %d images"%(myid, len(imgdata))
			mpi_barrier(MPI_COMM_WORLD)

			'''	
			imgdata2 = EMData.read_images(stack, range(img_begin, img_end))
			if options.fl > 0.0:
				for k in xrange(len(imgdata2)):
					imgdata2[k] = filt_tanl(imgdata2[k], options.fl, options.aa)
			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata2, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata2, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			if myid == main_node:
				vol.write_image("vol_ctf.hdf")
				print_msg("Writing to the disk volume reconstructed from averages as		:  %s\n"%("vol_ctf.hdf"))
			del vol, imgdata2
			mpi_barrier(MPI_COMM_WORLD)
			'''
			from applications import prepare_2d_forPCA
			from utilities import model_blank
			for i in xrange(len(proj_list)):
				ki = proj_angles[proj_list[i][0]][3]
				if ki >= symbaselen:  continue
				mi = index[ki]
				phiM, thetaM, psiM, s2xM, s2yM = get_params_proj(imgdata[mi])

				grp_imgdata = []
				for j in xrange(img_per_grp):
					mj = index[proj_angles[proj_list[i][j]][3]]
					phi, theta, psi, s2x, s2y = get_params_proj(imgdata[mj])
					alpha, sx, sy, mirror = params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror_list[i][j])
					if thetaM <= 90:
						if mirror == 0:  alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, phiM-phi, 0.0, 0.0, 1.0)
						else:            alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, 180-(phiM-phi), 0.0, 0.0, 1.0)
					else:
						if mirror == 0:  alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, -(phiM-phi), 0.0, 0.0, 1.0)
						else:            alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, -(180-(phiM-phi)), 0.0, 0.0, 1.0)
					set_params2D(imgdata[mj], [alpha, sx, sy, mirror, 1.0])
					grp_imgdata.append(imgdata[mj])
					#print grp_imgdata[j].get_xsize(), imgdata[mj].get_xsize()

				if not options.no_norm:
					#print grp_imgdata[j].get_xsize()
					mask = model_circle(nx/2-2, nx, nx)
					for k in xrange(img_per_grp):
						ave, std, minn, maxx = Util.infomask(grp_imgdata[k], mask, False)
						grp_imgdata[k] -= ave
						grp_imgdata[k] /= std
					del mask

				if options.fl > 0.0:
					from filter import filt_ctf, filt_table
					from fundamentals import fft, window2d
					nx2 = 2*nx
					ny2 = 2*ny
					if options.CTF:
						from utilities import pad
						for k in xrange(img_per_grp):
							grp_imgdata[k] = window2d(fft( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa) ),nx,ny)
							#grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
							#grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
					else:
						for k in xrange(img_per_grp):
							grp_imgdata[k] = filt_tanl( grp_imgdata[k], options.fl, options.aa)
							#grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
							#grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
				else:
					from utilities import pad, read_text_file
					from filter import filt_ctf, filt_table
					from fundamentals import fft, window2d
					nx2 = 2*nx
					ny2 = 2*ny
					if options.CTF:
						from utilities import pad
						for k in xrange(img_per_grp):
							grp_imgdata[k] = window2d( fft( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1) ) , nx,ny)
							#grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
							#grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)

				'''
				if i < 10 and myid == main_node:
					for k in xrange(10):
						grp_imgdata[k].write_image("grp%03d.hdf"%i, k)
				'''
				"""
				if myid == main_node and i==0:
					for pp in xrange(len(grp_imgdata)):
						grp_imgdata[pp].write_image("pp.hdf", pp)
				"""
				ave, grp_imgdata = prepare_2d_forPCA(grp_imgdata)
				"""
				if myid == main_node and i==0:
					for pp in xrange(len(grp_imgdata)):
						grp_imgdata[pp].write_image("qq.hdf", pp)
				"""

				var = model_blank(nx,ny)
				for q in grp_imgdata:  Util.add_img2( var, q )
				Util.mul_scalar( var, 1.0/(len(grp_imgdata)-1))
				# Switch to std dev
				var = square_root(threshold(var))
				#if options.CTF:	ave, var = avgvar_ctf(grp_imgdata, mode="a")
				#else:	            ave, var = avgvar(grp_imgdata, mode="a")
				"""
				if myid == main_node:
					ave.write_image("avgv.hdf",i)
					var.write_image("varv.hdf",i)
				"""
			
				set_params_proj(ave, [phiM, thetaM, 0.0, 0.0, 0.0])
				set_params_proj(var, [phiM, thetaM, 0.0, 0.0, 0.0])

				aveList.append(ave)
				varList.append(var)

				if options.VERBOSE:
					print "%5.2f%% done on processor %d"%(i*100.0/len(proj_list), myid)
				if nvec > 0:
					eig = pca(input_stacks=grp_imgdata, subavg="", mask_radius=radiuspca, nvec=nvec, incore=True, shuffle=False, genbuf=True)
					for k in xrange(nvec):
						set_params_proj(eig[k], [phiM, thetaM, 0.0, 0.0, 0.0])
						eigList[k].append(eig[k])
					"""
					if myid == 0 and i == 0:
						for k in xrange(nvec):
							eig[k].write_image("eig.hdf", k)
					"""

			del imgdata
			#  To this point, all averages, variances, and eigenvectors are computed

			if options.ave2D:
				from fundamentals import fpol
				if myid == main_node:
					km = 0
					for i in xrange(number_of_proc):
						if i == main_node :
							for im in xrange(len(aveList)):
								aveList[im].write_image(options.ave2D, km)
								km += 1
						else:
							nl = mpi_recv(1, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
							nl = int(nl[0])
							for im in xrange(nl):
								ave = recv_EMData(i, im+i+70000)
								"""
								nm = mpi_recv(1, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
								nm = int(nm[0])
								members = mpi_recv(nm, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
								ave.set_attr('members', map(int, members))
								members = mpi_recv(nm, MPI_FLOAT, i, MPI_TAG_UB, MPI_COMM_WORLD)
								ave.set_attr('pix_err', map(float, members))
								members = mpi_recv(3, MPI_FLOAT, i, MPI_TAG_UB, MPI_COMM_WORLD)
								ave.set_attr('refprojdir', map(float, members))
								"""
								tmpvol=fpol(ave, Tracker["nx"],Tracker["nx"],Tracker["nx"])								
								tmpvol.write_image(options.ave2D, km)
								km += 1
				else:
					mpi_send(len(aveList), 1, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
					for im in xrange(len(aveList)):
						send_EMData(aveList[im], main_node,im+myid+70000)
						"""
						members = aveList[im].get_attr('members')
						mpi_send(len(members), 1, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						mpi_send(members, len(members), MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						members = aveList[im].get_attr('pix_err')
						mpi_send(members, len(members), MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						try:
							members = aveList[im].get_attr('refprojdir')
							mpi_send(members, 3, MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						except:
							mpi_send([-999.0,-999.0,-999.0], 3, MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						"""

			if options.ave3D:
				from fundamentals import fpol
				if options.VERBOSE:
					print "Reconstructing 3D average volume"
				ave3D = recons3d_4nn_MPI(myid, aveList, symmetry=options.sym, npad=options.npad)
				bcast_EMData_to_all(ave3D, myid)
				if myid == main_node:
					ave3D=fpol(ave3D,Tracker["nx"],Tracker["nx"],Tracker["nx"])
					ave3D.write_image(options.ave3D)
					print_msg("%-70s:  %s\n"%("Writing to the disk volume reconstructed from averages as", options.ave3D))
			del ave, var, proj_list, stack, phi, theta, psi, s2x, s2y, alpha, sx, sy, mirror, aveList

			if nvec > 0:
				for k in xrange(nvec):
					if options.VERBOSE:
						print "Reconstruction eigenvolumes", k
					cont = True
					ITER = 0
					mask2d = model_circle(radiuspca, nx, nx)
					while cont:
						#print "On node %d, iteration %d"%(myid, ITER)
						eig3D = recons3d_4nn_MPI(myid, eigList[k], symmetry=options.sym, npad=options.npad)
						bcast_EMData_to_all(eig3D, myid, main_node)
						if options.fl > 0.0:
							eig3D = filt_tanl(eig3D, options.fl, options.aa)
						if myid == main_node:
							eig3D.write_image("eig3d_%03d.hdf"%k, ITER)
						Util.mul_img( eig3D, model_circle(radiuspca, nx, nx, nx) )
						eig3Df, kb = prep_vol(eig3D)
						del eig3D
						cont = False
						icont = 0
						for l in xrange(len(eigList[k])):
							phi, theta, psi, s2x, s2y = get_params_proj(eigList[k][l])
							proj = prgs(eig3Df, kb, [phi, theta, psi, s2x, s2y])
							cl = ccc(proj, eigList[k][l], mask2d)
							if cl < 0.0:
								icont += 1
								cont = True
								eigList[k][l] *= -1.0
						u = int(cont)
						u = mpi_reduce([u], 1, MPI_INT, MPI_MAX, main_node, MPI_COMM_WORLD)
						icont = mpi_reduce([icont], 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)

						if myid == main_node:
							u = int(u[0])
							print " Eigenvector: ",k," number changed ",int(icont[0])
						else: u = 0
						u = bcast_number_to_all(u, main_node)
						cont = bool(u)
						ITER += 1

					del eig3Df, kb
					mpi_barrier(MPI_COMM_WORLD)
				del eigList, mask2d

			if options.ave3D: del ave3D
			if options.var2D:
				from fundamentals import fpol 
				if myid == main_node:
					km = 0
					for i in xrange(number_of_proc):
						if i == main_node :
							for im in xrange(len(varList)):
								tmpvol=fpol(varList[im], Tracker["nx"], Tracker["nx"],1)
								tmpvol.write_image(options.var2D, km)
								km += 1
						else:
							nl = mpi_recv(1, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
							nl = int(nl[0])
							for im in xrange(nl):
								ave = recv_EMData(i, im+i+70000)
								tmpvol=fpol(ave, Tracker["nx"], Tracker["nx"],1)
								tmpvol.write_image(options.var2D, km)
								km += 1
				else:
					mpi_send(len(varList), 1, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
					for im in xrange(len(varList)):
						send_EMData(varList[im], main_node, im+myid+70000)#  What with the attributes??

			mpi_barrier(MPI_COMM_WORLD)

		if  options.var3D:
			if myid == main_node and options.VERBOSE:
				print "Reconstructing 3D variability volume"

			t6 = time()
			radiusvar = options.radiusvar
			if( radiusvar < 0 ):  radiusvar = nx//2 -3
			res = recons3d_4nn_MPI(myid, varList, symmetry=options.sym, npad=options.npad)
			#res = recons3d_em_MPI(varList, vol_stack, options.iter, radiusvar, options.abs, True, options.sym, options.squ)
			if myid == main_node:
				from fundamentals import fpol
				res =fpol(res, Tracker["nx"], Tracker["nx"], Tracker["nx"])
				res.write_image(options.var3D)

			if myid == main_node:
				print_msg("%-70s:  %.2f\n"%("Reconstructing 3D variability took [s]", time()-t6))
				if options.VERBOSE:
					print "Reconstruction took: %.2f [min]"%((time()-t6)/60)

			if myid == main_node:
				print_msg("%-70s:  %.2f\n"%("Total time for these computations [s]", time()-t0))
				if options.VERBOSE:
					print "Total time for these computations: %.2f [min]"%((time()-t0)/60)
				print_end_msg("sx3dvariability")

		global_def.BATCH = False

		from mpi import mpi_finalize
		mpi_finalize()